Skip to content

Commit

Permalink
Fixing code review issues
Browse files Browse the repository at this point in the history
  • Loading branch information
memoryz committed Aug 3, 2021
1 parent 638b47d commit a36a1ca
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ private[spark] object SlicerFunctions {
Vectors.dense(indices.map(values.apply).map(n.toDouble).toArray)
}

val DataTypeToNumericMap: Map[NumericType, Numeric[_]] = Map(
private val DataTypeToNumericMap: Map[NumericType, Numeric[_]] = Map(
FloatType -> implicitly[Numeric[Float]],
DoubleType -> implicitly[Numeric[Double]],
ByteType -> implicitly[Numeric[Byte]],
Expand Down
12 changes: 6 additions & 6 deletions core/src/main/scala/com/microsoft/ml/spark/lime/LIME.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ trait LIMEBase extends LIMEParams with ComplexParamsWritable {

}

@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3")
object TabularLIME extends ComplexParamsReadable[TabularLIME]

@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3")
class TabularLIME(val uid: String) extends Estimator[TabularLIMEModel]
with LIMEParams with Wrappable with ComplexParamsWritable with BasicLogging {
logClass()
Expand Down Expand Up @@ -200,10 +200,10 @@ class TabularLIME(val uid: String) extends Estimator[TabularLIMEModel]
}
}

@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3")
object TabularLIMEModel extends ComplexParamsReadable[TabularLIMEModel]

@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3")
class TabularLIMEModel(val uid: String) extends Model[TabularLIMEModel]
with LIMEBase with Wrappable with BasicLogging {
logClass()
Expand Down Expand Up @@ -256,15 +256,15 @@ class TabularLIMEModel(val uid: String) extends Model[TabularLIMEModel]

}

@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-rc3")
object ImageLIME extends ComplexParamsReadable[ImageLIME]

/** Distributed implementation of
* Local Interpretable Model-Agnostic Explanations (LIME)
*
* https://arxiv.org/pdf/1602.04938v1.pdf
*/
@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.ImageLIME'.", since="1.0.0-rc3")
class ImageLIME(val uid: String) extends Transformer with LIMEBase
with Wrappable with HasModifier with HasCellSize with BasicLogging {
logClass()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ object TextLIME extends ComplexParamsReadable[TextLIME]
*
* https://arxiv.org/pdf/1602.04938v1.pdf
*/
@deprecated("Please use 'com.microsoft.ml.spark.explainers.TextLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.TextLIME'.", since="1.0.0-rc3")
class TextLIME(val uid: String) extends Model[TextLIME]
with LIMEBase with Wrappable with BasicLogging {
logClass()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import org.apache.spark.ml.param.DataFrameEquality
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.MLReadable

@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.VectorLIME'.", since="1.0.0-rc3")
trait LimeTestBase extends TestBase {

import spark.implicits._
Expand Down Expand Up @@ -42,7 +42,7 @@ trait LimeTestBase extends TestBase {
lazy val limeModel = lime.fit(df)
}

@deprecated("Please use 'com.microsoft.ml.spark.explainers.TabularLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.TabularLIME'.", since="1.0.0-rc3")
class TabularLIMESuite extends EstimatorFuzzing[TabularLIME] with
DataFrameEquality with LimeTestBase {

Expand All @@ -59,7 +59,7 @@ class TabularLIMESuite extends EstimatorFuzzing[TabularLIME] with
override def modelReader: MLReadable[_] = TabularLIMEModel
}

@deprecated("Please use 'com.microsoft.ml.spark.explainers.TextLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.TextLIME'.", since="1.0.0-rc3")
class TabularLIMEModelSuite extends TransformerFuzzing[TabularLIMEModel] with
DataFrameEquality with LimeTestBase {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import org.apache.spark.sql.types.DoubleType
import org.scalactic.Equality
import org.scalatest.Assertion

@deprecated("Please use 'com.microsoft.ml.spark.explainers.TextLIME'.", since="1.0.0-RC3")
@deprecated("Please use 'com.microsoft.ml.spark.explainers.TextLIME'.", since="1.0.0-rc3")
class TextLIMESuite extends TransformerFuzzing[TextLIME] {

import spark.implicits._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.ml.spark
import org.apache.spark.ml.param.{MapParam, Params}
import spray.json.DefaultJsonProtocol._

trait HasFeedFetchMaps extends Params {
trait HasFeedFetchDicts extends Params {
val feedDict: MapParam[String, String] = new MapParam[String, String](
this,
"feedDict",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.ml.spark.cntk
import com.microsoft.CNTK.CNTKExtensions._
import com.microsoft.CNTK.CNTKUtils._
import com.microsoft.CNTK.{CNTKExtensions, DataType => CNTKDataType, SerializableFunction => CNTKFunction, _}
import com.microsoft.ml.spark.HasFeedFetchMaps
import com.microsoft.ml.spark.HasFeedFetchDicts
import com.microsoft.ml.spark.cntk.ConversionUtils.GVV
import com.microsoft.ml.spark.codegen.Wrappable
import com.microsoft.ml.spark.core.schema.DatasetExtensions.findUnusedColumnName
Expand Down Expand Up @@ -145,7 +145,7 @@ private object CNTKModelUtils extends java.io.Serializable {
object CNTKModel extends ComplexParamsReadable[CNTKModel]

class CNTKModel(override val uid: String) extends Model[CNTKModel] with ComplexParamsWritable
with HasMiniBatcher with HasFeedFetchMaps with Wrappable with BasicLogging {
with HasMiniBatcher with HasFeedFetchDicts with Wrappable with BasicLogging {
logClass()

override protected lazy val pyInternalWrapper = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import ai.onnxruntime.OrtSession.SessionOptions
import ai.onnxruntime.OrtSession.SessionOptions.OptLevel
import ai.onnxruntime.{OrtSession, _}
import breeze.linalg.{argmax, softmax, DenseVector => BDV}
import com.microsoft.ml.spark.HasFeedFetchMaps
import com.microsoft.ml.spark.HasFeedFetchDicts
import com.microsoft.ml.spark.codegen.Wrappable
import com.microsoft.ml.spark.core.schema.DatasetExtensions
import com.microsoft.ml.spark.core.utils.BreezeUtils._
Expand All @@ -33,7 +33,7 @@ import scala.collection.JavaConverters._
import scala.jdk.CollectionConverters.mapAsScalaMapConverter
import scala.reflect.ClassTag

trait ONNXModelParams extends Params with HasMiniBatcher with HasFeedFetchMaps {
trait ONNXModelParams extends Params with HasMiniBatcher with HasFeedFetchDicts {
val modelPayload: ByteArrayParam = new ByteArrayParam(
this,
"modelPayload",
Expand All @@ -52,7 +52,9 @@ trait ONNXModelParams extends Params with HasMiniBatcher with HasFeedFetchMaps {
val softMaxDict: MapParam[String, String] = new MapParam[String, String](
this,
"softMaxDict",
" A between output dataframe columns, the value column will be computed from taking the softmax of key column."
"A map between output dataframe columns, where the value column will be computed from taking " +
"the softmax of the key column. If the 'rawPrediction' column contains logits outputs, then one can " +
"set softMaxDict to `Map(\"rawPrediction\" -> \"probability\")` to obtain the probability outputs."
)

def setSoftMaxDict(value: Map[String, String]): this.type = set(softMaxDict, value)
Expand All @@ -64,7 +66,8 @@ trait ONNXModelParams extends Params with HasMiniBatcher with HasFeedFetchMaps {
val argMaxDict: MapParam[String, String] = new MapParam[String, String](
this,
"argMaxDict",
" A between output dataframe columns, the value column will be computed from taking the argmax of key column."
"A map between output dataframe columns, where the value column will be computed from taking " +
"the argmax of the key column. This can be used to convert probability output to predicted label."
)

def setArgMaxDict(value: Map[String, String]): this.type = set(argMaxDict, value)
Expand Down Expand Up @@ -301,10 +304,12 @@ class ONNXModel(override val uid: String)

override def transform(dataset: Dataset[_]): DataFrame = logTransform {
val inputSchema = dataset.schema
val outputSchema = transformValidateSchema(inputSchema, includePostProcessFields = false)
this.validateSchema(inputSchema)

val modelOutputSchema = getModelOutputSchema(inputSchema)

implicit val enc: Encoder[Row] = RowEncoder(
StructType(outputSchema.map(f => StructField(f.name, ArrayType(f.dataType))))
StructType(modelOutputSchema.map(f => StructField(f.name, ArrayType(f.dataType))))
)

val modelPayloadBC = dataset.sparkSession.sparkContext.broadcast(this.getModelPayload)
Expand Down Expand Up @@ -399,11 +404,24 @@ class ONNXModel(override val uid: String)
override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

override def transformSchema(schema: StructType): StructType = {
transformValidateSchema(schema, includePostProcessFields = true)
this.validateSchema(schema)

val modelOutputSchema = getModelOutputSchema(schema)

val softmaxFields = this.getSoftMaxDict.map {
case (inputCol, outputCol) =>
getSoftMaxOutputField(inputCol, outputCol, modelOutputSchema)
}

val argmaxFields = this.getArgMaxDict.map {
case (inputCol, outputCol) =>
getArgMaxOutputField(inputCol, outputCol, modelOutputSchema)
}

(softmaxFields ++ argmaxFields).foldLeft(modelOutputSchema)(_ add _)
}

//noinspection ScalaStyle
private def transformValidateSchema(schema: StructType, includePostProcessFields: Boolean): StructType = {
private def validateSchema(schema: StructType): Unit = {
// Validate that input schema matches with onnx model expected input types.
this.modelInput.foreach {
case (onnxInputName, onnxInputInfo) =>
Expand Down Expand Up @@ -431,7 +449,12 @@ class ONNXModel(override val uid: String)
throw new IllegalArgumentException(s"Output field $colName already exists in the input schema.")
}
}
}

/**
* Gets the output schema from the ONNX model
*/
private def getModelOutputSchema(inputSchema: StructType): StructType = {
// Get ONNX model output cols.
val modelOutputFields = this.getFetchDict.map {
case (colName, onnxOutputName) =>
Expand All @@ -444,24 +467,7 @@ class ONNXModel(override val uid: String)
StructField(colName, dataType)
}

val allFields = schema.fields ++ modelOutputFields

// Include post processing fields if any.
if (includePostProcessFields) {
val softmaxFields = this.getSoftMaxDict.map {
case (inputCol, outputCol) =>
getSoftMaxOutputField(inputCol, outputCol, StructType(allFields))
}

val argmaxFields = this.getArgMaxDict.map {
case (inputCol, outputCol) =>
getArgMaxOutputField(inputCol, outputCol, StructType(allFields))
}

StructType(allFields ++ softmaxFields ++ argmaxFields)
} else {
StructType(allFields)
}
StructType(inputSchema.fields ++ modelOutputFields)
}

private def getSoftMaxOutputField(inputCol: String, outputCol: String, schema: StructType) = {
Expand Down

0 comments on commit a36a1ca

Please sign in to comment.