Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TypedTransformer and TypedEstimator, towards a type-safe Spark ML API #206

Merged
merged 12 commits into from
Dec 7, 2017
13 changes: 13 additions & 0 deletions dataset/src/test/scala/frameless/XN.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,16 @@ object X5 {
implicit def ordering[A: Ordering, B: Ordering, C: Ordering, D: Ordering, E: Ordering]: Ordering[X5[A, B, C, D, E]] =
Ordering.Tuple5[A, B, C, D, E].on(x => (x.a, x.b, x.c, x.d, x.e))
}

case class X6[A, B, C, D, E, F](a: A, b: B, c: C, d: D, e: E, f: F)

object X6 {
implicit def arbitrary[A: Arbitrary, B: Arbitrary, C: Arbitrary, D: Arbitrary, E: Arbitrary, F: Arbitrary]: Arbitrary[X6[A, B, C, D, E, F]] =
Arbitrary(Arbitrary.arbTuple6[A, B, C, D, E, F].arbitrary.map((X6.apply[A, B, C, D, E, F] _).tupled))

implicit def cogen[A, B, C, D, E, F](implicit A: Cogen[A], B: Cogen[B], C: Cogen[C], D: Cogen[D], E: Cogen[E], F: Cogen[F]): Cogen[X6[A, B, C, D, E, F]] =
Cogen.tuple6(A, B, C, D, E, F).contramap(x => (x.a, x.b, x.c, x.d, x.e, x.f))

implicit def ordering[A: Ordering, B: Ordering, C: Ordering, D: Ordering, E: Ordering, F: Ordering]: Ordering[X6[A, B, C, D, E, F]] =
Ordering.Tuple6[A, B, C, D, E, F].on(x => (x.a, x.b, x.c, x.d, x.e, x.f))
}
28 changes: 28 additions & 0 deletions ml/src/main/scala/frameless/ml/TypedEstimator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package frameless
package ml

import frameless.ops.SmartProject
import org.apache.spark.ml.{Estimator, Model}

/**
* A TypedEstimator `fit` method takes as input a TypedDataset containing `Inputs` and
* return an AppendTransformer with `Inputs` as inputs and `Outputs` as outputs
*/
trait TypedEstimator[Inputs, Outputs, M <: Model[M]] {
val estimator: Estimator[M]

def fit[T, F[_]](ds: TypedDataset[T])(
implicit
smartProject: SmartProject[T, Inputs],
F: SparkDelay[F]
): F[AppendTransformer[Inputs, Outputs, M]] = {
implicit val sparkSession = ds.dataset.sparkSession
F.delay {
val inputDs = smartProject.apply(ds)
val model = estimator.fit(inputDs.dataset)
new AppendTransformer[Inputs, Outputs, M] {
val transformer: M = model
}
}
}
}
42 changes: 42 additions & 0 deletions ml/src/main/scala/frameless/ml/TypedTransformer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package frameless
package ml

import frameless.ops.SmartProject
import org.apache.spark.ml.Transformer
import shapeless.{Generic, HList}
import shapeless.ops.hlist.{Prepend, Tupler}

sealed trait TypedTransformer

/**
* An AppendTransformer `transform` method takes as input a TypedDataset containing `Inputs` and
* return a TypedDataset with `Outputs` columns appended to the input TypedDataset.
*/
trait AppendTransformer[Inputs, Outputs, InnerTransformer <: Transformer] extends TypedTransformer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@atamborrino do we need to have our own Append operation here? Shapeless provides an operation that may do exactly this.

Copy link
Contributor Author

@atamborrino atamborrino Nov 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AppendTransformer does a little more than just appending by first checking that T in .transform[T] can be projected to Inputs, and then indeed it uses shapeless' Prepend operation (see i3: Prepend.Aux[TVals, OutputsVals, OutVals]). The type-level append logic is basically the same as in TypedDataset.withColumn.

val transformer: InnerTransformer

def transform[T, TVals <: HList, OutputsVals <: HList, OutVals <: HList, Out, F[_]](ds: TypedDataset[T])(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@atamborrino to the best of my knowledge, when you execute transform, this is a lazy operation from Dataset to Dataset. I don't think we need to enclose it into the effectful SparkDelay. I think transform should just take a TypedDataset[T] and rerun a TypedDataset[Out].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I sampled a few mllib Models and that looks right, although they may throw exceptions (require is used liberally).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@imarios you're right, I had tested that but my test was apparently incorrect... I'll remove the SparkDelay for Transformers.
@iravid hopefully the conditions of the requires are checked at compile-time when building the TypedTransformer.

implicit
i0: SmartProject[T, Inputs],
i1: Generic.Aux[T, TVals],
i2: Generic.Aux[Outputs, OutputsVals],
i3: Prepend.Aux[TVals, OutputsVals, OutVals],
i4: Tupler.Aux[OutVals, Out],
i5: TypedEncoder[Out],
F: SparkDelay[F]
): F[TypedDataset[Out]] = {
implicit val sparkSession = ds.dataset.sparkSession
F.delay {
val transformed = transformer.transform(ds.dataset).as[Out](TypedExpressionEncoder[Out])
TypedDataset.create[Out](transformed)
}
}

}

