Skip to content

Commit

Permalink
feat: add number of threads parameter (#1055)
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft authored Jun 3, 2021
1 parent 63ce4ef commit 2a716c1
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
* @return ExecutionParams object containing parameters related to LightGBM execution.
*/
protected def getExecutionParams(): ExecutionParams = {
ExecutionParams(getChunkSize, getMatrixType)
ExecutionParams(getChunkSize, getMatrixType, getNumThreads)
}

/**
Expand Down
18 changes: 11 additions & 7 deletions src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,22 @@ object LightGBMUtils {
numRowsForChunks
}

def getDatasetParams(trainParams: TrainParams): String = {
val datasetParams = s"max_bin=${trainParams.maxBin} is_pre_partition=True " +
s"bin_construct_sample_cnt=${trainParams.binSampleCount} " +
s"num_threads=${trainParams.executionParams.numThreads} " +
(if (trainParams.categoricalFeatures.isEmpty) ""
else s"categorical_feature=${trainParams.categoricalFeatures.mkString(",")}")
datasetParams
}

def generateDenseDataset(numRows: Int, numCols: Int, featuresArray: doubleChunkedArray,
referenceDataset: Option[LightGBMDataset],
featureNamesOpt: Option[Array[String]],
trainParams: TrainParams, chunkSize: Int): LightGBMDataset = {
val isRowMajor = 1
val datasetOutPtr = lightgbmlib.voidpp_handle()
val datasetParams = s"max_bin=${trainParams.maxBin} is_pre_partition=True " +
s"bin_construct_sample_cnt=${trainParams.binSampleCount} " +
(if (trainParams.categoricalFeatures.isEmpty) ""
else s"categorical_feature=${trainParams.categoricalFeatures.mkString(",")}")
val datasetParams = getDatasetParams(trainParams)
val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64
var data: Option[(SWIGTYPE_p_void, SWIGTYPE_p_double)] = None
val numRowsForChunks = getNumRowsForChunksArray(numRows, chunkSize)
Expand Down Expand Up @@ -270,9 +276,7 @@ object LightGBMUtils {
val numCols = sparseRows(0).size

val datasetOutPtr = lightgbmlib.voidpp_handle()
val datasetParams = s"max_bin=${trainParams.maxBin} is_pre_partition=True " +
(if (trainParams.categoricalFeatures.isEmpty) ""
else s"categorical_feature=${trainParams.categoricalFeatures.mkString(",")}")
val datasetParams = getDatasetParams(trainParams)
// Generate the dataset for features
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromCSRSpark(
sparseRows.asInstanceOf[Array[Object]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ trait LightGBMExecutionParams extends Wrappable {

def getMatrixType: String = $(matrixType)
def setMatrixType(value: String): this.type = set(matrixType, value)

val numThreads = new IntParam(this, "numThreads",
"Number of threads for LightGBM. For the best speed, set this to the number of real CPU cores.")
setDefault(numThreads -> 0)

def getNumThreads: Int = $(numThreads)
def setNumThreads(value: Int): this.type = set(numThreads, value)
}

/** Defines common parameters across all LightGBM learners related to learning score evolution.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ abstract class TrainParams extends Serializable {
s"max_delta_step=$maxDeltaStep min_data_in_leaf=$minDataInLeaf ${objectiveParams.toString()} " +
(if (categoricalFeatures.isEmpty) "" else s"categorical_feature=${categoricalFeatures.mkString(",")} ") +
(if (maxBinByFeature.isEmpty) "" else s"max_bin_by_feature=${maxBinByFeature.mkString(",")} ") +
(if (boostingType == "dart") s"${dartModeParams.toString()}" else "")
(if (boostingType == "dart") s"${dartModeParams.toString()} " else "") +
executionParams.toString()
}
}

Expand Down Expand Up @@ -151,8 +152,14 @@ case class DartModeParams(dropRate: Double, maxDrop: Int, skipDrop: Double,
* @param chunkSize Advanced parameter to specify the chunk size for copying Java data to native.
* @param matrixType Advanced parameter to specify whether the native lightgbm matrix
* constructed should be sparse or dense.
* @param numThreads The number of threads to run the native lightgbm training with on each worker.
*/
case class ExecutionParams(chunkSize: Int, matrixType: String) extends Serializable
case class ExecutionParams(chunkSize: Int, matrixType: String, numThreads: Int) extends Serializable {
override def toString(): String = {
s"num_threads=$numThreads "
}
}


/** Defines parameters related to the lightgbm objective function.
*
Expand Down

0 comments on commit 2a716c1

Please sign in to comment.