Skip to content

Commit

Permalink
Adding params for optimization level and device type
Browse files Browse the repository at this point in the history
  • Loading branch information
memoryz committed Aug 9, 2021
1 parent e71d41e commit c467152
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

package com.microsoft.ml.spark

import org.apache.spark.ml.param.{MapParam, Params}
import org.apache.spark.ml.param.{MapParam, Param, Params}
import spray.json.DefaultJsonProtocol._

trait HasFeedFetchDicts extends Params {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import com.microsoft.ml.spark.core.schema.DatasetExtensions
import com.microsoft.ml.spark.core.utils.BreezeUtils._
import com.microsoft.ml.spark.logging.BasicLogging
import com.microsoft.ml.spark.stages._
import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.injections.UDFUtils
import org.apache.spark.internal.Logging
Expand All @@ -24,10 +23,11 @@ import org.apache.spark.ml.linalg.SQLDataTypes._
import org.apache.spark.ml.linalg.{SQLDataTypes, Vector}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.{SparkContext, TaskContext}
import spray.json.DefaultJsonProtocol._

import java.nio._
Expand Down Expand Up @@ -73,7 +73,33 @@ trait ONNXModelParams extends Params with HasMiniBatcher with HasFeedFetchDicts

def getArgMaxDict: Map[String, String] = get(argMaxDict).getOrElse(Map.empty)

val deviceType: Param[String] = new Param[String](
this,
"deviceType",
"Specify a device type the model inference runs on. Supported types are: CPU or CUDA." +
"If not specified, auto detection will be used.",
ParamValidators.inArray(Array("CPU", "CUDA"))
)

def getDeviceType: String = $(deviceType)

def setDeviceType(value: String): this.type = set(deviceType, value)

val optimizationLevel: Param[String] = new Param[String](
this,
"optimizationLevel",
"Specify the optimization level for the ONNX graph optimizations. Details at " +
"https://onnxruntime.ai/docs/resources/graph-optimizations.html#graph-optimization-levels. " +
"Supported values are: NO_OPT; BASIC_OPT; EXTENDED_OPT; ALL_OPT. Default: ALL_OPT.",
ParamValidators.inArray(Array("NO_OPT", "BASIC_OPT", "EXTENDED_OPT", "ALL_OPT"))
)

def getOptimizationLevel: String = $(optimizationLevel)

def setOptimizationLevel(value: String): this.type = set(optimizationLevel, value)

setDefault(
optimizationLevel -> "ALL_OPT",
miniBatcher -> new FixedMiniBatchTransformer().setBatchSize(10) //scalastyle:ignore magic.number
)
}
Expand Down Expand Up @@ -142,7 +168,10 @@ private class ClosableIterator[+T](delegate: Iterator[T], cleanup: => Unit) exte
* MapInfo defines keyType, valueType and size. It is usually used inside SequenceInfo.
*/
object ONNXModel extends ComplexParamsReadable[ONNXModel] with Logging {
private[onnx] def initializeOrt(modelContent: Array[Byte], ortEnv: OrtEnvironment, gpuDeviceId: Option[Int] = None)
private[onnx] def initializeOrt(modelContent: Array[Byte],
ortEnv: OrtEnvironment,
optLevel: OptLevel = OptLevel.ALL_OPT,
gpuDeviceId: Option[Int] = None)
: OrtSession = {
val options = new SessionOptions()

Expand All @@ -157,7 +186,7 @@ object ONNXModel extends ComplexParamsReadable[ONNXModel] with Logging {
logError(err)
}

options.setOptimizationLevel(OptLevel.BASIC_OPT)
options.setOptimizationLevel(optLevel)
ortEnv.createSession(modelContent, options)
}

Expand Down Expand Up @@ -243,17 +272,24 @@ object ONNXModel extends ComplexParamsReadable[ONNXModel] with Logging {
}
}

//noinspection ScalaStyle
private[onnx] def applyModel(modelPayloadBC: Broadcast[Array[Byte]],
private[onnx] def selectGpuDevice(deviceType: Option[String]): Option[Int] = {
deviceType match {
case None | Some("CUDA") =>
val gpuNum = TaskContext.get().resources().get("gpu").flatMap(_.addresses.map(_.toInt).headOption)
gpuNum
case Some("CPU") =>
None
case _ =>
None
}
}

private[onnx] def applyModel(session: OrtSession,
env: OrtEnvironment,
feedMap: Map[String, String],
fetchMap: Map[String, String],
inputSchema: StructType
)(rows: Iterator[Row]): Iterator[Row] = {

val payload = modelPayloadBC.value
val env = OrtEnvironment.getEnvironment
val session = initializeOrt(payload, env, GpuDeviceId)

val results = rows.map {
row =>
// Each row contains a batch
Expand All @@ -270,7 +306,7 @@ object ONNXModel extends ComplexParamsReadable[ONNXModel] with Logging {
case other =>
throw new NotImplementedError(s"Only tensor input type is supported, but got $other instead.")
}
}.toMap
}

// Run the tensors through the ONNX runtime.
val outputBatches: Seq[Seq[Any]] = using(session.run(inputTensors.asJava)) {
Expand Down Expand Up @@ -352,9 +388,6 @@ object ONNXModel extends ComplexParamsReadable[ONNXModel] with Logging {
case (fromDataType, toDataType) => fromDataType == toDataType
}
}

@transient
lazy val GpuDeviceId: Option[Int] = TaskContext.get().resources().get("gpu").map(_.addresses.head.toInt)
}

class ONNXModel(override val uid: String)
Expand Down Expand Up @@ -424,10 +457,18 @@ class ONNXModel(override val uid: String)
val batchedDF = getMiniBatcher.transform(dataset)
val batchedCache = if (batchedDF.isStreaming) batchedDF else batchedDF.cache().unpersist()
val (coerced, feedDict) = coerceBatchedDf(batchedCache)
val fetchDict = getFetchDict
val modelBc = broadcastedModelPayload.getOrElse(rebroadcastModelPayload(dataset.sparkSession))
val modelFunc = applyModel(modelBc, feedDict, fetchDict, inputSchema) _
val outputDf = coerced.mapPartitions(modelFunc)
val (fetchDicts, devType, optLevel) = (getFetchDict, get(deviceType), OptLevel.valueOf(getOptimizationLevel))
val outputDf = coerced.mapPartitions {
rows =>
val payload = modelBc.value
val taskId = TaskContext.get().taskAttemptId()
val gpuDeviceId = selectGpuDevice(devType)
val env = OrtEnvironment.getEnvironment
logInfo(s"Task:$taskId;DeviceType=$devType;DeviceId=$gpuDeviceId;OptimizationLevel=$optLevel")
val session = initializeOrt(payload, env, optLevel, gpuDeviceId)
applyModel(session, env, feedDict, fetchDicts, inputSchema)(rows)
}

// The cache call is a workaround for GH issue 1075:
// https://github.com/Azure/mmlspark/issues/1075
Expand Down

0 comments on commit c467152

Please sign in to comment.