object AppendTransformer {
// Random name to a temp column added by a TypedTransformer (the proper name will be given by the Tuple-based encoder)
private[ml] val tempColumnName = "I1X3T9CU1OP0128JYIO76TYZZA3AXHQ18RMI"
private[ml] val tempColumnName2 = "I1X3T9CU1OP0128JYIO76TYZZA3AXHQ18RMJ"
private[ml] val tempColumnName3 = "I1X3T9CU1OP0128JYIO76TYZZA3AXHQ18RMK"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package frameless
package ml
package classification

import frameless.ml.internals.TreesInputsChecker
import frameless.ml.params.trees.FeatureSubsetStrategy
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.linalg.Vector

final class TypedRandomForestClassifier[Inputs] private[ml](
rf: RandomForestClassifier,
labelCol: String,
featuresCol: String
) extends TypedEstimator[Inputs, TypedRandomForestClassifier.Outputs, RandomForestClassificationModel] {

val estimator: RandomForestClassifier =
rf
.setLabelCol(labelCol)
.setFeaturesCol(featuresCol)
.setPredictionCol(AppendTransformer.tempColumnName)
.setRawPredictionCol(AppendTransformer.tempColumnName2)
.setProbabilityCol(AppendTransformer.tempColumnName3)

def setNumTrees(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setNumTrees(value))
def setMaxDepth(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMaxDepth(value))
def setMinInfoGain(value: Double): TypedRandomForestClassifier[Inputs] = copy(rf.setMinInfoGain(value))
def setMinInstancesPerNode(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMinInstancesPerNode(value))
def setMaxMemoryInMB(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMaxMemoryInMB(value))
def setSubsamplingRate(value: Double): TypedRandomForestClassifier[Inputs] = copy(rf.setSubsamplingRate(value))
def setFeatureSubsetStrategy(value: FeatureSubsetStrategy): TypedRandomForestClassifier[Inputs] =
copy(rf.setFeatureSubsetStrategy(value.sparkValue))
def setMaxBins(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMaxBins(value))

private def copy(newRf: RandomForestClassifier): TypedRandomForestClassifier[Inputs] =
new TypedRandomForestClassifier[Inputs](newRf, labelCol, featuresCol)
}

object TypedRandomForestClassifier {
case class Outputs(rawPrediction: Vector, probability: Vector, prediction: Double)

def apply[Inputs](implicit inputsChecker: TreesInputsChecker[Inputs]): TypedRandomForestClassifier[Inputs] = {
new TypedRandomForestClassifier(new RandomForestClassifier(), inputsChecker.labelCol, inputsChecker.featuresCol)
}
}

25 changes: 25 additions & 0 deletions ml/src/main/scala/frameless/ml/feature/TypedIndexToString.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package frameless
package ml
package feature

import frameless.ml.internals.UnaryInputsChecker
import org.apache.spark.ml.feature.IndexToString

final class TypedIndexToString[Inputs] private[ml](indexToString: IndexToString, inputCol: String)
extends AppendTransformer[Inputs, TypedIndexToString.Outputs, IndexToString] {

val transformer: IndexToString =
indexToString
.setInputCol(inputCol)
.setOutputCol(AppendTransformer.tempColumnName)

}

object TypedIndexToString {
case class Outputs(originalOutput: String)

def apply[Inputs](labels: Array[String])
(implicit inputsChecker: UnaryInputsChecker[Inputs, Double]): TypedIndexToString[Inputs] = {
new TypedIndexToString[Inputs](new IndexToString().setLabels(labels), inputsChecker.inputCol)
}
}
35 changes: 35 additions & 0 deletions ml/src/main/scala/frameless/ml/feature/TypedStringIndexer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package frameless
package ml
package feature

import frameless.ml.feature.TypedStringIndexer.HandleInvalid
import frameless.ml.internals.UnaryInputsChecker
import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel}

final class TypedStringIndexer[Inputs] private[ml](stringIndexer: StringIndexer, inputCol: String)
extends TypedEstimator[Inputs, TypedStringIndexer.Outputs, StringIndexerModel] {

val estimator: StringIndexer = stringIndexer
.setInputCol(inputCol)
.setOutputCol(AppendTransformer.tempColumnName)

def setHandleInvalid(value: HandleInvalid): TypedStringIndexer[Inputs] = copy(stringIndexer.setHandleInvalid(value.sparkValue))

private def copy(newStringIndexer: StringIndexer): TypedStringIndexer[Inputs] =
new TypedStringIndexer[Inputs](newStringIndexer, inputCol)
}

