Skip to content

Commit

Permalink
Merge pull request #1 from Azure/master
Browse files Browse the repository at this point in the history
feat: ONNX model inference on Spark (microsoft#1152)
  • Loading branch information
ms-kashyap authored Aug 11, 2021
2 parents a5135b2 + 448f893 commit 8f29786
Show file tree
Hide file tree
Showing 36 changed files with 1,656 additions and 258 deletions.
41 changes: 21 additions & 20 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import scala.xml.transform.{RewriteRule, RuleTransformer}
import BuildUtils._
import xerial.sbt.Sonatype._

import java.nio.file.Files

val condaEnvName = "mmlspark"
val sparkVersion = "3.1.2"
name := "mmlspark"
Expand Down Expand Up @@ -41,7 +39,7 @@ val dependencies = coreDependencies ++ extraDependencies

def txt(e: Elem, label: String): String = "\"" + e.child.filter(_.label == label).flatMap(_.text).mkString + "\""

val omittedDeps = Set(s"spark-core_${scalaMajorVersion}", s"spark-mllib_${scalaMajorVersion}", "org.scala-lang")
val omittedDeps = Set(s"spark-core_$scalaMajorVersion", s"spark-mllib_$scalaMajorVersion", "org.scala-lang")
// skip dependency elements with a scope

def pomPostFunc(node: XmlNode): scala.xml.Node = {
Expand All @@ -67,7 +65,7 @@ pomPostProcess := pomPostFunc
val speechResolver = "Speech" at "https://mmlspark.blob.core.windows.net/maven/"

val getDatasetsTask = TaskKey[Unit]("getDatasets", "download datasets used for testing")
val datasetName = "datasets-2020-08-27.tgz"
val datasetName = "datasets-2021-07-27.tgz"
val datasetUrl = new URL(s"https://mmlspark.blob.core.windows.net/installers/$datasetName")
val datasetDir = settingKey[File]("The directory that holds the dataset")
ThisBuild / datasetDir := {
Expand Down Expand Up @@ -196,7 +194,7 @@ ThisBuild / publishMavenStyle := true

lazy val core = (project in file("core"))
.enablePlugins(BuildInfoPlugin && SbtPlugin)
.settings((settings ++ Seq(
.settings(settings ++ Seq(
libraryDependencies ++= dependencies,
buildInfoKeys ++= Seq[BuildInfoKey](
datasetDir,
Expand All @@ -207,48 +205,51 @@ lazy val core = (project in file("core"))
),
name := "mmlspark-core",
buildInfoPackage := "com.microsoft.ml.spark.build",
)): _*)
): _*)

lazy val deepLearning = (project in file("deep-learning"))
.enablePlugins(SbtPlugin)
.dependsOn(core % "test->test;compile->compile")
.settings((settings ++ Seq(
libraryDependencies += ("com.microsoft.cntk" % "cntk" % "2.4"),
.dependsOn(core % "test->test;compile->compile", opencv % "test->test;compile->compile")
.settings(settings ++ Seq(
libraryDependencies ++= Seq(
"com.microsoft.cntk" % "cntk" % "2.4",
"com.microsoft.onnxruntime" % "onnxruntime_gpu" % "1.8.1"
),
name := "mmlspark-deep-learning",
)): _*)
): _*)

lazy val lightgbm = (project in file("lightgbm"))
.enablePlugins(SbtPlugin)
.dependsOn(core % "test->test;compile->compile")
.settings((settings ++ Seq(
.settings(settings ++ Seq(
libraryDependencies += ("com.microsoft.ml.lightgbm" % "lightgbmlib" % "3.2.110"),
name := "mmlspark-lightgbm"
)): _*)
): _*)

lazy val vw = (project in file("vw"))
.enablePlugins(SbtPlugin)
.dependsOn(core % "test->test;compile->compile")
.settings((settings ++ Seq(
.settings(settings ++ Seq(
libraryDependencies += ("com.github.vowpalwabbit" % "vw-jni" % "8.9.1"),
name := "mmlspark-vw"
)): _*)
): _*)

lazy val cognitive = (project in file("cognitive"))
.enablePlugins(SbtPlugin)
.dependsOn(core % "test->test;compile->compile")
.settings((settings ++ Seq(
.settings(settings ++ Seq(
libraryDependencies += ("com.microsoft.cognitiveservices.speech" % "client-sdk" % "1.14.0"),
resolvers += speechResolver,
name := "mmlspark-cognitive"
)): _*)
): _*)

lazy val opencv = (project in file("opencv"))
.enablePlugins(SbtPlugin)
.dependsOn(core % "test->test;compile->compile")
.settings((settings ++ Seq(
.settings(settings ++ Seq(
libraryDependencies += ("org.openpnp" % "opencv" % "3.2.0-1"),
name := "mmlspark-opencv"
)): _*)
): _*)

lazy val root = (project in file("."))
.aggregate(core, deepLearning, cognitive, vw, lightgbm, opencv)
Expand Down Expand Up @@ -297,15 +298,15 @@ pgpPassphrase := Some(Secrets.pgpPassword.toCharArray)
pgpSecretRing := {
val temp = File.createTempFile("secret", ".asc")
new PrintWriter(temp) {
write(Secrets.pgpPrivate);
write(Secrets.pgpPrivate)
close()
}
temp
}
pgpPublicRing := {
val temp = File.createTempFile("public", ".asc")
new PrintWriter(temp) {
write(Secrets.pgpPublic);
write(Secrets.pgpPublic)
close()
}
temp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pyspark.ml.common import _to_java_object_rdd, _java2py
import pyspark
from pyspark.ml import PipelineModel
from pyspark.sql.types import DataType


@staticmethod
Expand All @@ -21,7 +22,7 @@ def __get_class(clazz):
"""
Loads Python class from its name.
"""
parts = clazz.split('.')
parts = clazz.split(".")
module = ".".join(parts[:-1])
m = __import__(module)
for comp in parts[1:]:
Expand All @@ -41,8 +42,7 @@ def __get_class(clazz):
elif hasattr(py_type, "_from_java"):
py_stage = py_type._from_java(java_stage)
else:
raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r"
% stage_name)
raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r" % stage_name)
return py_stage


Expand All @@ -68,6 +68,8 @@ def _mml_py2java(sc, obj):
pass
elif isinstance(obj, (int, float, bool, bytes, str)):
pass
elif isinstance(obj, DataType):
obj = sc._jvm.org.apache.spark.sql.types.DataType.fromJson(obj.json())
else:
data = bytearray(PickleSerializer().dumps(obj))
obj = sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ object ImageSchemaUtils {
StructField("data", BinaryType, true) :: Nil)
}

val ImageSchemaNullable = StructType(StructField("image", ColumnSchemaNullable, true) :: Nil)
val ImageSchemaNullable: StructType = StructType(StructField("image", ColumnSchemaNullable, nullable = true) :: Nil)

def isImage(dataType: DataType): Boolean = {
dataType == ImageSchema.columnSchema ||
dataType == ColumnSchemaNullable
DataType.equalsStructurally(dataType, ImageSchema.columnSchema, ignoreNullability = true)
}

def isImage(dataType: StructField): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.explainers
package com.microsoft.ml.spark.core.utils

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV, DenseMatrix => BDM}
import org.apache.spark.ml.linalg.{Vector, Vectors, Matrix, Matrices}
import org.apache.spark.ml.linalg.{Matrices, Matrix, Vector, Vectors}
import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, SparseVector => BSV, Vector => BV}

