diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBase.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBase.scala index d4f1265db6c..3eb793dd294 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBase.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBase.scala @@ -4,6 +4,7 @@ package com.microsoft.ml.spark.lightgbm import com.microsoft.ml.spark.core.utils.ClusterUtil +import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.SQLDataTypes.VectorType import org.apache.spark.ml.param.shared.{HasFeaturesCol => HasFeaturesColSpark, HasLabelCol => HasLabelColSpark} import org.apache.spark.ml.{Estimator, Model} @@ -14,6 +15,7 @@ import scala.concurrent.Await import scala.concurrent.duration.{Duration, SECONDS} import scala.language.existentials import scala.math.min +import scala.util.matching.Regex trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[TrainedModel] with LightGBMParams with HasFeaturesColSpark with HasLabelColSpark { @@ -156,6 +158,23 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine categoricalSlotIndexesArr, categoricalSlotNamesArr) } + private def validateSlotNames(df: DataFrame, columnParams: ColumnParams, trainParams: TrainParams): Unit = { + val schema = df.schema + val featuresSchema = schema.fields(schema.fieldIndex(getFeaturesCol)) + val metadata = AttributeGroup.fromStructField(featuresSchema) + val slotNamesOpt = TrainUtils.getSlotNames(df.schema, + columnParams.featuresColumn, metadata.attributes.get.length, trainParams) + val pattern = new Regex("[\",:\\[\\]{}]") + slotNamesOpt.foreach(slotNames => { + val badSlotNames = slotNames.flatMap(slotName => + if (pattern.findFirstIn(slotName).isEmpty) None else Option(slotName)) + if (!badSlotNames.isEmpty) { + val errorMsg = s"Invalid slot names detected in features column: ${badSlotNames.mkString(",")}" + throw new IllegalArgumentException(errorMsg) + } + }) + } + /** * Inner train method for LightGBM learners. Calculates the number of workers, * creates a driver thread, and runs mapPartitions on the dataset. @@ -199,6 +218,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine val preprocessedDF = preprocessData(trainingData) val schema = preprocessedDF.schema val columnParams = ColumnParams(getLabelCol, getFeaturesCol, get(weightCol), get(initScoreCol), getOptGroupCol) + validateSlotNames(preprocessedDF, columnParams, trainParams) val mapPartitionsFunc = TrainUtils.trainLightGBM(batchIndex, networkParams, columnParams, validationData, log, trainParams, numTasksPerExec, schema)(_) val lightGBMBooster = diff --git a/src/test/scala/com/microsoft/ml/spark/lightgbm/split2/VerifyLightGBMRegressor.scala b/src/test/scala/com/microsoft/ml/spark/lightgbm/split2/VerifyLightGBMRegressor.scala index da7923a5da7..962ca971fca 100644 --- a/src/test/scala/com/microsoft/ml/spark/lightgbm/split2/VerifyLightGBMRegressor.scala +++ b/src/test/scala/com/microsoft/ml/spark/lightgbm/split2/VerifyLightGBMRegressor.scala @@ -129,6 +129,12 @@ class VerifyLightGBMRegressor extends Benchmarks assert(metric < 0.6) } + test("Verify LightGBM Regressor with bad column names fails early") { + val baseModelWithBadSlots = baseModel.setSlotNames(Range(0, 22).map(i => + "Invalid characters \",:[]{} " + i).toArray) + interceptWithoutLogging[IllegalArgumentException]{baseModelWithBadSlots.fit(flareDF).transform(flareDF).collect()} + } + test("Verify LightGBM Regressor with tweedie distribution") { val model = baseModel.setObjective("tweedie").setTweedieVariancePower(1.5)