Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add featuresShapCol to LightGBMClassifierModel #863

Merged
merged 9 commits into from
May 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ protected class BoosterHandler(model: String) {
val shapDataOutPtr: ThreadLocal[DoubleNativePtrHandler] = {
new ThreadLocal[DoubleNativePtrHandler] {
override def initialValue(): DoubleNativePtrHandler = {
new DoubleNativePtrHandler(lightgbmlib.new_doubleArray(numFeatures))
new DoubleNativePtrHandler(lightgbmlib.new_doubleArray(numFeatures + 1))
}
}
}
Expand All @@ -89,7 +89,7 @@ protected class BoosterHandler(model: String) {
new ThreadLocal[LongLongNativePtrHandler] {
override def initialValue(): LongLongNativePtrHandler = {
val dataLongLengthPtr = lightgbmlib.new_int64_tp()
lightgbmlib.int64_tp_assign(dataLongLengthPtr, numFeatures)
lightgbmlib.int64_tp_assign(dataLongLengthPtr, (numFeatures + 1))
new LongLongNativePtrHandler(dataLongLengthPtr)
}
}
Expand Down Expand Up @@ -330,7 +330,7 @@ class LightGBMBooster(val model: String) extends Serializable {
}

private def shapToArray(shapDataOutPtr: SWIGTYPE_p_double): Array[Double] = {
(0 until numFeatures).map(featNum =>
(0 until (numFeatures + 1)).map(featNum =>
lightgbmlib.doubleArray_getitem(shapDataOutPtr, featNum)).toArray
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ object LightGBMClassifier extends DefaultParamsReadable[LightGBMClassifier]
class LightGBMClassifier(override val uid: String)
extends ProbabilisticClassifier[Vector, LightGBMClassifier, LightGBMClassificationModel]
with LightGBMBase[LightGBMClassificationModel]
with HasLeafPredictionCol {
with HasLeafPredictionCol with HasFeaturesShapCol {
def this() = this(Identifiable.randomUID("LightGBMClassifier"))

// Set default objective to be binary classification
Expand Down Expand Up @@ -57,7 +57,7 @@ class LightGBMClassifier(override val uid: String)
def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMClassificationModel = {
val classifierTrainParams = trainParams.asInstanceOf[ClassifierTrainParams]
new LightGBMClassificationModel(uid, lightGBMBooster, getLabelCol, getFeaturesCol,
getPredictionCol, getProbabilityCol, getRawPredictionCol, getLeafPredictionCol,
getPredictionCol, getProbabilityCol, getRawPredictionCol, getLeafPredictionCol, getFeaturesShapCol,
if (isDefined(thresholds)) Some(getThresholds) else None, classifierTrainParams.numClass)
}

Expand Down Expand Up @@ -86,16 +86,27 @@ trait HasLeafPredictionCol extends Params {
def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
}

trait HasFeaturesShapCol extends Params {
val featuresShapCol = new Param[String](this, "featuresShapCol",
"Output SHAP vector column name after prediction containing the feature contribution values")
setDefault(featuresShapCol -> "")

def getFeaturesShapCol: String = $(featuresShapCol)
def setFeaturesShapCol(value: String): this.type = set(featuresShapCol, value)
}

/** Model produced by [[LightGBMClassifier]]. */
@InternalWrapper
class LightGBMClassificationModel(
override val uid: String, override val model: LightGBMBooster, labelColName: String,
featuresColName: String, predictionColName: String, probColName: String,
rawPredictionColName: String, leafIndicesColName: String, thresholdValues: Option[Array[Double]],
rawPredictionColName: String, leafPredictionColName: String, featuresShapColName: String,
thresholdValues: Option[Array[Double]],
actualNumClasses: Int)
extends ProbabilisticClassificationModel[Vector, LightGBMClassificationModel]
with HasFeatureImportanceGetters
with HasLeafPredictionCol
with HasFeaturesShapCol
with ConstructorWritable[LightGBMClassificationModel] {

// Update the underlying Spark ML com.microsoft.ml.spark.core.serialize.params
Expand All @@ -105,7 +116,8 @@ class LightGBMClassificationModel(
set(predictionCol, predictionColName)
set(probabilityCol, probColName)
set(rawPredictionCol, rawPredictionColName)
set(leafPredictionCol, leafIndicesColName)
set(leafPredictionCol, leafPredictionColName)
set(featuresShapCol, featuresShapColName)

if (thresholdValues.isDefined) set(thresholds, thresholdValues.get)

Expand Down Expand Up @@ -139,14 +151,7 @@ class LightGBMClassificationModel(
numColsOutput += 1
}
if (getPredictionCol.nonEmpty) {
val predUDF = if (getRawPredictionCol.nonEmpty && !isDefined(thresholds)) {
// Note: Only call raw2prediction if thresholds not defined
udf(raw2prediction _).apply(col(getRawPredictionCol))
} else if (getProbabilityCol.nonEmpty) {
udf(probability2prediction _).apply(col(getProbabilityCol))
} else {
udf(predict _).apply(col(getFeaturesCol))
}
val predUDF = predictColumn
outputData = outputData.withColumn(getPredictionCol, predUDF)
numColsOutput += 1
}
Expand All @@ -155,6 +160,11 @@ class LightGBMClassificationModel(
outputData = outputData.withColumn(getLeafPredictionCol, predLeafUDF(col(getFeaturesCol)))
numColsOutput += 1
}
if (getFeaturesShapCol.nonEmpty) {
val featureShapUDF = udf(featuresShap _)
outputData = outputData.withColumn(getFeaturesShapCol, featureShapUDF(col(getFeaturesCol)))
numColsOutput += 1
}

if (numColsOutput == 0) {
this.logWarning(s"$uid: LightGBMClassificationModel.transform() was called as NOOP" +
Expand All @@ -180,19 +190,35 @@ class LightGBMClassificationModel(

override def copy(extra: ParamMap): LightGBMClassificationModel =
new LightGBMClassificationModel(uid, model, labelColName, featuresColName, predictionColName, probColName,
rawPredictionColName, leafIndicesColName, thresholdValues, actualNumClasses)
rawPredictionColName, leafPredictionColName, featuresShapColName, thresholdValues, actualNumClasses)

override val ttag: TypeTag[LightGBMClassificationModel] =
typeTag[LightGBMClassificationModel]

override def objectsToSave: List[Any] =
List(uid, model, getLabelCol, getFeaturesCol, getPredictionCol,
getProbabilityCol, getRawPredictionCol, getLeafPredictionCol, thresholdValues, actualNumClasses)
getProbabilityCol, getRawPredictionCol, getLeafPredictionCol,
getFeaturesShapCol, thresholdValues, actualNumClasses)

protected def predictColumn: Column = {
if (getRawPredictionCol.nonEmpty && !isDefined(thresholds)) {
// Note: Only call raw2prediction if thresholds not defined
udf(raw2prediction _).apply(col(getRawPredictionCol))
} else if (getProbabilityCol.nonEmpty) {
udf(probability2prediction _).apply(col(getProbabilityCol))
} else {
udf(predict _).apply(col(getFeaturesCol))
}
}

protected def predictLeaf(features: Vector): Vector = {
Vectors.dense(model.predictLeaf(features))
}

protected def featuresShap(features: Vector): Vector = {
Vectors.dense(model.featuresShap(features))
}

def saveNativeModel(filename: String, overwrite: Boolean): Unit = {
val session = SparkSession.builder().getOrCreate()
model.saveNativeModel(session, filename, overwrite)
Expand All @@ -206,26 +232,30 @@ object LightGBMClassificationModel extends ConstructorReadable[LightGBMClassific
featuresColName: String = "features", predictionColName: String = "prediction",
probColName: String = "probability",
rawPredictionColName: String = "rawPrediction",
leafPredictionColName: String = "leafPrediction"): LightGBMClassificationModel = {
leafPredictionColName: String = "leafPrediction",
featuresShapColName: String = "featuresShap"): LightGBMClassificationModel = {
val uid = Identifiable.randomUID("LightGBMClassifier")
val session = SparkSession.builder().getOrCreate()
val textRdd = session.read.text(filename)
val text = textRdd.collect().map { row => row.getString(0) }.mkString("\n")
val lightGBMBooster = new LightGBMBooster(text)
val actualNumClasses = lightGBMBooster.numClasses
new LightGBMClassificationModel(uid, lightGBMBooster, labelColName, featuresColName,
predictionColName, probColName, rawPredictionColName, leafPredictionColName, None, actualNumClasses)
predictionColName, probColName, rawPredictionColName, leafPredictionColName, featuresShapColName,
None, actualNumClasses)
}

def loadNativeModelFromString(model: String, labelColName: String = "label",
featuresColName: String = "features", predictionColName: String = "prediction",
probColName: String = "probability",
rawPredictionColName: String = "rawPrediction",
leafPredictionColName: String = "leafPrediction"): LightGBMClassificationModel = {
leafPredictionColName: String = "leafPrediction",
featuresShapColName: String = "featuresShap"): LightGBMClassificationModel = {
val uid = Identifiable.randomUID("LightGBMClassifier")
val lightGBMBooster = new LightGBMBooster(model)
val actualNumClasses = lightGBMBooster.numClasses
new LightGBMClassificationModel(uid, lightGBMBooster, labelColName, featuresColName,
predictionColName, probColName, rawPredictionColName, leafPredictionColName, None, actualNumClasses)
predictionColName, probColName, rawPredictionColName, leafPredictionColName, featuresShapColName,
None, actualNumClasses)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ trait LightGBMTestUtils extends TestBase {
def assertFeatureShapLengths(fitModel: Model[_] with HasFeatureShapGetters, features: Vector, df: DataFrame): Unit = {
val shapLength = fitModel.getFeatureShaps(features).length
val featuresLength = df.select(featuresCol).first().getAs[Vector](featuresCol).size
assert(shapLength == featuresLength)
assert(shapLength == featuresLength + 1)
}

lazy val numPartitions = 2
Expand All @@ -126,6 +126,7 @@ trait LightGBMTestUtils extends TestBase {
val labelCol = "labels"
val rawPredCol = "rawPrediction"
val leafPredCol = "leafPrediction"
val featuresShapCol = "featuresShap"
val initScoreCol = "initScore"
val predCol = "prediction"
val probCol = "probability"
Expand Down Expand Up @@ -219,6 +220,7 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
.setObjective(binaryObjective)
.setLabelCol(labelCol)
.setLeafPredictionCol(leafPredCol)
.setFeaturesShapCol(featuresShapCol)
}

test("Verify LightGBM Classifier can be run with TrainValidationSplit") {
Expand Down Expand Up @@ -271,7 +273,7 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
val convertUDF = udf((vector: DenseVector) => vector(1))
val scoredDF1 = baseModel.fit(pimaDF).transform(pimaDF)
val df2 = scoredDF1.withColumn(initScoreCol, convertUDF(col(rawPredCol)))
.drop(predCol, rawPredCol, probCol, leafPredCol)
.drop(predCol, rawPredCol, probCol, leafPredCol, featuresShapCol)
val scoredDF2 = baseModel.setInitScoreCol(initScoreCol).fit(df2).transform(df2)

assertBinaryImprovement(scoredDF1, scoredDF2)
Expand All @@ -280,7 +282,7 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
ignore("Verify LightGBM Multiclass Classifier with vector initial score") {
val scoredDF1 = baseModel.fit(breastTissueDF).transform(breastTissueDF)
val df2 = scoredDF1.withColumn(initScoreCol, col(rawPredCol))
.drop(predCol, rawPredCol, probCol, leafPredCol)
.drop(predCol, rawPredCol, probCol, leafPredCol, featuresShapCol)
val scoredDF2 = baseModel.setInitScoreCol(initScoreCol).fit(df2).transform(df2)

assertMulticlassImprovement(scoredDF1, scoredDF2)
Expand Down Expand Up @@ -416,6 +418,25 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
assert(!evaluatedDf2.columns.contains(leafPredCol))
}

test("Verify LightGBM Classifier features shap") {
val Array(train, test) = indexedBankTrainDF.randomSplit(Array(0.8, 0.2), seed)
val untrainedModel = baseModel
.setCategoricalSlotNames(indexedBankTrainDF.columns.filter(_.startsWith("c_")))
val model = untrainedModel.fit(train)

val evaluatedDf = model.transform(test)

val featuresShap: Array[Double] = evaluatedDf.select(featuresShapCol).rdd.map {
case Row(v: Vector) => v
}.first.toArray

assert(featuresShap.length == (model.getModel.numFeatures + 1))

// if featuresShap is not wanted, it is possible to remove it.
val evaluatedDf2 = model.setFeaturesShapCol("").transform(test)
assert(!evaluatedDf2.columns.contains(featuresShapCol))
}

test("Verify LightGBM Classifier with slot names parameter") {

val originalDf = readCSV(DatasetUtils.binaryTrainFile("PimaIndian.csv").toString).repartition(numPartitions)
Expand Down