object BreezeUtils {
implicit class SparkVectorCanConvertToBreeze(sv: Vector) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package com.microsoft.ml.spark.explainers
import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV}
import com.microsoft.ml.spark.codegen.Wrappable
import com.microsoft.ml.spark.core.schema.DatasetExtensions
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
package com.microsoft.ml.spark.explainers

import breeze.linalg.{sum, DenseVector => BDV}
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._
import org.apache.commons.math3.util.CombinatoricsUtils.{binomialCoefficientDouble => comb}
import org.apache.spark.ml.linalg.Vector

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.ml.spark.explainers
import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV}
import com.microsoft.ml.spark.codegen.Wrappable
import com.microsoft.ml.spark.core.schema.DatasetExtensions
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.injections.UDFUtils
import org.apache.spark.ml.Transformer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.ml.spark.explainers
import breeze.linalg.{BitVector, axpy, norm, DenseVector => BDV}
import breeze.stats.distributions.RandBasis
import org.apache.spark.ml.linalg.Vector
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._

private[explainers] trait LIMESampler[TObservation] extends Sampler[TObservation, Vector] {
def sample: (TObservation, Vector, Double) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package com.microsoft.ml.spark.explainers

import breeze.linalg.{norm, DenseVector => BDV}
import breeze.stats.distributions.RandBasis
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._
import com.microsoft.ml.spark.explainers.RowUtils.RowCanGetAsDouble
import com.microsoft.ml.spark.lime.{Superpixel, SuperpixelData}
import org.apache.spark.ml.linalg.{Vector, Vectors}
Expand Down Expand Up @@ -152,6 +152,8 @@ private[explainers] class LIMETabularSampler(val instance: Row, val featureStats
(instance.get(i), 1d)
case (_: ContinuousFeatureStats, i) =>
(instance.getAsDouble(i), instance.getAsDouble(i))
case (_, _) =>
throw new NotImplementedError("invalid state")
}.unzip

(Row.fromSeq(identityRow), Vectors.dense(identityState.toArray), 0d)
Expand Down
12 changes: 6 additions & 6 deletions core/src/main/scala/com/microsoft/ml/spark/lime/LIME.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ trait LIMEBase extends LIMEParams with ComplexParamsWritable {

}

@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3")
object TabularLIME extends ComplexParamsReadable[TabularLIME]

@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3")
class TabularLIME(val uid: String) extends Estimator[TabularLIMEModel]
with LIMEParams with Wrappable with ComplexParamsWritable with BasicLogging {
logClass()
Expand Down Expand Up @@ -200,10 +200,10 @@ class TabularLIME(val uid: String) extends Estimator[TabularLIMEModel]
}
}

