-
Notifications
You must be signed in to change notification settings - Fork 138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TypedTransformer and TypedEstimator, towards a type-safe Spark ML API #206
Changes from 7 commits
47ae51b
a8a9fb4
8875a0b
4808fab
a90c2f8
15ff0e3
74c221f
688a436
3d983da
0d4d123
28d915b
905deeb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package frameless | ||
package ml | ||
|
||
import frameless.ops.SmartProject | ||
import org.apache.spark.ml.{Estimator, Model} | ||
|
||
/** | ||
* A TypedEstimator `fit` method takes as input a TypedDataset containing `Inputs` and | ||
* return an AppendTransformer with `Inputs` as inputs and `Outputs` as outputs | ||
*/ | ||
trait TypedEstimator[Inputs, Outputs, M <: Model[M]] { | ||
val estimator: Estimator[M] | ||
|
||
def fit[T, F[_]](ds: TypedDataset[T])( | ||
implicit | ||
smartProject: SmartProject[T, Inputs], | ||
F: SparkDelay[F] | ||
): F[AppendTransformer[Inputs, Outputs, M]] = { | ||
implicit val sparkSession = ds.dataset.sparkSession | ||
F.delay { | ||
val inputDs = smartProject.apply(ds) | ||
val model = estimator.fit(inputDs.dataset) | ||
new AppendTransformer[Inputs, Outputs, M] { | ||
val transformer: M = model | ||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
package frameless | ||
package ml | ||
|
||
import frameless.ops.SmartProject | ||
import org.apache.spark.ml.Transformer | ||
import shapeless.{Generic, HList} | ||
import shapeless.ops.hlist.{Prepend, Tupler} | ||
|
||
sealed trait TypedTransformer | ||
|
||
/** | ||
* An AppendTransformer `transform` method takes as input a TypedDataset containing `Inputs` and | ||
* return a TypedDataset with `Outputs` columns appended to the input TypedDataset. | ||
*/ | ||
trait AppendTransformer[Inputs, Outputs, InnerTransformer <: Transformer] extends TypedTransformer { | ||
val transformer: InnerTransformer | ||
|
||
def transform[T, TVals <: HList, OutputsVals <: HList, OutVals <: HList, Out, F[_]](ds: TypedDataset[T])( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @atamborrino to the best of my knowledge, when you execute transform, this is a lazy operation from Dataset to Dataset. I don't think we need to enclose it into the effectful There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I sampled a few mllib There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
implicit | ||
i0: SmartProject[T, Inputs], | ||
i1: Generic.Aux[T, TVals], | ||
i2: Generic.Aux[Outputs, OutputsVals], | ||
i3: Prepend.Aux[TVals, OutputsVals, OutVals], | ||
i4: Tupler.Aux[OutVals, Out], | ||
i5: TypedEncoder[Out], | ||
F: SparkDelay[F] | ||
): F[TypedDataset[Out]] = { | ||
implicit val sparkSession = ds.dataset.sparkSession | ||
F.delay { | ||
val transformed = transformer.transform(ds.dataset).as[Out](TypedExpressionEncoder[Out]) | ||
TypedDataset.create[Out](transformed) | ||
} | ||
} | ||
|
||
} | ||
|
||
object AppendTransformer { | ||
// Random name to a temp column added by a TypedTransformer (the proper name will be given by the Tuple-based encoder) | ||
private[ml] val tempColumnName = "I1X3T9CU1OP0128JYIO76TYZZA3AXHQ18RMI" | ||
private[ml] val tempColumnName2 = "I1X3T9CU1OP0128JYIO76TYZZA3AXHQ18RMJ" | ||
private[ml] val tempColumnName3 = "I1X3T9CU1OP0128JYIO76TYZZA3AXHQ18RMK" | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
package frameless | ||
package ml | ||
package classification | ||
|
||
import frameless.ml.internals.TreesInputsChecker | ||
import frameless.ml.params.trees.FeatureSubsetStrategy | ||
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} | ||
import org.apache.spark.ml.linalg.Vector | ||
|
||
final class TypedRandomForestClassifier[Inputs] private[ml]( | ||
rf: RandomForestClassifier, | ||
labelCol: String, | ||
featuresCol: String | ||
) extends TypedEstimator[Inputs, TypedRandomForestClassifier.Outputs, RandomForestClassificationModel] { | ||
|
||
val estimator: RandomForestClassifier = | ||
rf | ||
.setLabelCol(labelCol) | ||
.setFeaturesCol(featuresCol) | ||
.setPredictionCol(AppendTransformer.tempColumnName) | ||
.setRawPredictionCol(AppendTransformer.tempColumnName2) | ||
.setProbabilityCol(AppendTransformer.tempColumnName3) | ||
|
||
def setNumTrees(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setNumTrees(value)) | ||
def setMaxDepth(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMaxDepth(value)) | ||
def setMinInfoGain(value: Double): TypedRandomForestClassifier[Inputs] = copy(rf.setMinInfoGain(value)) | ||
def setMinInstancesPerNode(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMinInstancesPerNode(value)) | ||
def setMaxMemoryInMB(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMaxMemoryInMB(value)) | ||
def setSubsamplingRate(value: Double): TypedRandomForestClassifier[Inputs] = copy(rf.setSubsamplingRate(value)) | ||
def setFeatureSubsetStrategy(value: FeatureSubsetStrategy): TypedRandomForestClassifier[Inputs] = | ||
copy(rf.setFeatureSubsetStrategy(value.sparkValue)) | ||
def setMaxBins(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMaxBins(value)) | ||
|
||
private def copy(newRf: RandomForestClassifier): TypedRandomForestClassifier[Inputs] = | ||
new TypedRandomForestClassifier[Inputs](newRf, labelCol, featuresCol) | ||
} | ||
|
||
object TypedRandomForestClassifier { | ||
case class Outputs(rawPrediction: Vector, probability: Vector, prediction: Double) | ||
|
||
def apply[Inputs](implicit inputsChecker: TreesInputsChecker[Inputs]): TypedRandomForestClassifier[Inputs] = { | ||
new TypedRandomForestClassifier(new RandomForestClassifier(), inputsChecker.labelCol, inputsChecker.featuresCol) | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
package frameless | ||
package ml | ||
package feature | ||
|
||
import frameless.ml.internals.UnaryInputsChecker | ||
import org.apache.spark.ml.feature.IndexToString | ||
|
||
final class TypedIndexToString[Inputs] private[ml](indexToString: IndexToString, inputCol: String) | ||
extends AppendTransformer[Inputs, TypedIndexToString.Outputs, IndexToString] { | ||
|
||
val transformer: IndexToString = | ||
indexToString | ||
.setInputCol(inputCol) | ||
.setOutputCol(AppendTransformer.tempColumnName) | ||
|
||
} | ||
|
||
object TypedIndexToString { | ||
case class Outputs(originalOutput: String) | ||
|
||
def apply[Inputs](labels: Array[String]) | ||
(implicit inputsChecker: UnaryInputsChecker[Inputs, Double]): TypedIndexToString[Inputs] = { | ||
new TypedIndexToString[Inputs](new IndexToString().setLabels(labels), inputsChecker.inputCol) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
package frameless | ||
package ml | ||
package feature | ||
|
||
import frameless.ml.feature.TypedStringIndexer.HandleInvalid | ||
import frameless.ml.internals.UnaryInputsChecker | ||
import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel} | ||
|
||
final class TypedStringIndexer[Inputs] private[ml](stringIndexer: StringIndexer, inputCol: String) | ||
extends TypedEstimator[Inputs, TypedStringIndexer.Outputs, StringIndexerModel] { | ||
|
||
val estimator: StringIndexer = stringIndexer | ||
.setInputCol(inputCol) | ||
.setOutputCol(AppendTransformer.tempColumnName) | ||
|
||
def setHandleInvalid(value: HandleInvalid): TypedStringIndexer[Inputs] = copy(stringIndexer.setHandleInvalid(value.sparkValue)) | ||
|
||
private def copy(newStringIndexer: StringIndexer): TypedStringIndexer[Inputs] = | ||
new TypedStringIndexer[Inputs](newStringIndexer, inputCol) | ||
} | ||
|
||
object TypedStringIndexer { | ||
case class Outputs(indexedOutput: Double) | ||
|
||
sealed abstract class HandleInvalid(val sparkValue: String) | ||
object HandleInvalid { | ||
case object Error extends HandleInvalid("error") | ||
case object Skip extends HandleInvalid("skip") | ||
case object Keep extends HandleInvalid("keep") | ||
} | ||
|
||
def apply[Inputs](implicit inputsChecker: UnaryInputsChecker[Inputs, String]): TypedStringIndexer[Inputs] = { | ||
new TypedStringIndexer[Inputs](new StringIndexer(), inputsChecker.inputCol) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
package frameless | ||
package ml | ||
package feature | ||
|
||
import org.apache.spark.ml.feature.VectorAssembler | ||
import org.apache.spark.ml.linalg.Vector | ||
import shapeless.{HList, HNil, LabelledGeneric} | ||
import shapeless.ops.hlist.ToTraversable | ||
import shapeless.ops.record.{Keys, Values} | ||
import shapeless._ | ||
import scala.annotation.implicitNotFound | ||
|
||
final class TypedVectorAssembler[Inputs] private[ml](vectorAssembler: VectorAssembler, inputCols: Array[String]) | ||
extends AppendTransformer[Inputs, TypedVectorAssembler.Output, VectorAssembler] { | ||
|
||
val transformer: VectorAssembler = vectorAssembler | ||
.setInputCols(inputCols) | ||
.setOutputCol(AppendTransformer.tempColumnName) | ||
|
||
} | ||
|
||
object TypedVectorAssembler { | ||
case class Output(vector: Vector) | ||
|
||
def apply[Inputs](implicit inputsChecker: TypedVectorAssemblerInputsChecker[Inputs]): TypedVectorAssembler[Inputs] = { | ||
new TypedVectorAssembler(new VectorAssembler(), inputsChecker.inputCols.toArray) | ||
} | ||
} | ||
|
||
@implicitNotFound( | ||
msg = "Cannot prove that ${Inputs} is a valid input type. Input type must only contain fields of numeric or boolean types." | ||
) | ||
private[ml] trait TypedVectorAssemblerInputsChecker[Inputs] { | ||
val inputCols: Seq[String] | ||
} | ||
|
||
private[ml] object TypedVectorAssemblerInputsChecker { | ||
implicit def checkInputs[Inputs, InputsRec <: HList, InputsKeys <: HList, InputsVals <: HList]( | ||
implicit | ||
inputsGen: LabelledGeneric.Aux[Inputs, InputsRec], | ||
inputsKeys: Keys.Aux[InputsRec, InputsKeys], | ||
inputsKeysTraverse: ToTraversable.Aux[InputsKeys, Seq, Symbol], | ||
inputsValues: Values.Aux[InputsRec, InputsVals], | ||
inputsTypeCheck: TypedVectorAssemblerInputsValueChecker[InputsVals] | ||
): TypedVectorAssemblerInputsChecker[Inputs] = new TypedVectorAssemblerInputsChecker[Inputs] { | ||
val inputCols: Seq[String] = inputsKeys.apply.to[Seq].map(_.name) | ||
} | ||
} | ||
|
||
private[ml] trait TypedVectorAssemblerInputsValueChecker[InputsVals] | ||
|
||
private[ml] object TypedVectorAssemblerInputsValueChecker { | ||
implicit def hnilCheckInputsValue: TypedVectorAssemblerInputsValueChecker[HNil] = | ||
new TypedVectorAssemblerInputsValueChecker[HNil] {} | ||
|
||
implicit def hlistCheckInputsValueNumeric[H, T <: HList]( | ||
implicit ch: CatalystNumeric[H], | ||
tt: TypedVectorAssemblerInputsValueChecker[T] | ||
): TypedVectorAssemblerInputsValueChecker[H :: T] = new TypedVectorAssemblerInputsValueChecker[H :: T] {} | ||
|
||
implicit def hlistCheckInputsValueBoolean[T <: HList]( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. codecov say you this case is never in tests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, I forgot to check the property... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @atamborrino isn't there a typeclass already in shapeless for checking that all fields of type are subtypes of another type? This is related to TypedVectorAssemblerInputsValueChecker. |
||
implicit tt: TypedVectorAssemblerInputsValueChecker[T] | ||
): TypedVectorAssemblerInputsValueChecker[Boolean :: T] = new TypedVectorAssemblerInputsValueChecker[Boolean :: T] {} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
package frameless | ||
package ml | ||
package internals | ||
|
||
import shapeless.labelled.FieldType | ||
import shapeless.{::, DepFn1, HList, Witness} | ||
|
||
/** | ||
* Typeclass supporting record selection by value type (returning the first key whose value is of type `Value`) | ||
*/ | ||
trait SelectorByValue[L <: HList, Value] extends DepFn1[L] with Serializable { type Out <: Symbol } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this could be up-streamed to shapeless There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you use it at the type level, you might to remove the apply altogether for this usage. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, upstreaming this to shapeless was my intention :) (that's why I also added the apply pattern). But I prefer to include it in this PR to avoid having to wait for a new version of shapeless to be published. Once I've did the shapeless PR and it is merged, I'll remove it from there via a new PR. |
||
|
||
object SelectorByValue { | ||
type Aux[L <: HList, Value, Out0 <: Symbol] = SelectorByValue[L, Value] { type Out = Out0 } | ||
|
||
implicit def select[K <: Symbol, T <: HList, Value](implicit wk: Witness.Aux[K]): Aux[FieldType[K, Value] :: T, Value, K] = { | ||
new SelectorByValue[FieldType[K, Value] :: T, Value] { | ||
type Out = K | ||
def apply(l: FieldType[K, Value] :: T): Out = wk.value | ||
} | ||
} | ||
|
||
implicit def recurse[H, T <: HList, Value](implicit st: SelectorByValue[T, Value]): Aux[H :: T, Value, st.Out] = { | ||
new SelectorByValue[H :: T, Value] { | ||
type Out = st.Out | ||
def apply(l: H :: T): Out = st(l.tail) | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package frameless.ml.internals | ||
|
||
import shapeless.ops.hlist.Length | ||
import shapeless.{HList, LabelledGeneric, Nat, Witness} | ||
import org.apache.spark.ml.linalg._ | ||
|
||
import scala.annotation.implicitNotFound | ||
|
||
/** | ||
* Can be used for all tree-based ML algorithm (decision tree, random forest, gradient-boosted trees) | ||
*/ | ||
@implicitNotFound( | ||
msg = "Cannot prove that ${Inputs} is a valid input type." + | ||
"Input type must only contain a field of type Double (label) and a field of type " + | ||
"org.apache.spark.ml.linalg.Vector (features)." | ||
) | ||
trait TreesInputsChecker[Inputs] { | ||
val featuresCol: String | ||
val labelCol: String | ||
} | ||
|
||
object TreesInputsChecker { | ||
|
||
implicit def checkTreesInputs[ | ||
Inputs, | ||
InputsRec <: HList, | ||
LabelK <: Symbol, | ||
FeaturesK <: Symbol]( | ||
implicit | ||
i0: LabelledGeneric.Aux[Inputs, InputsRec], | ||
i1: Length.Aux[InputsRec, Nat._2], | ||
i2: SelectorByValue.Aux[InputsRec, Double, LabelK], | ||
i3: Witness.Aux[LabelK], | ||
i4: SelectorByValue.Aux[InputsRec, Vector, FeaturesK], | ||
i5: Witness.Aux[FeaturesK] | ||
): TreesInputsChecker[Inputs] = { | ||
new TreesInputsChecker[Inputs] { | ||
val labelCol: String = implicitly[Witness.Aux[LabelK]].value.name | ||
val featuresCol: String = implicitly[Witness.Aux[FeaturesK]].value.name | ||
} | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package frameless.ml.internals | ||
|
||
import shapeless.ops.hlist.Length | ||
import shapeless.{HList, LabelledGeneric, Nat, Witness} | ||
|
||
import scala.annotation.implicitNotFound | ||
|
||
/** | ||
* Can be used for all unary transformers (i.e almost all of them) | ||
*/ | ||
@implicitNotFound( | ||
msg = "Cannot prove that ${Inputs} is a valid input type. Input type must have only one field of type ${Expected}" | ||
) | ||
trait UnaryInputsChecker[Inputs, Expected] { | ||
val inputCol: String | ||
} | ||
|
||
object UnaryInputsChecker { | ||
|
||
implicit def checkUnaryInputs[Inputs, Expected, InputsRec <: HList, InputK <: Symbol]( | ||
implicit | ||
i0: LabelledGeneric.Aux[Inputs, InputsRec], | ||
i1: Length.Aux[InputsRec, Nat._1], | ||
i2: SelectorByValue.Aux[InputsRec, Expected, InputK], | ||
i3: Witness.Aux[InputK] | ||
): UnaryInputsChecker[Inputs, Expected] = new UnaryInputsChecker[Inputs, Expected] { | ||
val inputCol: String = implicitly[Witness.Aux[InputK]].value.name | ||
} | ||
|
||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package frameless | ||
package ml | ||
package params | ||
package trees | ||
|
||
sealed abstract class FeatureSubsetStrategy private[ml](val sparkValue: String) | ||
object FeatureSubsetStrategy { | ||
case object Auto extends FeatureSubsetStrategy("auto") | ||
case object All extends FeatureSubsetStrategy("all") | ||
case object OneThird extends FeatureSubsetStrategy("onethird") | ||
case object Sqrt extends FeatureSubsetStrategy("sqrt") | ||
case object Log2 extends FeatureSubsetStrategy("log2") | ||
case class Ratio(value: Double) extends FeatureSubsetStrategy(value.toString) | ||
case class NumberOfFeatures(value: Int) extends FeatureSubsetStrategy(value.toString) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@atamborrino do we need to have our own Append operation here? Shapeless provides an operation that may do exactly this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AppendTransformer
does a little more than just appending by first checking thatT
in.transform[T]
can be projected toInputs
, and then indeed it uses shapeless' Prepend operation (seei3: Prepend.Aux[TVals, OutputsVals, OutVals]
). The type-level append logic is basically the same as inTypedDataset.withColumn
.