Skip to content

Commit ce5a234

Browse files
committed
added parallel validation
1 parent 8eddaba commit ce5a234

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package com.madhukaraphatak.examples.sparktwo.ml
2+
3+
import org.apache.spark.ml.Pipeline
4+
import org.apache.spark.ml.classification.LogisticRegression
5+
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
6+
import org.apache.spark.ml.feature.{OneHotEncoderEstimator, StringIndexer, VectorAssembler}
7+
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
8+
import org.apache.spark.sql.SparkSession
9+
10+
object ParallelCrossValidation {
11+
12+
def main(args: Array[String]): Unit = {
13+
14+
15+
val sparkSession = SparkSession.builder.
16+
master("local[*]")
17+
.appName("example")
18+
.getOrCreate()
19+
20+
21+
val salaryDf = sparkSession.read.format("csv")
22+
.option("header", "true")
23+
.option("inferSchema", "true")
24+
.load("src/main/resources/adult.csv")
25+
26+
val stringColumns = Array("workclass", "occupation", "sex", "education", "martial_status", "relationship",
27+
"race", "native_country")
28+
29+
val numericalColumns = Array("age", "fnlwgt", "capital_loss", "capital_gain")
30+
31+
val labelColumn = "salary"
32+
val outputColumns = stringColumns.map(_ + "_onehot")
33+
34+
val indexers = stringColumns.map(column => {
35+
val indexer = new StringIndexer()
36+
indexer.setInputCol(column)
37+
indexer.setHandleInvalid("keep")
38+
indexer.setOutputCol(column + "_index")
39+
})
40+
41+
val singleOneHotEncoder = new OneHotEncoderEstimator()
42+
singleOneHotEncoder.setInputCols(stringColumns.map(_ + "_index"))
43+
singleOneHotEncoder.setOutputCols(outputColumns)
44+
45+
val vectorAssembler = new VectorAssembler()
46+
vectorAssembler.setInputCols(outputColumns ++ numericalColumns)
47+
vectorAssembler.setOutputCol("features")
48+
49+
val labelIndexer = new StringIndexer()
50+
labelIndexer.setInputCol("salary")
51+
labelIndexer.setOutputCol("label")
52+
53+
val logisticRegression = new LogisticRegression()
54+
55+
56+
val pipeline = new Pipeline()
57+
pipeline.setStages(indexers ++ Array(singleOneHotEncoder)
58+
++ Array(vectorAssembler) ++ Array(labelIndexer) ++ Array(logisticRegression))
59+
60+
val paramMap = new ParamGridBuilder()
61+
.addGrid(logisticRegression.maxIter, Array(1, 2, 3)).build()
62+
63+
64+
val crossValidator = new CrossValidator()
65+
crossValidator.setEstimator(pipeline)
66+
crossValidator.setEvaluator(new BinaryClassificationEvaluator())
67+
crossValidator.setEstimatorParamMaps(paramMap)
68+
crossValidator.setParallelism(3)
69+
70+
crossValidator.fit(salaryDf)
71+
72+
}
73+
74+
}

0 commit comments

Comments
 (0)