@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3")
object TabularLIMEModel extends ComplexParamsReadable[TabularLIMEModel]

@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3")
class TabularLIMEModel(val uid: String) extends Model[TabularLIMEModel]
with LIMEBase with Wrappable with BasicLogging {
logClass()
Expand Down Expand Up @@ -256,15 +256,15 @@ class TabularLIMEModel(val uid: String) extends Model[TabularLIMEModel]

}

@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-rc3")
object ImageLIME extends ComplexParamsReadable[ImageLIME]

/** Distributed implementation of
* Local Interpretable Model-Agnostic Explanations (LIME)
*
* https://arxiv.org/pdf/1602.04938v1.pdf
*/
@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-rc3")
class ImageLIME(val uid: String) extends Transformer with LIMEBase
with Wrappable with HasModifier with HasCellSize with BasicLogging {
logClass()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ object TextLIME extends ComplexParamsReadable[TextLIME]
*
* https://arxiv.org/pdf/1602.04938v1.pdf
*/
@deprecated("Please use 'com.microsoft.ml.spark.explainers.TextLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.TextLIME'.", since="1.0.0-rc3")
class TextLIME(val uid: String) extends Model[TextLIME]
with LIMEBase with Wrappable with BasicLogging {
logClass()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}

Expand All @@ -30,16 +31,20 @@ trait MiniBatchBase extends Transformer with DefaultParamsWritable with Wrappabl

def transform(dataset: Dataset[_]): DataFrame = {
logTransform[DataFrame]({
val outputSchema = transformSchema(dataset.schema)
implicit val outputEncoder: ExpressionEncoder[Row] = RowEncoder(outputSchema)
dataset.toDF().mapPartitions { it =>
if (it.isEmpty) {
it
} else {
getBatcher(it).map(listOfRows => Row.fromSeq(transpose(listOfRows.map(r => r.toSeq))))
getBatcher(it).map {
listOfRows =>
new GenericRowWithSchema(transpose(listOfRows.map(r => r.toSeq)).toArray, outputSchema)
}
}
}(RowEncoder(transformSchema(dataset.schema)))
}
})
}

}

object DynamicMiniBatchTransformer extends DefaultParamsReadable[DynamicMiniBatchTransformer]
Expand Down Expand Up @@ -203,15 +208,20 @@ class FlattenBatch(val uid: String)

override def transform(dataset: Dataset[_]): DataFrame = {
logTransform[DataFrame]({
val outputSchema = transformSchema(dataset.schema)
implicit val outputEncoder: ExpressionEncoder[Row] = RowEncoder(outputSchema)

dataset.toDF().mapPartitions(it =>
it.flatMap { rowOfLists =>
val transposed = transpose(
val transposed: Seq[Seq[Any]] = transpose(
(0 until rowOfLists.length)
.filterNot(rowOfLists.isNullAt)
.map(rowOfLists.getSeq))
transposed.map(Row.fromSeq)
transposed.map {
values => new GenericRowWithSchema(values.toArray, outputSchema)
}
}
)(RowEncoder(transformSchema(dataset.schema)))
)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,21 @@ abstract class TestBase extends FunSuite with BeforeAndAfterEachTestData with Be
}
}

def breezeVectorEq[T: Field: ClassTag](tol: Double)(implicit normImpl: Impl[T, Double]): Equality[BDV[T]] =
def breezeVectorEq[T: Field](tol: Double)(implicit normImpl: Impl[T, Double]): Equality[BDV[T]] =
(a: BDV[T], b: Any) => {
b match {
case p: BDV[T @unchecked] =>
a.length == p.length && norm(a - p) < tol
case _ => false
}
}

def mapEq[K, V: Equality]: Equality[Map[K, V]] = {
(a: Map[K, V], b: Any) => {
b match {
case m: Map[K @unchecked, V @unchecked] => a.keySet == m.keySet && a.keySet.forall(key => a(key) === m(key))
case _ => false
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.ml.spark.explainers.split1
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV}
import breeze.stats.distributions.RandBasis
import breeze.stats.{mean, stddev}
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._
import com.microsoft.ml.spark.explainers._
import com.microsoft.ml.spark.io.image.ImageUtils
import com.microsoft.ml.spark.lime.{Superpixel, SuperpixelData}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.ml.spark.explainers.split1
import breeze.linalg.{DenseVector => BDV}
import com.microsoft.ml.spark.core.test.base.TestBase
import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing}
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._
import com.microsoft.ml.spark.explainers.{LocalExplainer, TabularLIME}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer, VectorAssembler}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package com.microsoft.ml.spark.explainers.split1

import com.microsoft.ml.spark.core.test.base.TestBase
import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing}
import com.microsoft.ml.spark.explainers.BreezeUtils._
import com.microsoft.ml.spark.core.utils.BreezeUtils._
import com.microsoft.ml.spark.explainers.{LocalExplainer, TabularSHAP}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer, VectorAssembler}
Expand Down
Loading

0 comments on commit 8f29786

Please sign in to comment.