Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ONNX model inference on Spark #1152

Merged
merged 21 commits into from
Aug 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import scala.xml.transform.{RewriteRule, RuleTransformer}
import BuildUtils._
import xerial.sbt.Sonatype._

import java.nio.file.Files

val condaEnvName = "mmlspark"
val sparkVersion = "3.1.2"
name := "mmlspark"
Expand Down Expand Up @@ -41,7 +39,7 @@ val dependencies = coreDependencies ++ extraDependencies

def txt(e: Elem, label: String): String = "\"" + e.child.filter(_.label == label).flatMap(_.text).mkString + "\""

val omittedDeps = Set(s"spark-core_${scalaMajorVersion}", s"spark-mllib_${scalaMajorVersion}", "org.scala-lang")
val omittedDeps = Set(s"spark-core_$scalaMajorVersion", s"spark-mllib_$scalaMajorVersion", "org.scala-lang")
// skip dependency elements with a scope

def pomPostFunc(node: XmlNode): scala.xml.Node = {
Expand All @@ -67,7 +65,7 @@ pomPostProcess := pomPostFunc
val speechResolver = "Speech" at "https://mmlspark.blob.core.windows.net/maven/"

val getDatasetsTask = TaskKey[Unit]("getDatasets", "download datasets used for testing")
val datasetName = "datasets-2020-08-27.tgz"
val datasetName = "datasets-2021-07-27.tgz"
val datasetUrl = new URL(s"https://mmlspark.blob.core.windows.net/installers/$datasetName")
val datasetDir = settingKey[File]("The directory that holds the dataset")
ThisBuild / datasetDir := {
Expand Down Expand Up @@ -196,7 +194,7 @@ ThisBuild / publishMavenStyle := true

lazy val core = (project in file("core"))
.enablePlugins(BuildInfoPlugin && SbtPlugin)
.settings((settings ++ Seq(
.settings(settings ++ Seq(
libraryDependencies ++= dependencies,
buildInfoKeys ++= Seq[BuildInfoKey](
datasetDir,
Expand All @@ -207,48 +205,51 @@ lazy val core = (project in file("core"))
),
name := "mmlspark-core",
buildInfoPackage := "com.microsoft.ml.spark.build",
)): _*)
): _*)

lazy val deepLearning = (project in file("deep-learning"))
.enablePlugins(SbtPlugin)
.dependsOn(core % "test->test;compile->compile")
.settings((settings ++ Seq(
libraryDependencies += ("com.microsoft.cntk" % "cntk" % "2.4"),
.dependsOn(core % "test->test;compile->compile", opencv % "test->test;compile->compile")
.settings(settings ++ Seq(
libraryDependencies ++= Seq(
"com.microsoft.cntk" % "cntk" % "2.4",
"com.microsoft.onnxruntime" % "onnxruntime_gpu" % "1.8.1"
),
name := "mmlspark-deep-learning",
)): _*)
): _*)

lazy val lightgbm = (project in file("lightgbm"))
.enablePlugins(SbtPlugin)
.dependsOn(core % "test->test;compile->compile")
.settings((settings ++ Seq(
.settings(settings ++ Seq(
libraryDependencies += ("com.microsoft.ml.lightgbm" % "lightgbmlib" % "3.2.110"),
name := "mmlspark-lightgbm"
)): _*)
): _*)

lazy val vw = (project in file("vw"))
.enablePlugins(SbtPlugin)
.dependsOn(core % "test->test;compile->compile")
.settings((settings ++ Seq(
.settings(settings ++ Seq(
libraryDependencies += ("com.github.vowpalwabbit" % "vw-jni" % "8.9.1"),
name := "mmlspark-vw"
)): _*)
): _*)

lazy val cognitive = (project in file("cognitive"))
.enablePlugins(SbtPlugin)
.dependsOn(core % "test->test;compile->compile")
.settings((settings ++ Seq(
.settings(settings ++ Seq(
libraryDependencies += ("com.microsoft.cognitiveservices.speech" % "client-sdk" % "1.14.0"),
resolvers += speechResolver,
name := "mmlspark-cognitive"
)): _*)
): _*)

lazy val opencv = (project in file("opencv"))
.enablePlugins(SbtPlugin)
.dependsOn(core % "test->test;compile->compile")
.settings((settings ++ Seq(
.settings(settings ++ Seq(
libraryDependencies += ("org.openpnp" % "opencv" % "3.2.0-1"),
name := "mmlspark-opencv"
)): _*)
): _*)

lazy val root = (project in file("."))
.aggregate(core, deepLearning, cognitive, vw, lightgbm, opencv)
Expand Down Expand Up @@ -297,15 +298,15 @@ pgpPassphrase := Some(Secrets.pgpPassword.toCharArray)
pgpSecretRing := {
val temp = File.createTempFile("secret", ".asc")
new PrintWriter(temp) {
write(Secrets.pgpPrivate);
write(Secrets.pgpPrivate)
close()
}
temp
}
pgpPublicRing := {
val temp = File.createTempFile("public", ".asc")
new PrintWriter(temp) {
write(Secrets.pgpPublic);
write(Secrets.pgpPublic)
close()
}
temp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pyspark.ml.common import _to_java_object_rdd, _java2py
import pyspark
from pyspark.ml import PipelineModel
from pyspark.sql.types import DataType


@staticmethod
Expand All @@ -21,7 +22,7 @@ def __get_class(clazz):
"""
Loads Python class from its name.
"""
parts = clazz.split('.')
parts = clazz.split(".")
module = ".".join(parts[:-1])
m = __import__(module)
for comp in parts[1:]:
Expand All @@ -41,8 +42,7 @@ def __get_class(clazz):
elif hasattr(py_type, "_from_java"):
py_stage = py_type._from_java(java_stage)
else:
raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r"
% stage_name)
raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r" % stage_name)
return py_stage


Expand All @@ -68,6 +68,8 @@ def _mml_py2java(sc, obj):
pass
elif isinstance(obj, (int, float, bool, bytes, str)):
pass
elif isinstance(obj, DataType):
obj = sc._jvm.org.apache.spark.sql.types.DataType.fromJson(obj.json())
else:
data = bytearray(PickleSerializer().dumps(obj))
obj = sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ object ImageSchemaUtils {
StructField("data", BinaryType, true) :: Nil)
}

val ImageSchemaNullable = StructType(StructField("image", ColumnSchemaNullable, true) :: Nil)
val ImageSchemaNullable: StructType = StructType(StructField("image", ColumnSchemaNullable, nullable = true) :: Nil)

def isImage(dataType: DataType): Boolean = {
dataType == ImageSchema.columnSchema ||
dataType == ColumnSchemaNullable
DataType.equalsStructurally(dataType, ImageSchema.columnSchema, ignoreNullability = true)
}

def isImage(dataType: StructField): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.explainers
package com.microsoft.ml.spark.core.utils

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV, DenseMatrix => BDM}
import org.apache.spark.ml.linalg.{Vector, Vectors, Matrix, Matrices}
import org.apache.spark.ml.linalg.{Matrices, Matrix, Vector, Vectors}
import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, SparseVector => BSV, Vector => BV}