object TypedStringIndexer {
case class Outputs(indexedOutput: Double)

sealed abstract class HandleInvalid(val sparkValue: String)
object HandleInvalid {
case object Error extends HandleInvalid("error")
case object Skip extends HandleInvalid("skip")
case object Keep extends HandleInvalid("keep")
}

def apply[Inputs](implicit inputsChecker: UnaryInputsChecker[Inputs, String]): TypedStringIndexer[Inputs] = {
new TypedStringIndexer[Inputs](new StringIndexer(), inputsChecker.inputCol)
}
}
66 changes: 66 additions & 0 deletions ml/src/main/scala/frameless/ml/feature/TypedVectorAssembler.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package frameless
package ml
package feature

import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vector
import shapeless.{HList, HNil, LabelledGeneric}
import shapeless.ops.hlist.ToTraversable
import shapeless.ops.record.{Keys, Values}
import shapeless._
import scala.annotation.implicitNotFound

final class TypedVectorAssembler[Inputs] private[ml](vectorAssembler: VectorAssembler, inputCols: Array[String])
extends AppendTransformer[Inputs, TypedVectorAssembler.Output, VectorAssembler] {

val transformer: VectorAssembler = vectorAssembler
.setInputCols(inputCols)
.setOutputCol(AppendTransformer.tempColumnName)

}

object TypedVectorAssembler {
case class Output(vector: Vector)

def apply[Inputs](implicit inputsChecker: TypedVectorAssemblerInputsChecker[Inputs]): TypedVectorAssembler[Inputs] = {
new TypedVectorAssembler(new VectorAssembler(), inputsChecker.inputCols.toArray)
}
}

@implicitNotFound(
msg = "Cannot prove that ${Inputs} is a valid input type. Input type must only contain fields of numeric or boolean types."
)
private[ml] trait TypedVectorAssemblerInputsChecker[Inputs] {
val inputCols: Seq[String]
}

private[ml] object TypedVectorAssemblerInputsChecker {
implicit def checkInputs[Inputs, InputsRec <: HList, InputsKeys <: HList, InputsVals <: HList](
implicit
inputsGen: LabelledGeneric.Aux[Inputs, InputsRec],
inputsKeys: Keys.Aux[InputsRec, InputsKeys],
inputsKeysTraverse: ToTraversable.Aux[InputsKeys, Seq, Symbol],
inputsValues: Values.Aux[InputsRec, InputsVals],
inputsTypeCheck: TypedVectorAssemblerInputsValueChecker[InputsVals]
): TypedVectorAssemblerInputsChecker[Inputs] = new TypedVectorAssemblerInputsChecker[Inputs] {
val inputCols: Seq[String] = inputsKeys.apply.to[Seq].map(_.name)
}
}

private[ml] trait TypedVectorAssemblerInputsValueChecker[InputsVals]

private[ml] object TypedVectorAssemblerInputsValueChecker {
implicit def hnilCheckInputsValue: TypedVectorAssemblerInputsValueChecker[HNil] =
new TypedVectorAssemblerInputsValueChecker[HNil] {}

implicit def hlistCheckInputsValueNumeric[H, T <: HList](
implicit ch: CatalystNumeric[H],
tt: TypedVectorAssemblerInputsValueChecker[T]
): TypedVectorAssemblerInputsValueChecker[H :: T] = new TypedVectorAssemblerInputsValueChecker[H :: T] {}

implicit def hlistCheckInputsValueBoolean[T <: HList](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

codecov say you this case is never in tests

Copy link
Contributor Author

@atamborrino atamborrino Nov 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I forgot to check the property...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@atamborrino isn't there a typeclass already in shapeless for checking that all fields of type are subtypes of another type? This is related to TypedVectorAssemblerInputsValueChecker.

implicit tt: TypedVectorAssemblerInputsValueChecker[T]
): TypedVectorAssemblerInputsValueChecker[Boolean :: T] = new TypedVectorAssemblerInputsValueChecker[Boolean :: T] {}
}


29 changes: 29 additions & 0 deletions ml/src/main/scala/frameless/ml/internals/SelectorByValue.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package frameless
package ml
package internals

import shapeless.labelled.FieldType
import shapeless.{::, DepFn1, HList, Witness}

