diff --git a/build.sbt b/build.sbt index e3b855035e..0cb2462deb 100644 --- a/build.sbt +++ b/build.sbt @@ -8,8 +8,6 @@ import scala.xml.transform.{RewriteRule, RuleTransformer} import BuildUtils._ import xerial.sbt.Sonatype._ -import java.nio.file.Files - val condaEnvName = "mmlspark" val sparkVersion = "3.1.2" name := "mmlspark" @@ -41,7 +39,7 @@ val dependencies = coreDependencies ++ extraDependencies def txt(e: Elem, label: String): String = "\"" + e.child.filter(_.label == label).flatMap(_.text).mkString + "\"" -val omittedDeps = Set(s"spark-core_${scalaMajorVersion}", s"spark-mllib_${scalaMajorVersion}", "org.scala-lang") +val omittedDeps = Set(s"spark-core_$scalaMajorVersion", s"spark-mllib_$scalaMajorVersion", "org.scala-lang") // skip dependency elements with a scope def pomPostFunc(node: XmlNode): scala.xml.Node = { @@ -67,7 +65,7 @@ pomPostProcess := pomPostFunc val speechResolver = "Speech" at "https://mmlspark.blob.core.windows.net/maven/" val getDatasetsTask = TaskKey[Unit]("getDatasets", "download datasets used for testing") -val datasetName = "datasets-2020-08-27.tgz" +val datasetName = "datasets-2021-07-27.tgz" val datasetUrl = new URL(s"https://mmlspark.blob.core.windows.net/installers/$datasetName") val datasetDir = settingKey[File]("The directory that holds the dataset") ThisBuild / datasetDir := { @@ -196,7 +194,7 @@ ThisBuild / publishMavenStyle := true lazy val core = (project in file("core")) .enablePlugins(BuildInfoPlugin && SbtPlugin) - .settings((settings ++ Seq( + .settings(settings ++ Seq( libraryDependencies ++= dependencies, buildInfoKeys ++= Seq[BuildInfoKey]( datasetDir, @@ -207,48 +205,51 @@ lazy val core = (project in file("core")) ), name := "mmlspark-core", buildInfoPackage := "com.microsoft.ml.spark.build", - )): _*) + ): _*) lazy val deepLearning = (project in file("deep-learning")) .enablePlugins(SbtPlugin) - .dependsOn(core % "test->test;compile->compile") - .settings((settings ++ Seq( - libraryDependencies += ("com.microsoft.cntk" % "cntk" % "2.4"), + .dependsOn(core % "test->test;compile->compile", opencv % "test->test;compile->compile") + .settings(settings ++ Seq( + libraryDependencies ++= Seq( + "com.microsoft.cntk" % "cntk" % "2.4", + "com.microsoft.onnxruntime" % "onnxruntime_gpu" % "1.8.1" + ), name := "mmlspark-deep-learning", - )): _*) + ): _*) lazy val lightgbm = (project in file("lightgbm")) .enablePlugins(SbtPlugin) .dependsOn(core % "test->test;compile->compile") - .settings((settings ++ Seq( + .settings(settings ++ Seq( libraryDependencies += ("com.microsoft.ml.lightgbm" % "lightgbmlib" % "3.2.110"), name := "mmlspark-lightgbm" - )): _*) + ): _*) lazy val vw = (project in file("vw")) .enablePlugins(SbtPlugin) .dependsOn(core % "test->test;compile->compile") - .settings((settings ++ Seq( + .settings(settings ++ Seq( libraryDependencies += ("com.github.vowpalwabbit" % "vw-jni" % "8.9.1"), name := "mmlspark-vw" - )): _*) + ): _*) lazy val cognitive = (project in file("cognitive")) .enablePlugins(SbtPlugin) .dependsOn(core % "test->test;compile->compile") - .settings((settings ++ Seq( + .settings(settings ++ Seq( libraryDependencies += ("com.microsoft.cognitiveservices.speech" % "client-sdk" % "1.14.0"), resolvers += speechResolver, name := "mmlspark-cognitive" - )): _*) + ): _*) lazy val opencv = (project in file("opencv")) .enablePlugins(SbtPlugin) .dependsOn(core % "test->test;compile->compile") - .settings((settings ++ Seq( + .settings(settings ++ Seq( libraryDependencies += ("org.openpnp" % "opencv" % "3.2.0-1"), name := "mmlspark-opencv" - )): _*) + ): _*) lazy val root = (project in file(".")) .aggregate(core, deepLearning, cognitive, vw, lightgbm, opencv) @@ -297,7 +298,7 @@ pgpPassphrase := Some(Secrets.pgpPassword.toCharArray) pgpSecretRing := { val temp = File.createTempFile("secret", ".asc") new PrintWriter(temp) { - write(Secrets.pgpPrivate); + write(Secrets.pgpPrivate) close() } temp @@ -305,7 +306,7 @@ pgpSecretRing := { pgpPublicRing := { val temp = File.createTempFile("public", ".asc") new PrintWriter(temp) { - write(Secrets.pgpPublic); + write(Secrets.pgpPublic) close() } temp diff --git a/core/src/main/python/mmlspark/core/serialize/java_params_patch.py b/core/src/main/python/mmlspark/core/serialize/java_params_patch.py index aa1590b221..0bd07d4751 100644 --- a/core/src/main/python/mmlspark/core/serialize/java_params_patch.py +++ b/core/src/main/python/mmlspark/core/serialize/java_params_patch.py @@ -6,6 +6,7 @@ from pyspark.ml.common import _to_java_object_rdd, _java2py import pyspark from pyspark.ml import PipelineModel +from pyspark.sql.types import DataType @staticmethod @@ -21,7 +22,7 @@ def __get_class(clazz): """ Loads Python class from its name. """ - parts = clazz.split('.') + parts = clazz.split(".") module = ".".join(parts[:-1]) m = __import__(module) for comp in parts[1:]: @@ -41,8 +42,7 @@ def __get_class(clazz): elif hasattr(py_type, "_from_java"): py_stage = py_type._from_java(java_stage) else: - raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r" - % stage_name) + raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r" % stage_name) return py_stage @@ -68,6 +68,8 @@ def _mml_py2java(sc, obj): pass elif isinstance(obj, (int, float, bool, bytes, str)): pass + elif isinstance(obj, DataType): + obj = sc._jvm.org.apache.spark.sql.types.DataType.fromJson(obj.json()) else: data = bytearray(PickleSerializer().dumps(obj)) obj = sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(data) diff --git a/core/src/main/scala/com/microsoft/ml/spark/core/schema/ImageSchemaUtils.scala b/core/src/main/scala/com/microsoft/ml/spark/core/schema/ImageSchemaUtils.scala index b940b82af6..3139f97e5f 100644 --- a/core/src/main/scala/com/microsoft/ml/spark/core/schema/ImageSchemaUtils.scala +++ b/core/src/main/scala/com/microsoft/ml/spark/core/schema/ImageSchemaUtils.scala @@ -20,11 +20,10 @@ object ImageSchemaUtils { StructField("data", BinaryType, true) :: Nil) } - val ImageSchemaNullable = StructType(StructField("image", ColumnSchemaNullable, true) :: Nil) + val ImageSchemaNullable: StructType = StructType(StructField("image", ColumnSchemaNullable, nullable = true) :: Nil) def isImage(dataType: DataType): Boolean = { - dataType == ImageSchema.columnSchema || - dataType == ColumnSchemaNullable + DataType.equalsStructurally(dataType, ImageSchema.columnSchema, ignoreNullability = true) } def isImage(dataType: StructField): Boolean = { diff --git a/core/src/main/scala/com/microsoft/ml/spark/explainers/BreezeUtils.scala b/core/src/main/scala/com/microsoft/ml/spark/core/utils/BreezeUtils.scala similarity index 80% rename from core/src/main/scala/com/microsoft/ml/spark/explainers/BreezeUtils.scala rename to core/src/main/scala/com/microsoft/ml/spark/core/utils/BreezeUtils.scala index b77469e097..ca68b2c08a 100644 --- a/core/src/main/scala/com/microsoft/ml/spark/explainers/BreezeUtils.scala +++ b/core/src/main/scala/com/microsoft/ml/spark/core/utils/BreezeUtils.scala @@ -1,10 +1,10 @@ // Copyright (C) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See LICENSE in project root for information. -package com.microsoft.ml.spark.explainers +package com.microsoft.ml.spark.core.utils -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV, DenseMatrix => BDM} -import org.apache.spark.ml.linalg.{Vector, Vectors, Matrix, Matrices} +import org.apache.spark.ml.linalg.{Matrices, Matrix, Vector, Vectors} +import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, SparseVector => BSV, Vector => BV} object BreezeUtils { implicit class SparkVectorCanConvertToBreeze(sv: Vector) { diff --git a/core/src/main/scala/com/microsoft/ml/spark/explainers/KernelSHAPBase.scala b/core/src/main/scala/com/microsoft/ml/spark/explainers/KernelSHAPBase.scala index d6d25d1bb5..19838ad00e 100644 --- a/core/src/main/scala/com/microsoft/ml/spark/explainers/KernelSHAPBase.scala +++ b/core/src/main/scala/com/microsoft/ml/spark/explainers/KernelSHAPBase.scala @@ -5,7 +5,7 @@ package com.microsoft.ml.spark.explainers import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV} import com.microsoft.ml.spark.codegen.Wrappable import com.microsoft.ml.spark.core.schema.DatasetExtensions -import com.microsoft.ml.spark.explainers.BreezeUtils._ +import com.microsoft.ml.spark.core.utils.BreezeUtils._ import com.microsoft.ml.spark.logging.BasicLogging import org.apache.spark.ml.Transformer import org.apache.spark.ml.linalg.SQLDataTypes.VectorType diff --git a/core/src/main/scala/com/microsoft/ml/spark/explainers/KernelSHAPSampler.scala b/core/src/main/scala/com/microsoft/ml/spark/explainers/KernelSHAPSampler.scala index 8f52e2cece..412682fe52 100644 --- a/core/src/main/scala/com/microsoft/ml/spark/explainers/KernelSHAPSampler.scala +++ b/core/src/main/scala/com/microsoft/ml/spark/explainers/KernelSHAPSampler.scala @@ -4,7 +4,7 @@ package com.microsoft.ml.spark.explainers import breeze.linalg.{sum, DenseVector => BDV} -import com.microsoft.ml.spark.explainers.BreezeUtils._ +import com.microsoft.ml.spark.core.utils.BreezeUtils._ import org.apache.commons.math3.util.CombinatoricsUtils.{binomialCoefficientDouble => comb} import org.apache.spark.ml.linalg.Vector diff --git a/core/src/main/scala/com/microsoft/ml/spark/explainers/LIMEBase.scala b/core/src/main/scala/com/microsoft/ml/spark/explainers/LIMEBase.scala index c21fbcc4ea..d2ca2f1bbd 100644 --- a/core/src/main/scala/com/microsoft/ml/spark/explainers/LIMEBase.scala +++ b/core/src/main/scala/com/microsoft/ml/spark/explainers/LIMEBase.scala @@ -6,7 +6,7 @@ package com.microsoft.ml.spark.explainers import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV} import com.microsoft.ml.spark.codegen.Wrappable import com.microsoft.ml.spark.core.schema.DatasetExtensions -import com.microsoft.ml.spark.explainers.BreezeUtils._ +import com.microsoft.ml.spark.core.utils.BreezeUtils._ import com.microsoft.ml.spark.logging.BasicLogging import org.apache.spark.injections.UDFUtils import org.apache.spark.ml.Transformer diff --git a/core/src/main/scala/com/microsoft/ml/spark/explainers/LIMESampler.scala b/core/src/main/scala/com/microsoft/ml/spark/explainers/LIMESampler.scala index f6c98022d3..07ff855a93 100644 --- a/core/src/main/scala/com/microsoft/ml/spark/explainers/LIMESampler.scala +++ b/core/src/main/scala/com/microsoft/ml/spark/explainers/LIMESampler.scala @@ -6,7 +6,7 @@ package com.microsoft.ml.spark.explainers import breeze.linalg.{BitVector, axpy, norm, DenseVector => BDV} import breeze.stats.distributions.RandBasis import org.apache.spark.ml.linalg.Vector -import com.microsoft.ml.spark.explainers.BreezeUtils._ +import com.microsoft.ml.spark.core.utils.BreezeUtils._ private[explainers] trait LIMESampler[TObservation] extends Sampler[TObservation, Vector] { def sample: (TObservation, Vector, Double) = { diff --git a/core/src/main/scala/com/microsoft/ml/spark/explainers/Sampler.scala b/core/src/main/scala/com/microsoft/ml/spark/explainers/Sampler.scala index 9fc2fc581b..25b51b6867 100644 --- a/core/src/main/scala/com/microsoft/ml/spark/explainers/Sampler.scala +++ b/core/src/main/scala/com/microsoft/ml/spark/explainers/Sampler.scala @@ -5,7 +5,7 @@ package com.microsoft.ml.spark.explainers import breeze.linalg.{norm, DenseVector => BDV} import breeze.stats.distributions.RandBasis -import com.microsoft.ml.spark.explainers.BreezeUtils._ +import com.microsoft.ml.spark.core.utils.BreezeUtils._ import com.microsoft.ml.spark.explainers.RowUtils.RowCanGetAsDouble import com.microsoft.ml.spark.lime.{Superpixel, SuperpixelData} import org.apache.spark.ml.linalg.{Vector, Vectors} @@ -152,6 +152,8 @@ private[explainers] class LIMETabularSampler(val instance: Row, val featureStats (instance.get(i), 1d) case (_: ContinuousFeatureStats, i) => (instance.getAsDouble(i), instance.getAsDouble(i)) + case (_, _) => + throw new NotImplementedError("invalid state") }.unzip (Row.fromSeq(identityRow), Vectors.dense(identityState.toArray), 0d) diff --git a/core/src/main/scala/com/microsoft/ml/spark/lime/LIME.scala b/core/src/main/scala/com/microsoft/ml/spark/lime/LIME.scala index 6924537271..03557d15e1 100644 --- a/core/src/main/scala/com/microsoft/ml/spark/lime/LIME.scala +++ b/core/src/main/scala/com/microsoft/ml/spark/lime/LIME.scala @@ -164,10 +164,10 @@ trait LIMEBase extends LIMEParams with ComplexParamsWritable { } -@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3") +@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3") object TabularLIME extends ComplexParamsReadable[TabularLIME] -@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3") +@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3") class TabularLIME(val uid: String) extends Estimator[TabularLIMEModel] with LIMEParams with Wrappable with ComplexParamsWritable with BasicLogging { logClass() @@ -200,10 +200,10 @@ class TabularLIME(val uid: String) extends Estimator[TabularLIMEModel] } } -@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3") +@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3") object TabularLIMEModel extends ComplexParamsReadable[TabularLIMEModel] -@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3") +@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3") class TabularLIMEModel(val uid: String) extends Model[TabularLIMEModel] with LIMEBase with Wrappable with BasicLogging { logClass() @@ -256,7 +256,7 @@ class TabularLIMEModel(val uid: String) extends Model[TabularLIMEModel] } -@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-RC3") +@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-rc3") object ImageLIME extends ComplexParamsReadable[ImageLIME] /** Distributed implementation of @@ -264,7 +264,7 @@ object ImageLIME extends ComplexParamsReadable[ImageLIME] * * https://arxiv.org/pdf/1602.04938v1.pdf */ -@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-RC3") +@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-rc3") class ImageLIME(val uid: String) extends Transformer with LIMEBase with Wrappable with HasModifier with HasCellSize with BasicLogging { logClass() diff --git a/core/src/main/scala/com/microsoft/ml/spark/lime/TextLIME.scala b/core/src/main/scala/com/microsoft/ml/spark/lime/TextLIME.scala index bca480c744..21b010632d 100644 --- a/core/src/main/scala/com/microsoft/ml/spark/lime/TextLIME.scala +++ b/core/src/main/scala/com/microsoft/ml/spark/lime/TextLIME.scala @@ -25,7 +25,7 @@ object TextLIME extends ComplexParamsReadable[TextLIME] * * https://arxiv.org/pdf/1602.04938v1.pdf */ -@deprecated("Please use 'com.microsoft.ml.spark.explainers.TextLIME'.", since="1.0.0-RC3") +@deprecated("Please use 'com.microsoft.ml.spark.explainers.TextLIME'.", since="1.0.0-rc3") class TextLIME(val uid: String) extends Model[TextLIME] with LIMEBase with Wrappable with BasicLogging { logClass() diff --git a/core/src/main/scala/com/microsoft/ml/spark/stages/MiniBatchTransformer.scala b/core/src/main/scala/com/microsoft/ml/spark/stages/MiniBatchTransformer.scala index f06f704acb..2089827aeb 100644 --- a/core/src/main/scala/com/microsoft/ml/spark/stages/MiniBatchTransformer.scala +++ b/core/src/main/scala/com/microsoft/ml/spark/stages/MiniBatchTransformer.scala @@ -8,7 +8,8 @@ import com.microsoft.ml.spark.logging.BasicLogging import org.apache.spark.ml.Transformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -30,16 +31,20 @@ trait MiniBatchBase extends Transformer with DefaultParamsWritable with Wrappabl def transform(dataset: Dataset[_]): DataFrame = { logTransform[DataFrame]({ + val outputSchema = transformSchema(dataset.schema) + implicit val outputEncoder: ExpressionEncoder[Row] = RowEncoder(outputSchema) dataset.toDF().mapPartitions { it => if (it.isEmpty) { it } else { - getBatcher(it).map(listOfRows => Row.fromSeq(transpose(listOfRows.map(r => r.toSeq)))) + getBatcher(it).map { + listOfRows => + new GenericRowWithSchema(transpose(listOfRows.map(r => r.toSeq)).toArray, outputSchema) + } } - }(RowEncoder(transformSchema(dataset.schema))) + } }) } - } object DynamicMiniBatchTransformer extends DefaultParamsReadable[DynamicMiniBatchTransformer] @@ -203,15 +208,20 @@ class FlattenBatch(val uid: String) override def transform(dataset: Dataset[_]): DataFrame = { logTransform[DataFrame]({ + val outputSchema = transformSchema(dataset.schema) + implicit val outputEncoder: ExpressionEncoder[Row] = RowEncoder(outputSchema) + dataset.toDF().mapPartitions(it => it.flatMap { rowOfLists => - val transposed = transpose( + val transposed: Seq[Seq[Any]] = transpose( (0 until rowOfLists.length) .filterNot(rowOfLists.isNullAt) .map(rowOfLists.getSeq)) - transposed.map(Row.fromSeq) + transposed.map { + values => new GenericRowWithSchema(values.toArray, outputSchema) + } } - )(RowEncoder(transformSchema(dataset.schema))) + ) }) } diff --git a/core/src/test/scala/com/microsoft/ml/spark/core/test/base/TestBase.scala b/core/src/test/scala/com/microsoft/ml/spark/core/test/base/TestBase.scala index 83aa9ee72d..21efc75d3f 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/core/test/base/TestBase.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/core/test/base/TestBase.scala @@ -238,7 +238,7 @@ abstract class TestBase extends FunSuite with BeforeAndAfterEachTestData with Be } } - def breezeVectorEq[T: Field: ClassTag](tol: Double)(implicit normImpl: Impl[T, Double]): Equality[BDV[T]] = + def breezeVectorEq[T: Field](tol: Double)(implicit normImpl: Impl[T, Double]): Equality[BDV[T]] = (a: BDV[T], b: Any) => { b match { case p: BDV[T @unchecked] => @@ -246,4 +246,13 @@ abstract class TestBase extends FunSuite with BeforeAndAfterEachTestData with Be case _ => false } } + + def mapEq[K, V: Equality]: Equality[Map[K, V]] = { + (a: Map[K, V], b: Any) => { + b match { + case m: Map[K @unchecked, V @unchecked] => a.keySet == m.keySet && a.keySet.forall(key => a(key) === m(key)) + case _ => false + } + } + } } diff --git a/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/SamplerSuite.scala b/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/SamplerSuite.scala index 0c4ea711ed..1716380a2f 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/SamplerSuite.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/SamplerSuite.scala @@ -6,7 +6,7 @@ package com.microsoft.ml.spark.explainers.split1 import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} import breeze.stats.distributions.RandBasis import breeze.stats.{mean, stddev} -import com.microsoft.ml.spark.explainers.BreezeUtils._ +import com.microsoft.ml.spark.core.utils.BreezeUtils._ import com.microsoft.ml.spark.explainers._ import com.microsoft.ml.spark.io.image.ImageUtils import com.microsoft.ml.spark.lime.{Superpixel, SuperpixelData} diff --git a/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/TabularLIMEExplainerSuite.scala b/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/TabularLIMEExplainerSuite.scala index 7eb1f24b64..76b2181399 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/TabularLIMEExplainerSuite.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/TabularLIMEExplainerSuite.scala @@ -6,7 +6,7 @@ package com.microsoft.ml.spark.explainers.split1 import breeze.linalg.{DenseVector => BDV} import com.microsoft.ml.spark.core.test.base.TestBase import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing} -import com.microsoft.ml.spark.explainers.BreezeUtils._ +import com.microsoft.ml.spark.core.utils.BreezeUtils._ import com.microsoft.ml.spark.explainers.{LocalExplainer, TabularLIME} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer, VectorAssembler} diff --git a/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/TabularSHAPExplainerSuite.scala b/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/TabularSHAPExplainerSuite.scala index 994ac339c9..00ae5d11ba 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/TabularSHAPExplainerSuite.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/TabularSHAPExplainerSuite.scala @@ -5,7 +5,7 @@ package com.microsoft.ml.spark.explainers.split1 import com.microsoft.ml.spark.core.test.base.TestBase import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing} -import com.microsoft.ml.spark.explainers.BreezeUtils._ +import com.microsoft.ml.spark.core.utils.BreezeUtils._ import com.microsoft.ml.spark.explainers.{LocalExplainer, TabularSHAP} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer, VectorAssembler} diff --git a/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/TextExplainersSuite.scala b/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/TextExplainersSuite.scala index 9e843b87e2..4906b7ec54 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/TextExplainersSuite.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/TextExplainersSuite.scala @@ -5,7 +5,7 @@ package com.microsoft.ml.spark.explainers.split1 import com.microsoft.ml.spark.core.test.base.TestBase import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing} -import com.microsoft.ml.spark.explainers.BreezeUtils._ +import com.microsoft.ml.spark.core.utils.BreezeUtils._ import com.microsoft.ml.spark.explainers.{LocalExplainer, TextLIME, TextSHAP} import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} diff --git a/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/VectorLIMEExplainerSuite.scala b/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/VectorLIMEExplainerSuite.scala index f0b0de4719..a47655ec4c 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/VectorLIMEExplainerSuite.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/VectorLIMEExplainerSuite.scala @@ -7,7 +7,7 @@ import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV} import breeze.stats.distributions.Rand import com.microsoft.ml.spark.core.test.base.{Flaky, TestBase} import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing} -import com.microsoft.ml.spark.explainers.BreezeUtils._ +import com.microsoft.ml.spark.core.utils.BreezeUtils._ import com.microsoft.ml.spark.explainers.{LocalExplainer, VectorLIME} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} diff --git a/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/VectorSHAPExplainerSuite.scala b/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/VectorSHAPExplainerSuite.scala index 910496dc38..9c1bc6a066 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/VectorSHAPExplainerSuite.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/explainers/split1/VectorSHAPExplainerSuite.scala @@ -7,7 +7,7 @@ import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV} import breeze.stats.distributions.RandBasis import com.microsoft.ml.spark.core.test.base.TestBase import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing} -import com.microsoft.ml.spark.explainers.BreezeUtils._ +import com.microsoft.ml.spark.core.utils.BreezeUtils._ import com.microsoft.ml.spark.explainers.{LocalExplainer, VectorSHAP} import com.microsoft.ml.spark.stages.UDFTransformer import org.apache.spark.injections.UDFUtils diff --git a/core/src/test/scala/com/microsoft/ml/spark/lime/LIMESuite.scala b/core/src/test/scala/com/microsoft/ml/spark/lime/LIMESuite.scala index b58e597944..deb87ee43c 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/lime/LIMESuite.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/lime/LIMESuite.scala @@ -12,6 +12,7 @@ import org.apache.spark.ml.param.DataFrameEquality import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.util.MLReadable +@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3") trait LimeTestBase extends TestBase { import spark.implicits._ @@ -41,6 +42,7 @@ trait LimeTestBase extends TestBase { lazy val limeModel = lime.fit(df) } +@deprecated("Please use 'com.microsoft.ml.spark.explainers.TabularLIME'.", since="1.0.0-rc3") class TabularLIMESuite extends EstimatorFuzzing[TabularLIME] with DataFrameEquality with LimeTestBase { @@ -57,6 +59,7 @@ class TabularLIMESuite extends EstimatorFuzzing[TabularLIME] with override def modelReader: MLReadable[_] = TabularLIMEModel } +@deprecated("Please use 'com.microsoft.ml.spark.explainers.TextLIME'.", since="1.0.0-rc3") class TabularLIMEModelSuite extends TransformerFuzzing[TabularLIMEModel] with DataFrameEquality with LimeTestBase { diff --git a/core/src/test/scala/com/microsoft/ml/spark/lime/TextLIMESuite.scala b/core/src/test/scala/com/microsoft/ml/spark/lime/TextLIMESuite.scala index 2736196f3c..1151637b13 100644 --- a/core/src/test/scala/com/microsoft/ml/spark/lime/TextLIMESuite.scala +++ b/core/src/test/scala/com/microsoft/ml/spark/lime/TextLIMESuite.scala @@ -17,6 +17,7 @@ import org.apache.spark.sql.types.DoubleType import org.scalactic.Equality import org.scalatest.Assertion +@deprecated("Please use 'com.microsoft.ml.spark.explainers.TextLIME'.", since="1.0.0-rc3") class TextLIMESuite extends TransformerFuzzing[TextLIME] { import spark.implicits._ diff --git a/deep-learning/src/main/python/mmlspark/onnx/ONNXModel.py b/deep-learning/src/main/python/mmlspark/onnx/ONNXModel.py new file mode 100644 index 0000000000..90e7634f8e --- /dev/null +++ b/deep-learning/src/main/python/mmlspark/onnx/ONNXModel.py @@ -0,0 +1,29 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. + +import sys + +if sys.version >= "3": + basestring = str + +from mmlspark.onnx._ONNXModel import _ONNXModel +from pyspark.ml.common import inherit_doc + + +@inherit_doc +class ONNXModel(_ONNXModel): + """ + + Args: + SparkSession (SparkSession): The SparkSession that will be used to find the model + location (str): The location of the model, either on local or HDFS + """ + + def setModelLocation(self, location): + self._java_obj = self._java_obj.setModelLocation(location) + return self + + def setMiniBatchSize(self, n): + self._java_obj = self._java_obj.setMiniBatchSize(n) + return self + diff --git a/deep-learning/src/main/python/mmlspark/onnx/__init__.py b/deep-learning/src/main/python/mmlspark/onnx/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/deep-learning/src/main/scala/com/microsoft/ml/spark/SharedParams.scala b/deep-learning/src/main/scala/com/microsoft/ml/spark/SharedParams.scala new file mode 100644 index 0000000000..1b43671743 --- /dev/null +++ b/deep-learning/src/main/scala/com/microsoft/ml/spark/SharedParams.scala @@ -0,0 +1,33 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.ml.spark + +import org.apache.spark.ml.param.{MapParam, Param, Params} +import spray.json.DefaultJsonProtocol._ + +trait HasFeedFetchDicts extends Params { + val feedDict: MapParam[String, String] = new MapParam[String, String]( + this, + "feedDict", + " Provide a map from CNTK/ONNX model input variable names (keys) to column names of the input dataframe (values)" + ) + + def setFeedDict(value: Map[String, String]): this.type = set(feedDict, value) + + def setFeedDict(k: String, v: String): this.type = set(feedDict, Map(k -> v)) + + def getFeedDict: Map[String, String] = $(feedDict) + + val fetchDict: MapParam[String, String] = new MapParam[String, String]( + this, + "fetchDict", + "Provide a map from column names of the output dataframe (keys) to CNTK/ONNX model output variable names (values)" + ) + + def setFetchDict(value: Map[String, String]): this.type = set("fetchDict", value) + + def setFetchDict(k: String, v: String): this.type = set(fetchDict, Map(k -> v)) + + def getFetchDict: Map[String, String] = $(fetchDict) +} diff --git a/deep-learning/src/main/scala/com/microsoft/ml/spark/cntk/CNTKModel.scala b/deep-learning/src/main/scala/com/microsoft/ml/spark/cntk/CNTKModel.scala index fadba325de..c4fc7040ea 100644 --- a/deep-learning/src/main/scala/com/microsoft/ml/spark/cntk/CNTKModel.scala +++ b/deep-learning/src/main/scala/com/microsoft/ml/spark/cntk/CNTKModel.scala @@ -6,6 +6,7 @@ package com.microsoft.ml.spark.cntk import com.microsoft.CNTK.CNTKExtensions._ import com.microsoft.CNTK.CNTKUtils._ import com.microsoft.CNTK.{CNTKExtensions, DataType => CNTKDataType, SerializableFunction => CNTKFunction, _} +import com.microsoft.ml.spark.HasFeedFetchDicts import com.microsoft.ml.spark.cntk.ConversionUtils.GVV import com.microsoft.ml.spark.codegen.Wrappable import com.microsoft.ml.spark.core.schema.DatasetExtensions.findUnusedColumnName @@ -144,7 +145,7 @@ private object CNTKModelUtils extends java.io.Serializable { object CNTKModel extends ComplexParamsReadable[CNTKModel] class CNTKModel(override val uid: String) extends Model[CNTKModel] with ComplexParamsWritable - with HasMiniBatcher with Wrappable with BasicLogging { + with HasMiniBatcher with HasFeedFetchDicts with Wrappable with BasicLogging { logClass() override protected lazy val pyInternalWrapper = true @@ -178,24 +179,6 @@ class CNTKModel(override val uid: String) extends Model[CNTKModel] with ComplexP setModel(CNTKFunction.loadModelFromBytes(modelBytes)) } - val batchInput = new BooleanParam(this, "batchInput", - "whether to use a batcher") - - setDefault(batchInput -> true) - - def setBatchInput(v: Boolean): this.type = set(batchInput, v) - - def getBatchInput: Boolean = $(batchInput) - - val shapeOutput = new BooleanParam(this, "shapeOutput", - "whether to shape the output") - - setDefault(shapeOutput -> false) - - def setShapeOutput(v: Boolean): this.type = set(shapeOutput, v) - - def getShapeOutput: Boolean = $(shapeOutput) - val convertOutputToDenseVector = new BooleanParam(this, "convertOutputToDenseVector", "whether to convert the output to dense vectors") @@ -205,26 +188,10 @@ class CNTKModel(override val uid: String) extends Model[CNTKModel] with ComplexP def getConvertOutputToDenseVector: Boolean = $(convertOutputToDenseVector) - val feedDict: MapParam[String, String] = new MapParam[String, String](this, "feedDict", - " Map of CNTK Variable names (keys) and Column Names (values)") - - setDefault(feedDict -> Map((ArgumentPrefix + 0) -> (ArgumentPrefix + 0))) - - def setFeedDict(value: Map[String, String]): this.type = set(feedDict, value) - - def setFeedDict(k: String, v: String): this.type = set(feedDict, Map(k -> v)) - - def getFeedDict: Map[String, String] = $(feedDict) - - val fetchDict: MapParam[String, String] = new MapParam[String, String](this, "fetchDict", - " Map of Column Names (keys) and CNTK Variable names (values)") - setDefault(fetchDict -> Map((OutputPrefix + 0) -> (OutputPrefix + 0))) - - def setFetchDict(value: Map[String, String]): this.type = set("fetchDict", value) - - def setFetchDict(k: String, v: String): this.type = set(fetchDict, Map(k -> v)) - - def getFetchDict: Map[String, String] = $(fetchDict) + setDefault( + feedDict -> Map((ArgumentPrefix + 0) -> (ArgumentPrefix + 0)), + fetchDict -> Map((OutputPrefix + 0) -> (OutputPrefix + 0)) + ) // Alternative Input APIs @@ -370,20 +337,23 @@ class CNTKModel(override val uid: String) extends Model[CNTKModel] with ComplexP } } + val batchInput = new BooleanParam(this, "batchInput", + "whether to use a batcher") + + def setBatchInput(v: Boolean): this.type = set(batchInput, v) + + def getBatchInput: Boolean = $(batchInput) + + setDefault( + batchInput -> true, + miniBatcher -> new FixedMiniBatchTransformer().setBatchSize(10) //scalastyle:ignore magic.number + ) + /** Returns the dimensions of the required input */ def getInputShapes: List[Array[Int]] = { getModel.getArguments.asScala.map(_.getShape.getDimensions.map(_.toInt)).toList } - setDefault(miniBatcher -> new FixedMiniBatchTransformer().setBatchSize(10)) //scalastyle:ignore magic.number - - private def getElementType(t: DataType): DataType = { - t match { - case ArrayType(et, _) => getElementType(et) - case et => et - } - } - def transformSchema(schema: StructType): StructType = { getFeedDict.foreach { case (_, colName) => val colType = schema(colName).dataType @@ -418,8 +388,8 @@ class CNTKModel(override val uid: String) extends Model[CNTKModel] with ComplexP private val coercionPrefix = s"coerced_$uid" - private def coerceType(schema: StructType, colName: String, targetElementType: DataType): - (Option[(UserDefinedFunction, String)]) = { + private def coerceType(schema: StructType, colName: String, targetElementType: DataType) + : Option[(UserDefinedFunction, String)] = { val colType = schema(colName).dataType match { case ArrayType(dt, _) => dt } diff --git a/deep-learning/src/main/scala/com/microsoft/ml/spark/onnx/ONNXModel.scala b/deep-learning/src/main/scala/com/microsoft/ml/spark/onnx/ONNXModel.scala new file mode 100644 index 0000000000..b24012a625 --- /dev/null +++ b/deep-learning/src/main/scala/com/microsoft/ml/spark/onnx/ONNXModel.scala @@ -0,0 +1,647 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.ml.spark.onnx + +import ai.onnxruntime.OrtException.OrtErrorCode +import ai.onnxruntime.OrtSession.SessionOptions +import ai.onnxruntime.OrtSession.SessionOptions.OptLevel +import ai.onnxruntime._ +import breeze.linalg.{argmax, softmax, DenseVector => BDV} +import com.microsoft.ml.spark.HasFeedFetchDicts +import com.microsoft.ml.spark.codegen.Wrappable +import com.microsoft.ml.spark.core.env.StreamUtilities.using +import com.microsoft.ml.spark.core.schema.DatasetExtensions +import com.microsoft.ml.spark.core.utils.BreezeUtils._ +import com.microsoft.ml.spark.logging.BasicLogging +import com.microsoft.ml.spark.stages._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.injections.UDFUtils +import org.apache.spark.internal.Logging +import org.apache.spark.ml._ +import org.apache.spark.ml.linalg.SQLDataTypes._ +import org.apache.spark.ml.linalg.{SQLDataTypes, Vector} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ +import org.apache.spark.{SparkContext, TaskContext} +import spray.json.DefaultJsonProtocol._ + +import java.nio._ +import scala.annotation.tailrec +import scala.collection.JavaConverters._ +import scala.jdk.CollectionConverters.mapAsScalaMapConverter +import scala.reflect.ClassTag + +trait ONNXModelParams extends Params with HasMiniBatcher with HasFeedFetchDicts { + + val modelPayload: ByteArrayParam = new ByteArrayParam( + this, + "modelPayload", + "Array of bytes containing the serialized ONNX model." + ) + + def getModelPayload: Array[Byte] = $(modelPayload) + + val softMaxDict: MapParam[String, String] = new MapParam[String, String]( + this, + "softMaxDict", + "A map between output dataframe columns, where the value column will be computed from taking " + + "the softmax of the key column. If the 'rawPrediction' column contains logits outputs, then one can " + + "set softMaxDict to `Map(\"rawPrediction\" -> \"probability\")` to obtain the probability outputs." + ) + + def setSoftMaxDict(value: Map[String, String]): this.type = set(softMaxDict, value) + + def setSoftMaxDict(k: String, v: String): this.type = set(softMaxDict, Map(k -> v)) + + def getSoftMaxDict: Map[String, String] = get(softMaxDict).getOrElse(Map.empty) + + val argMaxDict: MapParam[String, String] = new MapParam[String, String]( + this, + "argMaxDict", + "A map between output dataframe columns, where the value column will be computed from taking " + + "the argmax of the key column. This can be used to convert probability output to predicted label." + ) + + def setArgMaxDict(value: Map[String, String]): this.type = set(argMaxDict, value) + + def setArgMaxDict(k: String, v: String): this.type = set(argMaxDict, Map(k -> v)) + + def getArgMaxDict: Map[String, String] = get(argMaxDict).getOrElse(Map.empty) + + val deviceType: Param[String] = new Param[String]( + this, + "deviceType", + "Specify a device type the model inference runs on. Supported types are: CPU or CUDA." + + "If not specified, auto detection will be used.", + ParamValidators.inArray(Array("CPU", "CUDA")) + ) + + def getDeviceType: String = $(deviceType) + + def setDeviceType(value: String): this.type = set(deviceType, value) + + val optimizationLevel: Param[String] = new Param[String]( + this, + "optimizationLevel", + "Specify the optimization level for the ONNX graph optimizations. Details at " + + "https://onnxruntime.ai/docs/resources/graph-optimizations.html#graph-optimization-levels. " + + "Supported values are: NO_OPT; BASIC_OPT; EXTENDED_OPT; ALL_OPT. Default: ALL_OPT.", + ParamValidators.inArray(Array("NO_OPT", "BASIC_OPT", "EXTENDED_OPT", "ALL_OPT")) + ) + + def getOptimizationLevel: String = $(optimizationLevel) + + def setOptimizationLevel(value: String): this.type = set(optimizationLevel, value) + + setDefault( + optimizationLevel -> "ALL_OPT", + miniBatcher -> new FixedMiniBatchTransformer().setBatchSize(10) //scalastyle:ignore magic.number + ) +} + +//noinspection ScalaStyle +private class ClosableIterator[+T](delegate: Iterator[T], cleanup: => Unit) extends Iterator[T] { + override def hasNext: Boolean = delegate.hasNext + + override def next(): T = { + val t = delegate.next() + + if (!delegate.hasNext) { + // Cleanup the resource if there is no more rows, but iterator does not have to exhaust. + cleanup + } + + t + } + + override def finalize(): Unit = { + try { + // Make sure resource is cleaned up. + cleanup + } + catch { + case _: Throwable => + } + + super.finalize() + } +} + +/** + * Object model for an ONNX model: + * OrtSession + * |-InputInfo: Map[String, NodeInfo] + * |-OutputInfo: Map[String, NodeInfo] + * OrtSession is the entry point for the object model. Most importantly it defines the InputInfo and OutputInfo maps. + * ------------------------------------ + * NodeInfo + * |-name: String + * |-info: ValueInfo + * Each NodeInfo is a name and ValueInfo tuple. ValueInfo has three implementations, explained below. + * ------------------------------------ + * TensorInfo extends ValueInfo + * |-shape: Array[Long] + * |-type: OnnxJavaType + * TensorInfo is the most common type of ValueInfo. It defines the type of the tensor elements, and the shape. + * The first dimension of the tensor is assumed to be the batch size. For example, FLOAT[-1, 3, 224, 224] + * could represent a unlimited batch size * 3 channels * 224 height * 224 width tensor, where each element is a float. + * ------------------------------------ + * SequenceInfo extends ValueInfo + * |-sequenceOfMaps: Boolean + * |-sequenceType: OnnxJavaType + * |-mapInfo: MapInfo + * |-length: Int + * SequenceInfo can be a sequence of values (value type specified by sequenceType) if sequenceOfMaps is false, + * or a sequence of MapInfo if sequenceOfMaps is true. Sequence of MapInfo is usually used for ZipMap type of output, + * where the sequence represent the batch, and each MapInfo represents probability or logits outcome per class for + * each observation. + * ------------------------------------ + * MapInfo extends ValueInfo + * |-keyType: OnnxJavaType + * |-valueType: OnnxJavaType + * |-size: Int + * MapInfo defines keyType, valueType and size. It is usually used inside SequenceInfo. + */ +object ONNXModel extends ComplexParamsReadable[ONNXModel] with Logging { + private[onnx] def initializeOrt(modelContent: Array[Byte], + ortEnv: OrtEnvironment, + optLevel: OptLevel = OptLevel.ALL_OPT, + gpuDeviceId: Option[Int] = None) + : OrtSession = { + val options = new SessionOptions() + + try { + gpuDeviceId.foreach(options.addCUDA) + } catch { + case exp: OrtException if exp.getCode == OrtErrorCode.ORT_INVALID_ARGUMENT => + val err = s"GPU device is found on executor nodes with id ${gpuDeviceId.get}, " + + s"but adding CUDA support failed. Most likely the ONNX runtime supplied to the cluster " + + s"does not support GPU. Please install com.microsoft.onnxruntime:onnxruntime_gpu:{version} " + + s"instead for optimal performance. Exception details: ${exp.toString}" + logError(err) + } + + options.setOptimizationLevel(optLevel) + ortEnv.createSession(modelContent, options) + } + + private[onnx] def mapOnnxJavaTypeToDataType(javaType: OnnxJavaType): DataType = { + javaType match { + case OnnxJavaType.INT8 => ByteType + case OnnxJavaType.INT16 => ShortType + case OnnxJavaType.INT32 => IntegerType + case OnnxJavaType.INT64 => LongType + case OnnxJavaType.FLOAT => FloatType + case OnnxJavaType.DOUBLE => DoubleType + case OnnxJavaType.BOOL => BooleanType + case OnnxJavaType.STRING => StringType + case OnnxJavaType.UNKNOWN => BinaryType + } + } + + private[onnx] def mapTensorInfoToDataType(tensorInfo: TensorInfo): DataType = { + val dataType = mapOnnxJavaTypeToDataType(tensorInfo.`type`) + + def nestedArrayType(depth: Int, dataType: DataType): ArrayType = { + if (depth == 1) + ArrayType(dataType) + else + ArrayType(nestedArrayType(depth - 1, dataType)) + } + + if (tensorInfo.isScalar) { + dataType + } else if (tensorInfo.getShape.length == 1) { + // first dimension is assumed to be batch size. + dataType + } else { + nestedArrayType(tensorInfo.getShape.length - 1, dataType) + } + } + + @tailrec + private[onnx] def mapValueInfoToDataType(valueInfo: ValueInfo): DataType = { + valueInfo match { + case mapInfo: MapInfo => + val keyDataType = mapOnnxJavaTypeToDataType(mapInfo.keyType) + val valueDataType = mapOnnxJavaTypeToDataType(mapInfo.valueType) + MapType(keyDataType, valueDataType) + case seqInfo: SequenceInfo => + if (seqInfo.sequenceOfMaps) { + mapValueInfoToDataType(seqInfo.mapInfo) + } else { + mapOnnxJavaTypeToDataType(seqInfo.sequenceType) + } + case tensorInfo: TensorInfo => + mapTensorInfoToDataType(tensorInfo) + } + } + + private[onnx] def mapOnnxValueToArray(value: OnnxValue): Seq[Any] = { + value.getInfo match { + case tensorInfo: TensorInfo => + if (tensorInfo.isScalar) + Seq(value.getValue) + else { + value.getValue.asInstanceOf[Array[_]].toSeq + } + case sequenceInfo: SequenceInfo => + if (sequenceInfo.sequenceOfMaps) { + value.getValue.asInstanceOf[java.util.List[java.util.Map[_, _]]] + .asScala.toArray.map(_.asScala.toMap) + } else { + value.getValue.asInstanceOf[java.util.List[_]].asScala + } + case _: MapInfo => + Array(value.getValue.asInstanceOf[java.util.Map[_, _]].asScala.toMap) + } + } + + private def flattenNested[T: ClassTag](nestedSeq: Seq[_]): Seq[T] = { + nestedSeq.flatMap { + case x: T => Array(x) + case s: Seq[_] => + flattenNested(s) + case a: Array[_] => + flattenNested(a) + } + } + + private[onnx] def selectGpuDevice(deviceType: Option[String]): Option[Int] = { + deviceType match { + case None | Some("CUDA") => + val gpuNum = TaskContext.get().resources().get("gpu").flatMap(_.addresses.map(_.toInt).headOption) + gpuNum + case Some("CPU") => + None + case _ => + None + } + } + + private[onnx] def applyModel(session: OrtSession, + env: OrtEnvironment, + feedMap: Map[String, String], + fetchMap: Map[String, String], + inputSchema: StructType + )(rows: Iterator[Row]): Iterator[Row] = { + val results = rows.map { + row => + // Each row contains a batch + // Get the input tensors for each input node. + val inputTensors = session.getInputInfo.asScala.map { + case (inputName, inputNodeInfo) => + + val batchedValues: Seq[Any] = row.getAs[Seq[Any]](feedMap(inputName)) + + inputNodeInfo.getInfo match { + case tensorInfo: TensorInfo => // Only supports tensor input. + val tensor = createTensor(env, tensorInfo, batchedValues) + (inputName, tensor) + case other => + throw new NotImplementedError(s"Only tensor input type is supported, but got $other instead.") + } + } + + // Run the tensors through the ONNX runtime. + val outputBatches: Seq[Seq[Any]] = using(session.run(inputTensors.asJava)) { + result => + // Map the output tensors to batches. + fetchMap.map { + case (_, outputName) => + val i = session.getOutputInfo.asScala.keysIterator.indexOf(outputName) + val outputValue: OnnxValue = result.get(i) + mapOnnxValueToArray(outputValue) + }.toSeq + }.get + + // Close the tensor and clean up native handles + inputTensors.valuesIterator.foreach { + _.close() + } + + // Return a row for each output batch: original payload appended with model output. + val data = inputSchema.map(f => row.getAs[Any](f.name)) + Row.fromSeq(data ++ outputBatches) + } + + new ClosableIterator[Row](results, { + session.close() + env.close() + }) + } + + private def createTensor(env: OrtEnvironment, tensorInfo: TensorInfo, batchedValues: Seq[_]) = { + val classTag = ClassTag(tensorInfo.`type`.clazz) + val flattened: Array[_] = flattenNested(batchedValues)(classTag).toArray + + val shape: Array[Long] = tensorInfo.getShape + // the first dimension of the shape can be -1 when multiple inputs are allowed. Setting it to the real + // input size. Otherwise we cannot create the tensor from the 1D array buffer. + shape(0) = batchedValues.length + + tensorInfo.`type` match { + case OnnxJavaType.FLOAT => + val buffer = FloatBuffer.wrap(flattened.map(_.asInstanceOf[Float])) + OnnxTensor.createTensor(env, buffer, shape) + case OnnxJavaType.DOUBLE => + val buffer = DoubleBuffer.wrap(flattened.map(_.asInstanceOf[Double])) + OnnxTensor.createTensor(env, buffer, shape) + case OnnxJavaType.INT8 => + val buffer = ByteBuffer.wrap(flattened.map(_.asInstanceOf[Byte])) + OnnxTensor.createTensor(env, buffer, shape) + case OnnxJavaType.INT16 => + val buffer = ShortBuffer.wrap(flattened.map(_.asInstanceOf[Short])) + OnnxTensor.createTensor(env, buffer, shape) + case OnnxJavaType.INT32 => + val buffer = IntBuffer.wrap(flattened.map(_.asInstanceOf[Int])) + OnnxTensor.createTensor(env, buffer, shape) + case OnnxJavaType.INT64 => + val buffer = LongBuffer.wrap(flattened.map(_.asInstanceOf[Long])) + OnnxTensor.createTensor(env, buffer, shape) + case OnnxJavaType.STRING => + OnnxTensor.createTensor(env, flattened.map(_.asInstanceOf[String]), shape) + case other => + throw new NotImplementedError(s"Tensor input type $other not supported. " + + s"Only FLOAT, DOUBLE, INT8, INT16, INT32, INT64, STRING types are supported.") + } + } + + /** + * Returns true if the two data types are compatible. They are compatible if they share the same "shape", and + * 1. The element types from both sides are numeric types, or + * 2. The element types from both sides are the same. + */ + @tailrec + private def compatible(from: DataType, to: DataType): Boolean = { + (from, to) match { + case (VectorType, right: ArrayType) => + compatible(DoubleType, right.elementType) + case (left: ArrayType, right: ArrayType) => + compatible(left.elementType, right.elementType) + case (_: NumericType, _: NumericType) => true + case (fromDataType, toDataType) => fromDataType == toDataType + } + } +} + +class ONNXModel(override val uid: String) + extends Transformer + with ComplexParamsWritable + with ONNXModelParams + with Wrappable + with BasicLogging { + + import ONNXModel._ + + override protected lazy val pyInternalWrapper = true + + logClass() + + def this() = this(Identifiable.randomUID("ONNXModel")) + + def modelInput: Map[String, NodeInfo] = { + using(OrtEnvironment.getEnvironment) { + env => + using(initializeOrt(getModelPayload, env)) { + session => session.getInputInfo.asScala.toMap + } + }.flatten.get + } + + def modelOutput: Map[String, NodeInfo] = { + using(OrtEnvironment.getEnvironment) { + env => + using(initializeOrt(getModelPayload, env)) { + session => session.getOutputInfo.asScala.toMap + } + }.flatten.get + } + + private var broadcastedModelPayload: Option[Broadcast[Array[Byte]]] = None + + def setModelPayload(value: Array[Byte]): this.type = { + this.broadcastedModelPayload.foreach(_.destroy) + broadcastedModelPayload = None + this.set(modelPayload, value) + } + + def rebroadcastModelPayload(spark: SparkSession): Broadcast[Array[Byte]] = { + val bc = spark.sparkContext.broadcast(getModelPayload) + broadcastedModelPayload = Some(bc) + bc + } + + def setModelLocation(path: String): this.type = { + val modelBytes = SparkContext.getOrCreate().binaryFiles(path).first()._2.toArray + this.setModelPayload(modelBytes) + } + + override def transform(dataset: Dataset[_]): DataFrame = logTransform { + val inputSchema = dataset.schema + this.validateSchema(inputSchema) + + val modelOutputSchema = getModelOutputSchema(inputSchema) + + implicit val enc: Encoder[Row] = RowEncoder( + StructType(modelOutputSchema.map(f => StructField(f.name, ArrayType(f.dataType)))) + ) + + // The cache call is a workaround for GH issue 1075: + // https://github.com/Azure/mmlspark/issues/1075 + val batchedDF = getMiniBatcher.transform(dataset) + val batchedCache = if (batchedDF.isStreaming) batchedDF else batchedDF.cache().unpersist() + val (coerced, feedDict) = coerceBatchedDf(batchedCache) + val modelBc = broadcastedModelPayload.getOrElse(rebroadcastModelPayload(dataset.sparkSession)) + val (fetchDicts, devType, optLevel) = (getFetchDict, get(deviceType), OptLevel.valueOf(getOptimizationLevel)) + val outputDf = coerced.mapPartitions { + rows => + val payload = modelBc.value + val taskId = TaskContext.get().taskAttemptId() + val gpuDeviceId = selectGpuDevice(devType) + val env = OrtEnvironment.getEnvironment + logInfo(s"Task:$taskId;DeviceType=$devType;DeviceId=$gpuDeviceId;OptimizationLevel=$optLevel") + val session = initializeOrt(payload, env, optLevel, gpuDeviceId) + applyModel(session, env, feedDict, fetchDicts, inputSchema)(rows) + } + + // The cache call is a workaround for GH issue 1075: + // https://github.com/Azure/mmlspark/issues/1075 + val outputCache = if (outputDf.isStreaming) outputDf else outputDf.cache().unpersist() + + val flattenedDF = new FlattenBatch().transform(outputCache) + + (softMaxTransform _ andThen argMaxTransform) (flattenedDF) + } + + private def softMaxTransform(input: DataFrame): DataFrame = { + this.getSoftMaxDict.foldLeft(input) { + case (df, (input, output)) => + val softmaxCol = df.schema(input).dataType match { + case ArrayType(_: NumericType, _) => + val softmaxUdf = UDFUtils.oldUdf({ + array: Seq[Double] => + val data = BDV(array: _*) + (data - softmax(data)).mapValues(math.exp).toSpark + }, VectorType) + softmaxUdf(col(input).cast(ArrayType(DoubleType))) + case MapType(_: NumericType, _: NumericType, _) => + val softmaxUdf = UDFUtils.oldUdf({ + map: Map[Double, Double] => + val data = BDV(map.toSeq.sortBy(_._1).map(_._2): _*) + (data - softmax(data)).mapValues(math.exp).toSpark + }, VectorType) + softmaxUdf(col(input).cast(MapType(DoubleType, DoubleType))) + } + + df.withColumn(output, softmaxCol) + } + } + + private def argMaxTransform(input: DataFrame): DataFrame = { + this.getArgMaxDict.foldLeft(input) { + case (df, (input, output)) => + val argmaxCol = df.schema(input).dataType match { + case ArrayType(_: NumericType, _) => + val argmaxUdf = UDFUtils.oldUdf({ + array: Seq[Double] => argmax(array.toArray).toDouble + }, DoubleType) + argmaxUdf(col(input).cast(ArrayType(DoubleType))) + case MapType(_: NumericType, _: NumericType, _) => + val argmaxUdf = UDFUtils.oldUdf({ + map: Map[Double, Double] => + map.maxBy(_._2)._1 + }, DoubleType) + argmaxUdf(col(input).cast(MapType(DoubleType, DoubleType))) + } + + df.withColumn(output, argmaxCol) + } + } + + private def coerceBatchedDf(df: DataFrame): (DataFrame, Map[String, String]) = { + val toArray = UDFUtils.oldUdf({ + (vectors: Seq[Vector]) => vectors.map(_.toArray) + }, ArrayType(ArrayType(DoubleType))) + + this.modelInput.mapValues(f => ArrayType(mapValueInfoToDataType(f.getInfo))) + .foldLeft((df, this.getFeedDict)) { + case ((accDf, feedDict), (onnxInputName, dataType)) => + val originalColName = this.getFeedDict(onnxInputName) + val coercedColName = DatasetExtensions.findUnusedColumnName(originalColName, accDf) + val originalCol = df.schema(originalColName).dataType match { + case ArrayType(VectorType, _) => toArray(col(originalColName)) + case _ => col(originalColName) + } + + ( + accDf.withColumn(coercedColName, originalCol.cast(dataType)), + feedDict.updated(onnxInputName, coercedColName) + ) + } + } + + override def copy(extra: ParamMap): Transformer = defaultCopy(extra) + + override def transformSchema(schema: StructType): StructType = { + this.validateSchema(schema) + + val modelOutputSchema = getModelOutputSchema(schema) + + val softmaxFields = this.getSoftMaxDict.map { + case (inputCol, outputCol) => + getSoftMaxOutputField(inputCol, outputCol, modelOutputSchema) + } + + val argmaxFields = this.getArgMaxDict.map { + case (inputCol, outputCol) => + getArgMaxOutputField(inputCol, outputCol, modelOutputSchema) + } + + (softmaxFields ++ argmaxFields).foldLeft(modelOutputSchema)(_ add _) + } + + private def validateSchema(schema: StructType): Unit = { + // Validate that input schema matches with onnx model expected input types. + this.modelInput.foreach { + case (onnxInputName, onnxInputInfo) => + val colName = this.getFeedDict.getOrElse( + onnxInputName, + throw new IllegalArgumentException( + s"Onnx model input $onnxInputName is not defined in onnxInputMap parameter." + ) + ) + + val inputDataType: DataType = schema(colName).dataType + val onnxExpectedDataType: DataType = mapValueInfoToDataType(onnxInputInfo.getInfo) + + if (!compatible(inputDataType, onnxExpectedDataType)) { + throw new IllegalArgumentException( + s"Onnx model input $onnxInputName expects type ${onnxExpectedDataType.simpleString}, " + + s"but got type ${inputDataType.simpleString}") + } + } + + // Validate the output col names do not conflict with input schema. + (this.getFetchDict.keySet ++ this.getSoftMaxDict.values ++ this.getArgMaxDict.values).foreach { + colName => + if (schema.fieldNames.map(_.toLowerCase).contains(colName.toLowerCase)) { + throw new IllegalArgumentException(s"Output field $colName already exists in the input schema.") + } + } + } + + /** + * Gets the output schema from the ONNX model + */ + private def getModelOutputSchema(inputSchema: StructType): StructType = { + // Get ONNX model output cols. + val modelOutputFields = this.getFetchDict.map { + case (colName, onnxOutputName) => + val onnxOutput = this.modelOutput.getOrElse(onnxOutputName, + throw new IllegalArgumentException(s"""Onnx model does not have an output named "$onnxOutputName"""") + ) + + val dataType = mapValueInfoToDataType(onnxOutput.getInfo) + + StructField(colName, dataType) + } + + StructType(inputSchema.fields ++ modelOutputFields) + } + + private def getSoftMaxOutputField(inputCol: String, outputCol: String, schema: StructType) = { + val inputField = schema(inputCol) + + val outputType = inputField.dataType match { + case MapType(_: NumericType, _: NumericType, _) => SQLDataTypes.VectorType + case ArrayType(_: NumericType, _) => SQLDataTypes.VectorType + case t => throw new IllegalArgumentException( + s"Input type for Softmax must be ArrayType(NumericType) or MapType(NumericType, NumericType), " + + s"but got $t instead." + ) + } + + StructField(outputCol, outputType) + } + + private def getArgMaxOutputField(inputCol: String, outputCol: String, schema: StructType) = { + val inputField = schema(inputCol) + + val outputType = inputField.dataType match { + case ArrayType(_: NumericType, _) => DoubleType + case MapType(_: NumericType, _: NumericType, _) => DoubleType + case t => throw new IllegalArgumentException( + s"Input type for Softmax must be ArrayType(NumericType) or MapType(NumericType, NumericType), " + + s"but got $t instead." + ) + } + + StructField(outputCol, outputType) + } +} diff --git a/deep-learning/src/test/scala/com/microsoft/ml/spark/cntk/CNTKModelSuite.scala b/deep-learning/src/test/scala/com/microsoft/ml/spark/cntk/CNTKModelSuite.scala index 8d2285be0a..f4398e691b 100644 --- a/deep-learning/src/test/scala/com/microsoft/ml/spark/cntk/CNTKModelSuite.scala +++ b/deep-learning/src/test/scala/com/microsoft/ml/spark/cntk/CNTKModelSuite.scala @@ -116,7 +116,6 @@ class CNTKModelSuite extends LinuxOnly with ImageTestUtils with TransformerFuzzi assert(m.getInputShapes.map(_.toList) === List(List(32, 32, 3), List(10))) assert(m.getOutputNode === "OUTPUT_3") assert(m.getOutputNodeIndex === 3) - assert(m.getShapeOutput === false) assert(m.getInputCol === inputCol) assert(m.getOutputCol === outputCol) } diff --git a/deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split2/ImageLIMEExplainerSuite.scala b/deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split2/ImageLIMEExplainerSuite.scala index 131b69f6fd..3869438f71 100644 --- a/deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split2/ImageLIMEExplainerSuite.scala +++ b/deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split2/ImageLIMEExplainerSuite.scala @@ -4,7 +4,7 @@ package com.microsoft.ml.spark.explainers.split2 import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing} -import com.microsoft.ml.spark.explainers.BreezeUtils._ +import com.microsoft.ml.spark.core.utils.BreezeUtils._ import com.microsoft.ml.spark.explainers.{ImageExplainersSuite, ImageFormat, ImageLIME, LocalExplainer} import com.microsoft.ml.spark.io.IOImplicits._ import com.microsoft.ml.spark.lime.SuperpixelData diff --git a/deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split3/ImageSHAPExplainerSuite.scala b/deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split3/ImageSHAPExplainerSuite.scala index 1de490a4a8..9c2bee7055 100644 --- a/deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split3/ImageSHAPExplainerSuite.scala +++ b/deep-learning/src/test/scala/com/microsoft/ml/spark/explainers/split3/ImageSHAPExplainerSuite.scala @@ -4,7 +4,7 @@ package com.microsoft.ml.spark.explainers.split3 import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing} -import com.microsoft.ml.spark.explainers.BreezeUtils._ +import com.microsoft.ml.spark.core.utils.BreezeUtils._ import com.microsoft.ml.spark.explainers.{ImageExplainersSuite, ImageFormat, ImageSHAP, LocalExplainer} import com.microsoft.ml.spark.lime.SuperpixelData import org.apache.spark.ml.linalg.Vector diff --git a/deep-learning/src/test/scala/com/microsoft/ml/spark/lime/ImageLIMESuite.scala b/deep-learning/src/test/scala/com/microsoft/ml/spark/lime/ImageLIMESuite.scala index b53d206137..f46249137a 100644 --- a/deep-learning/src/test/scala/com/microsoft/ml/spark/lime/ImageLIMESuite.scala +++ b/deep-learning/src/test/scala/com/microsoft/ml/spark/lime/ImageLIMESuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.{NamespaceInjections, PipelineModel} import org.apache.spark.sql.functions.col import org.apache.spark.sql.{DataFrame, Row} +@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-RC3") class ImageLIMESuite extends TransformerFuzzing[ImageLIME] with DataFrameEquality with TrainedCNTKModelUtils with FileReaderUtils { diff --git a/deep-learning/src/test/scala/com/microsoft/ml/spark/onnx/ONNXModelSuite.scala b/deep-learning/src/test/scala/com/microsoft/ml/spark/onnx/ONNXModelSuite.scala new file mode 100644 index 0000000000..dd4b5ece47 --- /dev/null +++ b/deep-learning/src/test/scala/com/microsoft/ml/spark/onnx/ONNXModelSuite.scala @@ -0,0 +1,293 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.ml.spark.onnx + +import breeze.linalg.{argmax, argtopk} +import com.microsoft.ml.spark.build.BuildInfo +import com.microsoft.ml.spark.core.env.FileUtilities +import com.microsoft.ml.spark.core.test.base.TestBase +import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing} +import com.microsoft.ml.spark.core.utils.BreezeUtils._ +import com.microsoft.ml.spark.io.IOImplicits._ +import com.microsoft.ml.spark.opencv.ImageTransformer +import org.apache.commons.io.FileUtils +import org.apache.spark.injections.UDFUtils +import org.apache.spark.ml.image.ImageSchema +import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} +import org.apache.spark.ml.util.MLReadable +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{FloatType, IntegerType} +import org.apache.spark.sql.{DataFrame, Row} +import org.scalactic.{Equality, TolerantNumerics} + +import java.io.File +import java.net.URL + +class ONNXModelSuite extends TestBase + with TransformerFuzzing[ONNXModel] { + + override def testObjects(): Seq[TestObject[ONNXModel]] = Seq( + new TestObject(onnxIris, testDfIrisFloat), + new TestObject(onnxIris, testDfIrisDouble), + new TestObject(onnxIris, testDfIrisVector), + new TestObject(onnxMNIST, testDfMNIST), + new TestObject(onnxAdultsIncome, testDfAdultsIncome), + new TestObject(onnxResNet50, testDfResNet50) + ) + + override def reader: MLReadable[_] = ONNXModel + + private val baseUrl = "https://mmlspark.blob.core.windows.net/publicwasb/ONNXModels/" + private implicit val eqFloat: Equality[Float] = TolerantNumerics.tolerantFloatEquality(1E-5f) + private implicit val eqMap: Equality[Map[Long, Float]] = mapEq[Long, Float] + + import spark.implicits._ + + def downloadModel(modelName: String, baseUrl: String): File = { + val f = FileUtilities.join(BuildInfo.datasetDir, "ONNXModel", modelName) + if (!f.exists()) { + FileUtils.copyURLToFile(new URL(new URL(baseUrl), modelName), f) + } + f + } + + private lazy val onnxIris: ONNXModel = { + // Making sure spark context is initialized + spark + val model = downloadModel("iris.onnx", baseUrl) + new ONNXModel() + .setModelLocation(model.getPath) + .setFeedDict(Map("float_input" -> "features")) + .setFetchDict(Map("prediction" -> "output_label", "rawProbability" -> "output_probability")) + } + + private lazy val testDfIrisFloat: DataFrame = Seq( + Array(6.7f, 3.1f, 4.7f, 1.5f), + Array(4.9f, 3.0f, 1.4f, 0.2f), + Array(5.8f, 2.7f, 5.1f, 1.9f) + ) toDF "features" + + private lazy val testDfIrisDouble: DataFrame = Seq( + Array(6.7d, 3.1d, 4.7d, 1.5d), + Array(4.9d, 3.0d, 1.4d, 0.2d), + Array(5.8d, 2.7d, 5.1d, 1.9d) + ) toDF "features" + + private lazy val testDfIrisVector: DataFrame = Seq( + Tuple1(Vectors.dense(6.7d, 3.1d, 4.7d, 1.5d)), + Tuple1(Vectors.dense(4.9d, 3.0d, 1.4d, 0.2d)), + Tuple1(Vectors.dense(5.8d, 2.7d, 5.1d, 1.9d)) + ) toDF "features" + + test("ONNXModel can infer observations of matching input types") { + val predicted = onnxIris.transform(testDfIrisFloat).as[(Seq[Float], Long, Map[Long, Float])].collect() + + assert(predicted(0)._2 == 1L) + assert(predicted(0)._3 === Map(0L -> 0.0032624616f, 1L -> 0.78214455f, 2L -> 0.214593f)) + + assert(predicted(1)._2 == 0L) + assert(predicted(1)._3 === Map(0L -> 0.9666327f, 1L -> 0.033367135f, 2L -> 1.5725234E-7f)) + + assert(predicted(2)._2 == 2L) + assert(predicted(2)._3 === Map(0L -> 5.4029905E-4f, 1L -> 0.24569187f, 2L -> 0.75376785f)) + } + + test("ONNXModel can infer observations of compatible input types") { + val predicted = onnxIris.transform(testDfIrisDouble).as[(Seq[Double], Long, Map[Long, Float])].collect() + + assert(predicted(0)._2 == 1L) + assert(predicted(0)._3 === Map(0L -> 0.0032624616f, 1L -> 0.78214455f, 2L -> 0.214593f)) + + assert(predicted(1)._2 == 0L) + assert(predicted(1)._3 === Map(0L -> 0.9666327f, 1L -> 0.033367135f, 2L -> 1.5725234E-7f)) + + assert(predicted(2)._2 == 2L) + assert(predicted(2)._3 === Map(0L -> 5.4029905E-4f, 1L -> 0.24569187f, 2L -> 0.75376785f)) + } + + test("ONNXModel can infer observations of vector input types") { + val predicted = onnxIris.transform(testDfIrisVector).as[(DenseVector, Long, Map[Long, Float])].collect() + + assert(predicted(0)._2 == 1L) + assert(predicted(0)._3 === Map(0L -> 0.0032624616f, 1L -> 0.78214455f, 2L -> 0.214593f)) + + assert(predicted(1)._2 == 0L) + assert(predicted(1)._3 === Map(0L -> 0.9666327f, 1L -> 0.033367135f, 2L -> 1.5725234E-7f)) + + assert(predicted(2)._2 == 2L) + assert(predicted(2)._3 === Map(0L -> 5.4029905E-4f, 1L -> 0.24569187f, 2L -> 0.75376785f)) + } + + private lazy val onnxMNIST: ONNXModel = { + // Making sure spark context is initialized + spark + val model = downloadModel("mnist-8.onnx", baseUrl) + new ONNXModel() + .setModelLocation(model.getPath) + .setFeedDict(Map("Input3" -> "features")) + .setFetchDict(Map("rawPrediction" -> "Plus214_Output_0")) + .setSoftMaxDict(Map("rawPrediction" -> "probability")) + .setArgMaxDict(Map("rawPrediction" -> "prediction")) + .setMiniBatchSize(1) + } + + def getLibSVM2ImageUdf(origin: String, height: Int, + width: Int, nChannels: Int, mode: Int): UserDefinedFunction = { + UDFUtils.oldUdf( + { + data: Vector => + val array = data.toArray.map(_.toByte) + Row(origin, height, width, nChannels, mode, array) + }, + + ImageSchema.columnSchema + ) + } + + private lazy val testDfMNIST: DataFrame = { + val mnistDataLocation: String = { + val loc = "/tmp/mnist.t" + val f = new File(loc) + if (f.exists()) { + f.delete() + } + + FileUtils.copyURLToFile(new URL("https://mmlspark.blob.core.windows.net/publicwasb/ONNXModels/mnist.t"), f) + loc + } + + val libSVM2ImageFunc = getLibSVM2ImageUdf( + origin = "mnist.t", + height = 28, + width = 28, + nChannels = 1, + mode = 0 + ) + + //noinspection ScalaCustomHdfsFormat + val imageDf: DataFrame = spark.read + .format("libsvm") + .option("numFeatures", (28 * 28).toString) + .load(mnistDataLocation) + .withColumn("label", col("label").cast(IntegerType)) + .withColumn("image", libSVM2ImageFunc(col("features"))) + + val imageTransformer = new ImageTransformer() + .setInputCol("image") + .setOutputCol("features") + .resize(28, 28) + .centerCrop(28, 28) + .normalize(mean = Array(0d), std = Array(1d), colorScaleFactor = 1d / 255d) + .setTensorElementType(FloatType) + + imageTransformer.transform(imageDf).cache() + } + + test("ONNXModel can infer for MNIST model") { + val prediction = onnxMNIST.transform(testDfMNIST) + .select("label", "rawPrediction", "probability", "prediction") + + val rows = prediction.as[(Int, Array[Float], Vector, Double)].head(10) + + val epsilon = 1e-4 + implicit lazy val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(epsilon) + + rows.foreach { + case (label, rawPrediction, probability, prediction) => + assert(label == prediction.toInt) + assert(argmax(rawPrediction) == label) + assert(probability.argmax == label) + assert(probability.toArray.sum === 1.0) + } + } + + private lazy val featuresAdultsIncome = Array("Age", "WorkClass", "fnlwgt", "Education", "EducationNum", + "MaritalStatus", "Occupation", "Relationship", "Race", "Gender", "CapitalGain", "CapitalLoss", "HoursPerWeek", + "NativeCountry") + + private lazy val onnxAdultsIncome = { + spark + val model = downloadModel("adults_income.onnx", baseUrl) + new ONNXModel() + .setModelLocation(model.getPath) + .setFeedDict(featuresAdultsIncome.map(v => (v, v)).toMap) + .setFetchDict(Map("probability" -> "output_probability")) + .setArgMaxDict(Map("probability" -> "prediction")) + } + + private lazy val testDfAdultsIncome = { + val testDf = Seq( + (39L, " State-gov", 77516L, " Bachelors", 13L, " Never-married", " Adm-clerical", + " Not-in-family", " White", " Male", 2174L, 0L, 40L, " United-States"), + (52L, " Self-emp-not-inc", 209642L, " Doctorate", 16L, " Married-civ-spouse", " Exec-managerial", + " Husband", " White", " Male", 0L, 0L, 45L, " United-States") + ).toDF(featuresAdultsIncome: _*) + + featuresAdultsIncome.foldLeft(testDf) { + case (acc, feature) => + acc.withColumn(feature, array(col(feature))) + }.repartition(1) + } + + test("ONNXModel can translate zipmap output properly") { + val Array(row1, row2) = onnxAdultsIncome.transform(testDfAdultsIncome) + .select("probability", "prediction") + .orderBy(col("prediction")) + .as[(Map[Long, Float], Double)] + .collect() + + assert(row1._1 == Map(0L -> 0.99f, 1L -> 0.01f)) + assert(row1._2 == 0.0) + + assert(row2._1 == Map(0L -> 0.19000047f, 1L -> 0.8099995f)) + assert(row2._2 == 1.0) + } + + private lazy val onnxResNet50 = { + spark + val model = downloadModel("resnet50-v2-7.onnx", baseUrl) + new ONNXModel() + .setModelLocation(model.getPath) + .setFeedDict(Map("data" -> "features")) + .setFetchDict(Map("rawPrediction" -> "resnetv24_dense0_fwd")) + .setSoftMaxDict(Map("rawPrediction" -> "probability")) + .setArgMaxDict(Map("rawPrediction" -> "prediction")) + .setMiniBatchSize(1) + } + + private lazy val testDfResNet50: DataFrame = { + val greyhoundImageLocation: String = { + val loc = "/tmp/greyhound.jpg" + val f = new File(loc) + if (f.exists()) { + f.delete() + } + FileUtils.copyURLToFile(new URL("https://mmlspark.blob.core.windows.net/datasets/LIME/greyhound.jpg"), f) + loc + } + + val imageDf = spark.read.image.load(greyhoundImageLocation) + val imageTransformer = new ImageTransformer() + .setInputCol("image") + .setOutputCol("features") + .resize(224, 224) + .centerCrop(224, 224) + .normalize(mean = Array(0.485, 0.456, 0.406), std = Array(0.229, 0.224, 0.225), colorScaleFactor = 1d / 255d) + .setTensorElementType(FloatType) + + imageTransformer.transform(imageDf).cache() + } + + test("ONNXModel can infer for resnet50 model") { + val (probability, prediction) = onnxResNet50.transform(testDfResNet50) + .select("probability", "prediction") + .as[(Vector, Double)] + .head + + val top2 = argtopk(probability.toBreeze, 2).toArray + assert(top2 sameElements Array(176, 172)) + assert(prediction.toInt == 176) + } +} diff --git a/notebooks/DeepLearning - Flower Image Classification.ipynb b/notebooks/DeepLearning - Flower Image Classification.ipynb index 165bd30ce1..b76f914837 100644 --- a/notebooks/DeepLearning - Flower Image Classification.ipynb +++ b/notebooks/DeepLearning - Flower Image Classification.ipynb @@ -58,7 +58,7 @@ "# Make some featurizers\n", "it = ImageTransformer()\\\n", " .setOutputCol(\"scaled\")\\\n", - " .resize(height = 60, width = 60)\n", + " .resize(size=(60, 60))\n", "\n", "ur = UnrollImage()\\\n", " .setInputCol(\"scaled\")\\\n", diff --git a/notebooks/OpenCV - Pipeline Image Transformations.ipynb b/notebooks/OpenCV - Pipeline Image Transformations.ipynb index c3c2d3f469..0368d08757 100644 --- a/notebooks/OpenCV - Pipeline Image Transformations.ipynb +++ b/notebooks/OpenCV - Pipeline Image Transformations.ipynb @@ -146,7 +146,7 @@ "\n", "tr = (ImageTransformer() # images are resized and then cropped\n", " .setOutputCol(\"transformed\")\n", - " .resize(height = 200, width = 200)\n", + " .resize(size=(200, 200))\n", " .crop(0, 0, height = 180, width = 180) )\n", "\n", "small = tr.transform(images).select(\"transformed\")\n", diff --git a/opencv/src/main/python/mmlspark/opencv/ImageTransformer.py b/opencv/src/main/python/mmlspark/opencv/ImageTransformer.py index ec370686fe..1ef0a210b6 100644 --- a/opencv/src/main/python/mmlspark/opencv/ImageTransformer.py +++ b/opencv/src/main/python/mmlspark/opencv/ImageTransformer.py @@ -2,11 +2,13 @@ # Licensed under the MIT License. See LICENSE in project root for information. import sys +from typing import List -if sys.version >= '3': +if sys.version >= "3": basestring = str import pyspark +from pyspark import SparkContext from pyspark.ml.common import inherit_doc from pyspark.sql.types import * from pyspark.sql.types import Row, _create_row @@ -15,13 +17,17 @@ ImageFields = ["origin", "height", "width", "nChannels", "mode", "data"] -ImageSchema = StructType([ - StructField(ImageFields[0], StringType(), True), - StructField(ImageFields[1], IntegerType(), True), - StructField(ImageFields[2], IntegerType(), True), - StructField(ImageFields[3], IntegerType(), True), - StructField(ImageFields[4], IntegerType(), True), # OpenCV type: CV_8U in most cases - StructField(ImageFields[5], BinaryType(), True) ]) # OpenCV bytes: row-wise BGR in most cases +ImageSchema = StructType( + [ + StructField(ImageFields[0], StringType(), True), + StructField(ImageFields[1], IntegerType(), True), + StructField(ImageFields[2], IntegerType(), True), + StructField(ImageFields[3], IntegerType(), True), + StructField(ImageFields[4], IntegerType(), True), # OpenCV type: CV_8U in most cases + StructField(ImageFields[5], BinaryType(), True), + ] +) # OpenCV bytes: row-wise BGR in most cases + def toNDArray(image): """ @@ -33,9 +39,10 @@ def toNDArray(image): Returns: array: The image as a 1-dimensional array """ - return np.asarray(image.data, dtype = np.uint8).reshape((image.height, image.width, 3))[:,:,(2,1,0)] + return np.asarray(image.data, dtype=np.uint8).reshape((image.height, image.width, 3))[:, :, (2, 1, 0)] + -def toImage(array, path = "", mode = 16): +def toImage(array, path="", mode=16): """ Converts a one-dimensional array to a 2 dimensional image @@ -50,36 +57,39 @@ def toImage(array, path = "", mode = 16): """ length = np.prod(array.shape) - data = bytearray(array.astype(dtype=np.int8)[:,:,(2,1,0)].reshape(length)) + data = bytearray(array.astype(dtype=np.int8)[:, :, (2, 1, 0)].reshape(length)) height = array.shape[0] width = array.shape[1] # Creating new Row with _create_row(), because Row(name = value, ... ) orders fields by name, # which conflicts with expected ImageSchema order when the new DataFrame is created by UDF - return _create_row(ImageFields, [path, height, width, 3, mode, data]) + return _create_row(ImageFields, [path, height, width, 3, mode, data]) + from pyspark.ml.common import inherit_doc + + @inherit_doc class ImageTransformer(_ImageTransformer): """ - Resizes the image to the given width and height - - Args: - height (int): The height to resize to (>=0) - width (int): The width to resize to (>=0) - + Transformer for common image processing stages. """ - def resize(self, height, width): + def resize(self, size, keep_aspect_ratio=True): """ - Resizes the image to the given width and height - - Args: - height (int): The height to resize to (>=0) - width (int): The width to resize to (>=0) + Resizes the image to the given size. - """ - self._java_obj.resize(height, width) - return self + Args: + size (int or tuple(width, height)): The size to resize to (>=0). + keep_aspect_ratio (bool): Whether to keep aspect ratio. + If true, the shorter side of the image will be resized to the specified size. + """ + if type(size) is tuple: + height, width = (size[1], size[0]) + self._java_obj.resize(height, width) + return self + else: + self._java_obj.resize(size, keep_aspect_ratio) + return self def crop(self, x, y, height, width): """ @@ -93,7 +103,19 @@ def crop(self, x, y, height, width): width (int): The width to crop to (>=0) """ - self._java_obj.crop(x,y,height,width) + self._java_obj.crop(x, y, height, width) + return self + + def centerCrop(self, height, width): + """ + Center crops the image given the width and height. + + Args: + height (int): The height to crop to (>= 0) + width (int): The width to crop to (>= 0) + + """ + self._java_obj.centerCrop(height, width) return self def colorFormat(self, format): @@ -119,39 +141,46 @@ def blur(self, height, width): self._java_obj.blur(height, width) return self - def threshold(self, threshold, maxVal, thresholdType): + def threshold(self, threshold, max_val, threshold_type): """ Thresholds the image, please see OpenCV threshold function documentation for more information Args: threshold: (double) The threshold value - maxVal (double): The maximum value to use - thresholdType (double): The type of threshold, can be binary, binary_inv, trunc, zero, zero_inv + max_val (double): The maximum value to use + threshold_type (double): The type of threshold, can be binary, binary_inv, trunc, zero, zero_inv """ - self._java_obj.threshold(threshold, maxVal, thresholdType) + self._java_obj.threshold(threshold, max_val, threshold_type) return self - def gaussianKernel(self, appertureSize, sigma): + def gaussianKernel(self, aperture_size, sigma): """ Blurs the image by applying a gaussian kernel Args: - appertureSize (double): The aperture size, which should be odd and positive + aperture_size (double): The aperture size, which should be odd and positive sigma (double): The standard deviation of the gaussian """ - self._java_obj.gaussianKernel(appertureSize, sigma) + self._java_obj.gaussianKernel(aperture_size, sigma) return self - """ - Flips the image - :param int flipCode: a flag to specify how to flip the image - - 0 means flipping around the x-axis (up-down) - - positive value (for example, 1) means flipping around y-axis (left-right, default) - - negative value (for example, -1) means flipping around both axes (diagonally) - See OpenCV documentation for details. - """ - def flip(self, flipCode = 1): - self._java_obj.flip(flipCode) + def flip(self, flip_code=1): + """ + Flips the image + :param int flip_code: a flag to specify how to flip the image + - 0 means flipping around the x-axis (up-down) + - positive value (for example, 1) means flipping around y-axis (left-right, default) + - negative value (for example, -1) means flipping around both axes (diagonally) + See OpenCV documentation for details. + """ + self._java_obj.flip(flip_code) + return self + + def normalize(self, mean, std, color_scale_factor): + """ + Normalizes the image by multiplying the color_scale_factor, substracting mean and dividing by std + """ + self._java_obj.normalize(mean, std, color_scale_factor) return self diff --git a/opencv/src/main/scala/com/microsoft/ml/spark/opencv/ImageTransformer.scala b/opencv/src/main/scala/com/microsoft/ml/spark/opencv/ImageTransformer.scala index 0b99092243..9caa45c54f 100644 --- a/opencv/src/main/scala/com/microsoft/ml/spark/opencv/ImageTransformer.scala +++ b/opencv/src/main/scala/com/microsoft/ml/spark/opencv/ImageTransformer.scala @@ -9,14 +9,15 @@ import com.microsoft.ml.spark.core.schema.{BinaryFileSchema, ImageSchemaUtils} import com.microsoft.ml.spark.logging.BasicLogging import org.apache.spark.injections.UDFUtils import org.apache.spark.ml.image.ImageSchema -import org.apache.spark.ml.param.{ParamMap, _} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} -import org.apache.spark.ml.{ImageInjections, Transformer} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable} +import org.apache.spark.ml.{ComplexParamsWritable, ImageInjections, Transformer} import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.opencv.core.{Core, Mat, Rect, Size} +import org.opencv.core._ import org.opencv.imgproc.Imgproc +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer //scalastyle:off field.name @@ -30,23 +31,60 @@ abstract class ImageTransformerStage(params: Map[String, Any]) extends Serializa val stageName: String } +object ImageTransformerStage { + // every stage has a name like "resize", "normalize", "unroll" + val stageNameKey = "action" + + def apply(stage: Map[String, Any]): ImageTransformerStage = { + stage(stageNameKey) match { + case ResizeImage.stageName => new ResizeImage(stage) + case CropImage.stageName => new CropImage(stage) + case ColorFormat.stageName => new ColorFormat(stage) + case Blur.stageName => new Blur(stage) + case Threshold.stageName => new Threshold(stage) + case GaussianKernel.stageName => new GaussianKernel(stage) + case Flip.stageName => new Flip(stage) + case CenterCropImage.stageName => new CenterCropImage(stage) + case unsupported: String => throw new IllegalArgumentException(s"unsupported transformation $unsupported") + } + } +} + /** Resizes the image. The parameters of the ParameterMap are: - * "height" - the height of the image - * "width" + * "height" - the height of the resized image + * "width" - the width of the resized image * "stageName" + * "size" - the shorter side of the resized image if keep aspect ratio is true, otherwise, + * the side length of both height and width. + * "keepAspectRatio" - if true, then the shorter side will be resized to "size" parameter * Please refer to [[http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html#resize OpenCV]] * for more information * * @param params ParameterMap of the parameters */ class ResizeImage(params: Map[String, Any]) extends ImageTransformerStage(params) { - val height: Double = params(ResizeImage.height).asInstanceOf[Int].toDouble - val width: Double = params(ResizeImage.width).asInstanceOf[Int].toDouble override val stageName: String = ResizeImage.stageName override def apply(image: Mat): Mat = { val resized = new Mat() - val sz = new Size(width, height) + + val sz = if (params.isDefinedAt(ResizeImage.size)) { + val specifiedSize = params(ResizeImage.size).asInstanceOf[Int] + if (params(ResizeImage.keepAspectRatio).asInstanceOf[Boolean]) { + val (originalWidth, originalHeight) = (image.width, image.height) + val shorterSize = math.min(originalWidth, originalHeight) + val ratio = 1.0 * specifiedSize / shorterSize + val (targetWidth, targetHeight) = (math.round(ratio * originalWidth), math.round(ratio * originalHeight)) + new Size(targetWidth, targetHeight) + } else { + new Size(specifiedSize, specifiedSize) + } + } else { + val height: Double = params(ResizeImage.height).asInstanceOf[Int].toDouble + val width: Double = params(ResizeImage.width).asInstanceOf[Int].toDouble + new Size(width, height) + } + Imgproc.resize(image, resized, sz) resized } @@ -61,6 +99,8 @@ object ResizeImage { val stageName = "resize" val height = "height" val width = "width" + val size = "size" + val keepAspectRatio = "keepAspectRatio" } /** Crops the image for processing. The parameters are: @@ -93,6 +133,26 @@ object CropImage { val width = "width" } +class CenterCropImage(params: Map[String, Any]) extends ImageTransformerStage(params) { + val height: Int = params(CropImage.height).asInstanceOf[Int] + val width: Int = params(CropImage.width).asInstanceOf[Int] + + override val stageName: String = CenterCropImage.stageName + + override def apply(image: Mat): Mat = { + val (cropWidth, cropHeight) = (math.min(width, image.width), math.min(height, image.height)) + val (midX, midY) = (image.width / 2, image.height / 2) + val rect = new Rect(midX - cropWidth / 2, midY - cropHeight / 2, cropWidth, cropHeight) + new Mat(image, rect) + } +} + +object CenterCropImage { + val stageName = "centercrop" + val height = "height" + val width = "width" +} + /** Converts an image from one color space to another, eg COLOR_BGR2GRAY. Refer to * [[http://docs.opencv.org/2.4/modules/imgproc/doc/miscellaneous_transformations.html#cvtcolor OpenCV]] * for more information. @@ -117,11 +177,11 @@ object ColorFormat { /** Flips the image * - * @param params + * @param params Map of parameters and values */ class Flip(params: Map[String, Any]) extends ImageTransformerStage(params) { - val flipCode = params(Flip.flipCode).asInstanceOf[Int] - override val stageName = Flip.stageName + val flipCode: Int = params(Flip.flipCode).asInstanceOf[Int] + override val stageName: String = Flip.stageName override def apply(image: Mat): Mat = { val dst = new Mat() @@ -143,7 +203,7 @@ object Flip { * The com.microsoft.ml.spark.core.serialize.params are a map of the dimensions of the blurring box. Please refer to * [[http://docs.opencv.org/2.4/modules/imgproc/doc/filtering.html#blur OpenCV]] for more information. * - * @param params + * @param params Map of parameters and values */ class Blur(params: Map[String, Any]) extends ImageTransformerStage(params) { val height: Double = params(Blur.height).asInstanceOf[Double] @@ -167,7 +227,7 @@ object Blur { * [[http://docs.opencv.org/2.4/modules/imgproc/doc/miscellaneous_transformations.html#threshold threshold]] for * more information * - * @param params + * @param params Map of parameters and values */ class Threshold(params: Map[String, Any]) extends ImageTransformerStage(params) { val threshold: Double = params(Threshold.threshold).asInstanceOf[Double] @@ -243,36 +303,111 @@ object ImageTransformer extends DefaultParamsReadable[ImageTransformer] { Row(path, img.height, img.width, img.channels(), img.`type`, ocvBytes) } - /** Apply all OpenCV transformation stages to a single image; unroll the result if needed - * For null inputs or binary files that could not be parsed, return None. - * Break on OpenCV errors. - */ - def process(stages: Seq[ImageTransformerStage], decodeMode: String)(r: Any): Option[Row] = { - - if (r == null) return None - - val decoded = (r, decodeMode) match { - case (row: Row, "binaryfile") => - val path = BinaryFileSchema.getPath(row) - val bytes = BinaryFileSchema.getBytes(row) - - //early return if the image can't be decompressed - ImageInjections.decode(path, bytes).getOrElse(return None).getStruct(0) - case (bytes: Array[Byte], "binary") => - ImageInjections.decode(null, bytes).getOrElse(return None).getStruct(0) - case (row: Row, "image") => - row - case (_, mode) => - throw new MatchError(s"Unknown decoder mode $mode") - } + /** + * Convert Spark image representation to OpenCV format. + */ + def decodeImage(decodeMode: String)(r: Any): Option[(String, Mat)] = { + Option(r).flatMap { + row => + (row, decodeMode) match { + case (row: Row, "binaryfile") => + val path = BinaryFileSchema.getPath(row) + val bytes = BinaryFileSchema.getBytes(row) + //early return if the image can't be decompressed + ImageInjections.decode(path, bytes).map(_.getStruct(0)) + case (bytes: Array[Byte], "binary") => + //noinspection ScalaStyle + ImageInjections.decode(null, bytes).map(_.getStruct(0)) + case (row: Row, "image") => + Some(row) + case (_, mode) => + throw new MatchError(s"Unknown decoder mode $mode") + } + } map row2mat + } - val (path, img) = row2mat(decoded) - val result = stages.foldLeft(img) { + /** + * Apply all OpenCV transformation stages to a single image. Break on OpenCV errors. + */ + def processImage(stages: Seq[ImageTransformerStage])(image: Mat): Mat = { + stages.foldLeft(image) { case (imgInternal, stage) => stage.apply(imgInternal) } - Some(mat2row(result, path)) } + /** + * Extract channels from image. + */ + def extractChannels(channelOrder: String)(image: Mat): Array[Mat] = { + // OpenCV channel order is BGR - reverse the order if the intended order is RGB. + // Also remove alpha channel if nChannels is 4. + val converted = if (image.channels == 4) { + // remove alpha channel and order color channels if necessary + val dest = new Mat(image.rows, image.cols, CvType.CV_8UC3) + val colorConversion = if (channelOrder.toLowerCase == "rgb") Imgproc.COLOR_BGRA2RGB else Imgproc.COLOR_BGRA2BGR + Imgproc.cvtColor(image, dest, colorConversion) + dest + } else if (image.channels == 3 && channelOrder.toLowerCase == "rgb") { + // Reorder channel if nChannel is 3 and intended tensor channel order is RGB. + val dest = new Mat(image.rows, image.cols, CvType.CV_8UC3) + Imgproc.cvtColor(image, dest, Imgproc.COLOR_BGR2RGB) + dest + } else { + image + } + + val channelLength = converted.channels + val channelMats = ListBuffer.fill(channelLength)(Mat.zeros(converted.rows, converted.cols, CvType.CV_8U)) + Core.split(converted, channelMats.asJava) + + channelMats.toArray + } + + /** + * Normalize each channel. + */ + def normalizeChannels(means: Option[Array[Double]], stds: Option[Array[Double]], scaleFactor: Option[Double]) + (channels: Array[Mat]): Array[Mat] = { + val channelLength = channels.length + require(means.forall(channelLength == _.length)) + require(stds.forall(channelLength == _.length)) + + channels + .zip(means.getOrElse(Array.fill(channelLength)(0d))) + .zip(stds.getOrElse(Array.fill(channelLength)(1d))) + .map { + case ((matrix: Mat, m: Double), sd: Double) => + val t = new Mat(matrix.rows, matrix.cols, CvType.CV_64F) + matrix.convertTo(t, CvType.CV_64F) + Core.multiply(t, new Scalar(scaleFactor.getOrElse(1d)), t) // Standardized + Core.subtract(t, new Scalar(m), t) // Centered + Core.divide(t, new Scalar(sd), t) // Normalized + t + } + } + + private def to2DArray(m: Mat): Array[Array[Double]] = { + val array = Array.ofDim[Double](m.rows, m.cols) + array.indices foreach { + i => m.get(i, 0, array(i)) + } + + array + } + + /** + * Convert channel matrices to tensor in the shape of (C * H * W) + */ + def convertToTensor(matrices: Array[Mat]): Array[Array[Array[Double]]] = { + matrices.map(to2DArray) + } + + /** + * Convert from OpenCV format to Dataframe Row. + */ + def encodeImage(path: String, image: Mat): Row = { + mat2row(image, path) + } } /** Image processing stage. Please refer to OpenCV for additional information @@ -280,10 +415,11 @@ object ImageTransformer extends DefaultParamsReadable[ImageTransformer] { * @param uid The id of the module */ class ImageTransformer(val uid: String) extends Transformer - with HasInputCol with HasOutputCol with Wrappable with DefaultParamsWritable with BasicLogging { + with HasInputCol with HasOutputCol with Wrappable with ComplexParamsWritable with BasicLogging { logClass() import ImageTransformer._ + import ImageTransformerStage._ override protected lazy val pyInternalWrapper = true @@ -293,45 +429,153 @@ class ImageTransformer(val uid: String) extends Transformer def setStages(value: Array[Map[String, Any]]): this.type = set(stages, value) - val emptyStages = Array[Map[String, Any]]() + val emptyStages: Array[Map[String, Any]] = Array[Map[String, Any]]() def getStages: Array[Map[String, Any]] = if (isDefined(stages)) $(stages) else emptyStages private def addStage(stage: Map[String, Any]): this.type = set(stages, getStages :+ stage) - setDefault(inputCol -> "image", outputCol -> (uid + "_output")) + val toTensor: BooleanParam = new BooleanParam( + this, + "toTensor", + "Convert output image to tensor in the shape of (C * H * W)" + ) + + def getToTensor: Boolean = $(toTensor) + def setToTensor(value: Boolean): this.type = this.set(toTensor, value) + + @transient + private lazy val validElementTypes: Array[DataType] = Array(FloatType, DoubleType) + val tensorElementType: DataTypeParam = new DataTypeParam( + parent = this, + name = "tensorElementType", + doc = "The element data type for the output tensor. Only used when toTensor is set to true. " + + "Valid values are DoubleType or FloatType. Default value: FloatType.", + isValid = ParamValidators.inArray(validElementTypes) + ) + + def getTensorElementType: DataType = $(tensorElementType) + def setTensorElementType(value: DataType): this.type = this.set(tensorElementType, value) + + val tensorChannelOrder: Param[String] = new Param[String]( + parent = this, + name = "tensorChannelOrder", + doc = "The color channel order of the output channels. Valid values are RGB and GBR. Default: RGB.", + isValid = ParamValidators.inArray(Array("rgb", "RGB", "bgr", "BGR")) + ) + + def getTensorChannelOrder: String = $(tensorChannelOrder) + def setTensorChannelOrder(value: String): this.type = this.set(tensorChannelOrder, value) + + val normalizeMean: DoubleArrayParam = new DoubleArrayParam( + this, + "normalizeMean", + "The mean value to use for normalization for each channel. " + + "The length of the array must match the number of channels of the input image." + ) + + def getNormalizeMean: Array[Double] = $(normalizeMean) + def setNormalizeMean(value: Array[Double]): this.type = this.set(normalizeMean, value) + + val normalizeStd: DoubleArrayParam = new DoubleArrayParam( + this, + "normalizeStd", + "The standard deviation to use for normalization for each channel. " + + "The length of the array must match the number of channels of the input image." + ) + + def getNormalizeStd: Array[Double] = $(normalizeStd) + def setNormalizeStd(value: Array[Double]): this.type = this.set(normalizeStd, value) + + val colorScaleFactor: DoubleParam = new DoubleParam( + this, + "colorScaleFactor", + "The scale factor for color values. Used for normalization. " + + "The color values will be multiplied with the scale factor.", + ParamValidators.gt(0d) + ) + + def getColorScaleFactor: Double = $(colorScaleFactor) + def setColorScaleFactor(value: Double): this.type = this.set(colorScaleFactor, value) + + setDefault( + inputCol -> "image", + outputCol -> (uid + "_output"), + toTensor -> false, + tensorChannelOrder -> "RGB", + tensorElementType -> FloatType + ) + + def normalize(mean: Array[Double], std: Array[Double], colorScaleFactor: Double): this.type = { + this + .setToTensor(true) + .setNormalizeMean(mean) + .setNormalizeStd(std) + .setColorScaleFactor(colorScaleFactor) + } - // every stage has a name like "resize", "normalize", "unroll" - val stageName = "action" + /** + * For py4j invocation. + */ + def normalize(mean: java.util.List[Double], std: java.util.List[Double], colorScaleFactor: Double): this.type = { + this + .setToTensor(true) + .setNormalizeMean(mean.asScala.toArray) + .setNormalizeStd(std.asScala.toArray) + .setColorScaleFactor(colorScaleFactor) + } def resize(height: Int, width: Int): this.type = { - require(width >= 0 && height >= 0, "width and height should be nonnegative") + require(width >= 0 && height >= 0, "width and height should be non-negative") - addStage(Map(stageName -> ResizeImage.stageName, + addStage(Map(stageNameKey -> ResizeImage.stageName, ResizeImage.width -> width, ResizeImage.height -> height)) } + /** + * If keep aspect ratio is set to true, the shorter side of the image will be resized to the specified size. + */ + def resize(size: Int, keepAspectRatio: Boolean): this.type = { + require(size >= 0, "size should be non-negative") + addStage(Map(stageNameKey -> ResizeImage.stageName, + ResizeImage.size -> size, + ResizeImage.keepAspectRatio -> keepAspectRatio + )) + } + def crop(x: Int, y: Int, height: Int, width: Int): this.type = { - require(x >= 0 && y >= 0 && width >= 0 && height >= 0, "crop values should be nonnegative") + require(x >= 0 && y >= 0 && width >= 0 && height >= 0, "crop values should be non-negative") - addStage(Map(stageName -> CropImage.stageName, + addStage(Map(stageNameKey -> CropImage.stageName, CropImage.width -> width, CropImage.height -> height, CropImage.x -> x, CropImage.y -> y)) } + def centerCrop(height: Int, width: Int): this.type = { + require(width >= 0 && height >= 0, "crop values should be non-negative") + + addStage( + Map( + stageNameKey -> CenterCropImage.stageName, + CenterCropImage.width -> width, + CenterCropImage.height -> height + ) + ) + } + def colorFormat(format: Int): this.type = { - addStage(Map(stageName -> ColorFormat.stageName, ColorFormat.format -> format)) + addStage(Map(stageNameKey -> ColorFormat.stageName, ColorFormat.format -> format)) } def blur(height: Double, width: Double): this.type = { - addStage(Map(stageName -> Blur.stageName, Blur.height -> height, Blur.width -> width)) + addStage(Map(stageNameKey -> Blur.stageName, Blur.height -> height, Blur.width -> width)) } def threshold(threshold: Double, maxVal: Double, thresholdType: Int): this.type = { - addStage(Map(stageName -> Threshold.stageName, + addStage(Map(stageNameKey -> Threshold.stageName, Threshold.maxVal -> maxVal, Threshold.threshold -> threshold, Threshold.thresholdType -> thresholdType)) @@ -347,11 +591,11 @@ class ImageTransformer(val uid: String) extends Transformer * @return */ def flip(flipCode: Int): this.type = { - addStage(Map(stageName -> Flip.stageName, Flip.flipCode -> flipCode)) + addStage(Map(stageNameKey -> Flip.stageName, Flip.flipCode -> flipCode)) } def gaussianKernel(apertureSize: Int, sigma: Double): this.type = { - addStage(Map(stageName -> GaussianKernel.stageName, + addStage(Map(stageNameKey -> GaussianKernel.stageName, GaussianKernel.apertureSize -> apertureSize, GaussianKernel.sigma -> sigma)) } @@ -362,41 +606,66 @@ class ImageTransformer(val uid: String) extends Transformer // load native OpenCV library on each partition // TODO: figure out more elegant way val df = OpenCVUtils.loadOpenCV(dataset.toDF) - val decodeMode = df.schema(getInputCol).dataType match { - case s if ImageSchemaUtils.isImage(s) => "image" - case s if BinaryFileSchema.isBinaryFile(s) => "binaryfile" - case s if s == BinaryType => "binary" - case s => - throw new IllegalArgumentException(s"input column should have Image or BinaryFile type, got $s") - + val inputDataType = df.schema(getInputCol).dataType + val decodeMode = getDecodeType(inputDataType) + + val transforms = getStages.map(ImageTransformerStage.apply) + + val outputColumnSchema = if ($(toTensor)) tensorUdfSchema else imageColumnSchema + val processStep = processImage(transforms) _ + val extractStep = extractChannels(getTensorChannelOrder) _ + val normalizeStep = normalizeChannels(get(normalizeMean), get(normalizeStd), get(colorScaleFactor)) _ + val toTensorStep = convertToTensor _ + + val convertFunc = if ($(toTensor)) { + inputRow: Any => + decodeImage(decodeMode)(inputRow) map { + case (_, image) => + processStep + .andThen(extractStep) + .andThen(normalizeStep) + .andThen(toTensorStep) + .apply(image) + } + } else { + inputRow: Any => + decodeImage(decodeMode)(inputRow) map { + case (path, image) => + val encodeStep = encodeImage(path, _) + processStep.andThen(encodeStep).apply(image) + } } - val transforms = ListBuffer[ImageTransformerStage]() - for (stage <- getStages) { - stage(stageName) match { - case ResizeImage.stageName => transforms += new ResizeImage(stage) - case CropImage.stageName => transforms += new CropImage(stage) - case ColorFormat.stageName => transforms += new ColorFormat(stage) - case Blur.stageName => transforms += new Blur(stage) - case Threshold.stageName => transforms += new Threshold(stage) - case GaussianKernel.stageName => transforms += new GaussianKernel(stage) - case Flip.stageName => transforms += new Flip(stage) - case unsupported: String => throw new IllegalArgumentException(s"unsupported transformation $unsupported") - } + val convert = UDFUtils.oldUdf(convertFunc, outputColumnSchema) + if ($(toTensor)) { + df.withColumn(getOutputCol, convert(df(getInputCol)).cast(tensorColumnSchema)) + } else { + df.withColumn(getOutputCol, convert(df(getInputCol))) } - - val convert = UDFUtils.oldUdf(process(transforms, decodeMode = decodeMode) _, ImageSchema.columnSchema) - - df.withColumn(getOutputCol, convert(df(getInputCol))) }) } + private def getDecodeType(inputDataType: DataType): String = { + inputDataType match { + case s if ImageSchemaUtils.isImage(s) => "image" + case s if BinaryFileSchema.isBinaryFile(s) => "binaryfile" + case s if s == BinaryType => "binary" + case s => + throw new IllegalArgumentException(s"input column should have Image or BinaryFile type, got $s") + } + } + override def copy(extra: ParamMap): Transformer = defaultCopy(extra) + private lazy val tensorUdfSchema = ArrayType(ArrayType(ArrayType(DoubleType))) + private lazy val tensorColumnSchema = ArrayType(ArrayType(ArrayType($(tensorElementType)))) + private lazy val imageColumnSchema = ImageSchema.columnSchema override def transformSchema(schema: StructType): StructType = { - schema.add(getOutputCol, ImageSchema.columnSchema) - } + assert(!schema.fieldNames.contains(getOutputCol), s"Input schema already contains output field $getOutputCol") + val outputColumnSchema = if ($(toTensor)) tensorColumnSchema else imageColumnSchema + schema.add(getOutputCol, outputColumnSchema) + } } //scalastyle:on field.name diff --git a/opencv/src/test/scala/com/microsoft/ml/spark/opencv/ImageTransformerSuite.scala b/opencv/src/test/scala/com/microsoft/ml/spark/opencv/ImageTransformerSuite.scala index 62a43aa5e9..554b2d0776 100644 --- a/opencv/src/test/scala/com/microsoft/ml/spark/opencv/ImageTransformerSuite.scala +++ b/opencv/src/test/scala/com/microsoft/ml/spark/opencv/ImageTransformerSuite.scala @@ -3,28 +3,28 @@ package com.microsoft.ml.spark.opencv -import java.awt.GridLayout -import java.nio.file.Paths - import com.microsoft.ml.spark.build.BuildInfo import com.microsoft.ml.spark.core.env.FileUtilities -import com.microsoft.ml.spark.io.IOImplicits._ import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing} import com.microsoft.ml.spark.image.{UnrollBinaryImage, UnrollImage} -import javax.swing._ +import com.microsoft.ml.spark.io.IOImplicits._ import org.apache.hadoop.fs.Path import org.apache.spark.ml.linalg.DenseVector import org.apache.spark.ml.param.DataFrameEquality -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.col -import org.opencv.core.{Mat, MatOfByte} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} +import org.opencv.core.{CvType, Mat, MatOfByte} import org.opencv.imgcodecs.Imgcodecs import org.opencv.imgproc.Imgproc import org.scalactic.Equality -import org.scalatest.Assertion + +import java.awt.GridLayout +import java.io.File +import javax.swing._ trait OpenCVTestUtils { - lazy protected val fileLocation = FileUtilities.join(BuildInfo.datasetDir, "Images", "Grocery") + lazy protected val fileLocation: File = FileUtilities.join(BuildInfo.datasetDir, "Images", "Grocery") protected def selectTestImageBytes(images: DataFrame): Array[Byte] = { images.filter(row => row.getString(4).contains("negative") && row.getString(4).endsWith("5.jpg")) @@ -83,8 +83,8 @@ trait OpenCVTestUtils { class UnrollImageSuite extends TransformerFuzzing[UnrollImage] with OpenCVTestUtils with DataFrameEquality { - lazy val filesRoot = BuildInfo.datasetDir - lazy val imagePath = FileUtilities.join(filesRoot,"Images", "CIFAR").toString + lazy val filesRoot: File = BuildInfo.datasetDir + lazy val imagePath: String = FileUtilities.join(filesRoot,"Images", "CIFAR").toString lazy val images: DataFrame = spark.read.image.load(imagePath) test("roll and unroll") { @@ -130,8 +130,8 @@ class UnrollImageSuite extends TransformerFuzzing[UnrollImage] with OpenCVTestUt class UnrollBinaryImageSuite extends TransformerFuzzing[UnrollBinaryImage] with OpenCVTestUtils with DataFrameEquality { - lazy val filesRoot = BuildInfo.datasetDir - lazy val imagePath = FileUtilities.join(filesRoot, "Images", "CIFAR").toString + lazy val filesRoot: File = BuildInfo.datasetDir + lazy val imagePath: String = FileUtilities.join(filesRoot, "Images", "CIFAR").toString lazy val images: DataFrame = spark.read.image.load(imagePath) lazy val binaryImages: DataFrame = spark.read.binary.load(imagePath) .withColumn("image", col("value.bytes")) @@ -147,14 +147,12 @@ class UnrollBinaryImageSuite extends TransformerFuzzing[UnrollBinaryImage] // This is needed for some small 256!=0 issue in unroll. // It only happens at one place throughout the tests though - override implicit lazy val dvEq: Equality[DenseVector] = new Equality[DenseVector] { - def areEqual(a: DenseVector, b: Any): Boolean = b match { - case bArr: DenseVector => - a.values.zip(bArr.values).map { - case (x, y) if doubleEq.areEqual(x, y) => 0 - case _ => 0 - }.sum <= 1 - } + override implicit lazy val dvEq: Equality[DenseVector] = (a: DenseVector, b: Any) => b match { + case bArr: DenseVector => + a.values.zip(bArr.values).map { + case (x, y) if doubleEq.areEqual(x, y) => 0 + case _ => 0 + }.sum <= 1 } override def testObjects(): Seq[TestObject[UnrollBinaryImage]] = @@ -167,10 +165,11 @@ class ImageTransformerSuite extends TransformerFuzzing[ImageTransformer] with Op //TODO this is needed to stop the build from freezing override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { + //noinspection NameBooleanParameters assert(true) } - lazy val images = spark.read.image.option("dropInvalid",true) + lazy val images: DataFrame = spark.read.image.option("dropInvalid", value = true) .load(FileUtilities.join(fileLocation, "**").toString) test("general workflow") { @@ -185,7 +184,7 @@ class ImageTransformerSuite extends TransformerFuzzing[ImageTransformer] with Op val outSizes = preprocessed.select(preprocessed("out.height"), preprocessed("out.width")).collect outSizes.foreach { row: Row => - assert(row.getInt(0) == 15 && row.getInt(1) == 10, "output images have incorrect size") + assert(row.getInt(0) == 15 && row.getInt(1) == 10, "output images have incorrect size") } val unroll = new UnrollImage() @@ -195,12 +194,12 @@ class ImageTransformerSuite extends TransformerFuzzing[ImageTransformer] with Op unroll.transform(preprocessed) .select("final") .collect().foreach(row => - assert(row.getAs[DenseVector](0).toArray.length == 10 * 15 * 3, "unrolled image is incorrect")) + assert(row.getAs[DenseVector](0).toArray.length == 10 * 15 * 3, "unrolled image is incorrect")) } test("binary file input") { - val binaries = spark.read.binary.load(FileUtilities.join(fileLocation,"**").toString) + val binaries = spark.read.binary.load(FileUtilities.join(fileLocation, "**").toString) assert(binaries.count() == 31) binaries.printSchema() @@ -348,4 +347,106 @@ class ImageTransformerSuite extends TransformerFuzzing[ImageTransformer] with Op .gaussianKernel(20, 10), images)) override def reader: ImageTransformer.type = ImageTransformer + + test("image transformer can convert a 3-channel image to tensor") { + val fileLocation = FileUtilities.join(BuildInfo.datasetDir, "Images", "ImageTransformer") + val image = spark.read.image.load(FileUtilities.join(fileLocation, "red_bgr.png").toString) + + val output = new ImageTransformer() + .setOutputCol("features") + .setToTensor(true) + .normalize(Array(0.5, 0.5, 0.5), Array(1.0, 1.0, 1.0), colorScaleFactor = 1 / 255d) + .setTensorElementType(FloatType) + .transform(image) + + val row = output.select("image.height", "image.width", "image.nChannels", "image.mode", "features").head + + assert(row.getAs[Int]("height") == 500) + assert(row.getAs[Int]("width") == 600) + assert(row.getAs[Int]("nChannels") == 3) + assert(row.getAs[Int]("mode") == 16) + + val tensor = row.getAs[Seq[Seq[Seq[Float]]]]("features") + + val channelRed = tensor.head + assert(channelRed.length == 500) + assert(channelRed.forall(_.length == 600)) + assert(channelRed.flatten.forall(_ == 0.5f)) + + val channelGreen = tensor(1) + assert(channelGreen.length == 500) + assert(channelGreen.forall(_.length == 600)) + assert(channelGreen.flatten.forall(_ == -0.5f)) + + val channelBlue = tensor(2) + assert(channelBlue.length == 500) + assert(channelBlue.forall(_.length == 600)) + assert(channelBlue.flatten.forall(_ == -0.5f)) + } + + test("image transformer can convert a 4-channel image to tensor") { + val fileLocation = FileUtilities.join(BuildInfo.datasetDir, "Images", "ImageTransformer") + val image = spark.read.image.load(FileUtilities.join(fileLocation, "red_bgra.png").toString) + + val output = new ImageTransformer() + .setOutputCol("features") + .setToTensor(true) + .centerCrop(100, 200) + .normalize(Array(0.5, 0.5, 0.5), Array(1.0, 1.0, 1.0), colorScaleFactor = 1 / 255d) + .setTensorElementType(DoubleType) + .transform(image) + + val row = output.select("image.nChannels", "image.mode", "features").head + + assert(row.getAs[Int]("nChannels") == 4) + assert(row.getAs[Int]("mode") == 24) + + val tensor = row.getAs[Seq[Seq[Seq[Double]]]]("features") + + val channelRed = tensor.head + assert(channelRed.length == 100) + assert(channelRed.forall(_.length == 200)) + assert(channelRed.flatten.forall(_ == 0.5d)) + + val channelGreen = tensor(1) + assert(channelGreen.length == 100) + assert(channelGreen.forall(_.length == 200)) + assert(channelGreen.flatten.forall(_ == -0.5d)) + + val channelBlue = tensor(2) + assert(channelBlue.length == 100) + assert(channelBlue.forall(_.length == 200)) + assert(channelBlue.flatten.forall(_ == -0.5d)) + } + + test("image transformer can convert a single-channel (grayscale) image to tensor") { + val fileLocation = FileUtilities.join(BuildInfo.datasetDir, "Images", "ImageTransformer") + val image = spark.read.image.load(FileUtilities.join(fileLocation, "grayscale.jpg").toString) + + val output = new ImageTransformer() + .setOutputCol("features") + .setToTensor(true) + .normalize(Array(0.5), Array(1.0), colorScaleFactor = 1 / 255d) + .setTensorElementType(DoubleType) + .transform(image) + + val row = output.select("image.height", "image.width", "image.nChannels", "image.mode", "features").head + + assert(row.getAs[Int]("height") == 200) + assert(row.getAs[Int]("width") == 256) + assert(row.getAs[Int]("nChannels") == 1) + assert(row.getAs[Int]("mode") == CvType.CV_8UC1) + + val tensor = row.getAs[Seq[Seq[Seq[Double]]]]("features") + assert(tensor.length == 1) + + val channel = tensor.head + assert(channel.length == 200) + assert(channel.forall(_.length == 256)) + assert(channel.flatten.forall(v => v >= -0.5d && v <= 0.5d)) + channel.foreach { + // check for monotonicity, making sure each row is properly ordered. + row => assert((row, row.drop(1)).zipped.forall(_ < _)) + } + } }