object BreezeUtils {
implicit class SparkVectorCanConvertToBreeze(sv: Vector) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package com.microsoft.ml.spark.explainers
import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV}
import com.microsoft.ml.spark.codegen.Wrappable
import com.microsoft.ml.spark.core.schema.DatasetExtensions
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
package com.microsoft.ml.spark.explainers

import breeze.linalg.{sum, DenseVector => BDV}
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._
import org.apache.commons.math3.util.CombinatoricsUtils.{binomialCoefficientDouble => comb}
import org.apache.spark.ml.linalg.Vector

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.ml.spark.explainers
import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV}
import com.microsoft.ml.spark.codegen.Wrappable
import com.microsoft.ml.spark.core.schema.DatasetExtensions
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.injections.UDFUtils
import org.apache.spark.ml.Transformer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.ml.spark.explainers
import breeze.linalg.{BitVector, axpy, norm, DenseVector => BDV}
import breeze.stats.distributions.RandBasis
import org.apache.spark.ml.linalg.Vector
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._

private[explainers] trait LIMESampler[TObservation] extends Sampler[TObservation, Vector] {
def sample: (TObservation, Vector, Double) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package com.microsoft.ml.spark.explainers

import breeze.linalg.{norm, DenseVector => BDV}
import breeze.stats.distributions.RandBasis
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._
import com.microsoft.ml.spark.explainers.RowUtils.RowCanGetAsDouble
import com.microsoft.ml.spark.lime.{Superpixel, SuperpixelData}
import org.apache.spark.ml.linalg.{Vector, Vectors}
Expand Down Expand Up @@ -152,6 +152,8 @@ private[explainers] class LIMETabularSampler(val instance: Row, val featureStats
(instance.get(i), 1d)
case (_: ContinuousFeatureStats, i) =>
(instance.getAsDouble(i), instance.getAsDouble(i))
case (_, _) =>
throw new NotImplementedError("invalid state")
}.unzip

(Row.fromSeq(identityRow), Vectors.dense(identityState.toArray), 0d)
Expand Down
12 changes: 6 additions & 6 deletions core/src/main/scala/com/microsoft/ml/spark/lime/LIME.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ trait LIMEBase extends LIMEParams with ComplexParamsWritable {

}

@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3")
object TabularLIME extends ComplexParamsReadable[TabularLIME]

@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3")
class TabularLIME(val uid: String) extends Estimator[TabularLIMEModel]
with LIMEParams with Wrappable with ComplexParamsWritable with BasicLogging {
logClass()
Expand Down Expand Up @@ -200,10 +200,10 @@ class TabularLIME(val uid: String) extends Estimator[TabularLIMEModel]
}
}