/**
* Typeclass supporting record selection by value type (returning the first key whose value is of type `Value`)
*/
trait SelectorByValue[L <: HList, Value] extends DepFn1[L] with Serializable { type Out <: Symbol }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be up-streamed to shapeless

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you use it at the type level, you might to remove the apply altogether for this usage.

Copy link
Contributor Author

@atamborrino atamborrino Nov 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, upstreaming this to shapeless was my intention :) (that's why I also added the apply pattern). But I prefer to include it in this PR to avoid having to wait for a new version of shapeless to be published. Once I've did the shapeless PR and it is merged, I'll remove it from there via a new PR.


object SelectorByValue {
type Aux[L <: HList, Value, Out0 <: Symbol] = SelectorByValue[L, Value] { type Out = Out0 }

implicit def select[K <: Symbol, T <: HList, Value](implicit wk: Witness.Aux[K]): Aux[FieldType[K, Value] :: T, Value, K] = {
new SelectorByValue[FieldType[K, Value] :: T, Value] {
type Out = K
def apply(l: FieldType[K, Value] :: T): Out = wk.value
}
}

implicit def recurse[H, T <: HList, Value](implicit st: SelectorByValue[T, Value]): Aux[H :: T, Value, st.Out] = {
new SelectorByValue[H :: T, Value] {
type Out = st.Out
def apply(l: H :: T): Out = st(l.tail)
}
}
}
43 changes: 43 additions & 0 deletions ml/src/main/scala/frameless/ml/internals/TreesInputsChecker.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package frameless.ml.internals

import shapeless.ops.hlist.Length
import shapeless.{HList, LabelledGeneric, Nat, Witness}
import org.apache.spark.ml.linalg._

import scala.annotation.implicitNotFound

/**
* Can be used for all tree-based ML algorithm (decision tree, random forest, gradient-boosted trees)
*/
@implicitNotFound(
msg = "Cannot prove that ${Inputs} is a valid input type." +
"Input type must only contain a field of type Double (label) and a field of type " +
"org.apache.spark.ml.linalg.Vector (features)."
)
trait TreesInputsChecker[Inputs] {
val featuresCol: String
val labelCol: String
}

object TreesInputsChecker {

implicit def checkTreesInputs[
Inputs,
InputsRec <: HList,
LabelK <: Symbol,
FeaturesK <: Symbol](
implicit
i0: LabelledGeneric.Aux[Inputs, InputsRec],
i1: Length.Aux[InputsRec, Nat._2],
i2: SelectorByValue.Aux[InputsRec, Double, LabelK],
i3: Witness.Aux[LabelK],
i4: SelectorByValue.Aux[InputsRec, Vector, FeaturesK],
i5: Witness.Aux[FeaturesK]
): TreesInputsChecker[Inputs] = {
new TreesInputsChecker[Inputs] {
val labelCol: String = implicitly[Witness.Aux[LabelK]].value.name
val featuresCol: String = implicitly[Witness.Aux[FeaturesK]].value.name
}
}

}
31 changes: 31 additions & 0 deletions ml/src/main/scala/frameless/ml/internals/UnaryInputsChecker.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package frameless.ml.internals

import shapeless.ops.hlist.Length
import shapeless.{HList, LabelledGeneric, Nat, Witness}

import scala.annotation.implicitNotFound

/**
* Can be used for all unary transformers (i.e almost all of them)
*/
@implicitNotFound(
msg = "Cannot prove that ${Inputs} is a valid input type. Input type must have only one field of type ${Expected}"
)
trait UnaryInputsChecker[Inputs, Expected] {
val inputCol: String
}

object UnaryInputsChecker {

implicit def checkUnaryInputs[Inputs, Expected, InputsRec <: HList, InputK <: Symbol](
implicit
i0: LabelledGeneric.Aux[Inputs, InputsRec],
i1: Length.Aux[InputsRec, Nat._1],
i2: SelectorByValue.Aux[InputsRec, Expected, InputK],
i3: Witness.Aux[InputK]
): UnaryInputsChecker[Inputs, Expected] = new UnaryInputsChecker[Inputs, Expected] {
val inputCol: String = implicitly[Witness.Aux[InputK]].value.name
}

}

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package frameless
package ml
package params
package trees

sealed abstract class FeatureSubsetStrategy private[ml](val sparkValue: String)
object FeatureSubsetStrategy {
case object Auto extends FeatureSubsetStrategy("auto")
case object All extends FeatureSubsetStrategy("all")
case object OneThird extends FeatureSubsetStrategy("onethird")
case object Sqrt extends FeatureSubsetStrategy("sqrt")
case object Log2 extends FeatureSubsetStrategy("log2")
case class Ratio(value: Double) extends FeatureSubsetStrategy(value.toString)
case class NumberOfFeatures(value: Int) extends FeatureSubsetStrategy(value.toString)
}
Loading