@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3")
object TabularLIMEModel extends ComplexParamsReadable[TabularLIMEModel]

@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3")
class TabularLIMEModel(val uid: String) extends Model[TabularLIMEModel]
with LIMEBase with Wrappable with BasicLogging {
logClass()
Expand Down Expand Up @@ -256,15 +256,15 @@ class TabularLIMEModel(val uid: String) extends Model[TabularLIMEModel]

}

@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-rc3")
object ImageLIME extends ComplexParamsReadable[ImageLIME]

/** Distributed implementation of
* Local Interpretable Model-Agnostic Explanations (LIME)
*
* https://arxiv.org/pdf/1602.04938v1.pdf
*/
@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-rc3")
class ImageLIME(val uid: String) extends Transformer with LIMEBase
with Wrappable with HasModifier with HasCellSize with BasicLogging {
logClass()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ object TextLIME extends ComplexParamsReadable[TextLIME]
*
* https://arxiv.org/pdf/1602.04938v1.pdf
*/
@deprecated("Please use 'com.microsoft.ml.spark.explainers.TextLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.TextLIME'.", since="1.0.0-rc3")
class TextLIME(val uid: String) extends Model[TextLIME]
with LIMEBase with Wrappable with BasicLogging {
logClass()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}

Expand All @@ -30,16 +31,20 @@ trait MiniBatchBase extends Transformer with DefaultParamsWritable with Wrappabl

def transform(dataset: Dataset[_]): DataFrame = {
logTransform[DataFrame]({
val outputSchema = transformSchema(dataset.schema)
implicit val outputEncoder: ExpressionEncoder[Row] = RowEncoder(outputSchema)
dataset.toDF().mapPartitions { it =>
if (it.isEmpty) {
it
} else {
getBatcher(it).map(listOfRows => Row.fromSeq(transpose(listOfRows.map(r => r.toSeq))))
getBatcher(it).map {
listOfRows =>
new GenericRowWithSchema(transpose(listOfRows.map(r => r.toSeq)).toArray, outputSchema)
}
}
}(RowEncoder(transformSchema(dataset.schema)))
}
})
}

}

object DynamicMiniBatchTransformer extends DefaultParamsReadable[DynamicMiniBatchTransformer]
Expand Down Expand Up @@ -203,15 +208,20 @@ class FlattenBatch(val uid: String)

override def transform(dataset: Dataset[_]): DataFrame = {
logTransform[DataFrame]({
val outputSchema = transformSchema(dataset.schema)
implicit val outputEncoder: ExpressionEncoder[Row] = RowEncoder(outputSchema)

dataset.toDF().mapPartitions(it =>
it.flatMap { rowOfLists =>
val transposed = transpose(
val transposed: Seq[Seq[Any]] = transpose(
(0 until rowOfLists.length)
.filterNot(rowOfLists.isNullAt)
.map(rowOfLists.getSeq))
transposed.map(Row.fromSeq)
transposed.map {
values => new GenericRowWithSchema(values.toArray, outputSchema)
}
}
)(RowEncoder(transformSchema(dataset.schema)))
)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,21 @@ abstract class TestBase extends FunSuite with BeforeAndAfterEachTestData with Be
}
}

def breezeVectorEq[T: Field: ClassTag](tol: Double)(implicit normImpl: Impl[T, Double]): Equality[BDV[T]] =
def breezeVectorEq[T: Field](tol: Double)(implicit normImpl: Impl[T, Double]): Equality[BDV[T]] =
(a: BDV[T], b: Any) => {
b match {
case p: BDV[T @unchecked] =>
a.length == p.length && norm(a - p) < tol
case _ => false
}
}

def mapEq[K, V: Equality]: Equality[Map[K, V]] = {
(a: Map[K, V], b: Any) => {
b match {
case m: Map[K @unchecked, V @unchecked] => a.keySet == m.keySet && a.keySet.forall(key => a(key) === m(key))
case _ => false
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.ml.spark.explainers.split1
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV}
import breeze.stats.distributions.RandBasis
import breeze.stats.{mean, stddev}
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._
import com.microsoft.ml.spark.explainers._
import com.microsoft.ml.spark.io.image.ImageUtils
import com.microsoft.ml.spark.lime.{Superpixel, SuperpixelData}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.ml.spark.explainers.split1
import breeze.linalg.{DenseVector => BDV}
import com.microsoft.ml.spark.core.test.base.TestBase
import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing}
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._
import com.microsoft.ml.spark.explainers.{LocalExplainer, TabularLIME}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer, VectorAssembler}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package com.microsoft.ml.spark.explainers.split1

import com.microsoft.ml.spark.core.test.base.TestBase
import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing}
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._
import com.microsoft.ml.spark.explainers.{LocalExplainer, TabularSHAP}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer, VectorAssembler}
Expand Down
Loading