Skip to content

Commit

Permalink
fix: spark.executor.cores' default value based on master when countin…
Browse files Browse the repository at this point in the history
…g workers (#855)

* fix default values based on master
- https://spark.apache.org/docs/latest/configuration.html

* add unittest for clusterutils

* fix comment and add copyright

Co-authored-by: Ilya Matiach <ilmat@microsoft.com>
  • Loading branch information
Keunhyun Oh and imatiach-msft authored Jun 6, 2020
1 parent 4ae0fe8 commit 64481e9
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 20 deletions.
92 changes: 73 additions & 19 deletions src/main/scala/com/microsoft/ml/spark/core/utils/ClusterUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,91 @@ object ClusterUtil {
*/
def getNumCoresPerExecutor(dataset: Dataset[_], log: Logger): Int = {
val spark = dataset.sparkSession
val confTaskCpus =
try {
val taskCpusConfig = spark.sparkContext.getConf.getOption("spark.task.cpus")
if (taskCpusConfig.isEmpty) {
log.info("ClusterUtils did not detect spark.task.cpus config set, using default 1 instead")
}
taskCpusConfig.getOrElse("1").toInt
} catch {
case _: NoSuchElementException => {
log.info("spark.task.cpus config not set, using default 1 instead")
1
}
}
val confTaskCpus = getTaskCpus(dataset, log)
try {
val confCores = spark.sparkContext.getConf
.get("spark.executor.cores").toInt
val confCores = spark.sparkContext.getConf.get("spark.executor.cores").toInt
val coresPerExec = confCores / confTaskCpus
log.info(s"ClusterUtils calculated num cores per executor as $coresPerExec from $confCores " +
s"cores and $confTaskCpus task CPUs")
coresPerExec
} catch {
case _: NoSuchElementException =>
// If spark.executor.cores is not defined, get the cores per JVM
val numMachineCores = getJVMCPUs(spark)
val coresPerExec = numMachineCores / confTaskCpus
// If spark.executor.cores is not defined, get the cores based on master
val defaultNumCores = getDefaultNumExecutorCores(spark, log)
val coresPerExec = defaultNumCores / confTaskCpus
log.info(s"ClusterUtils calculated num cores per executor as $coresPerExec from " +
s"$numMachineCores machine cores from JVM and $confTaskCpus task CPUs")
s"default num cores($defaultNumCores) from master and $confTaskCpus task CPUs")
coresPerExec
}
}

/** Get number of default cores from sparkSession(required) or master(optional) for 1 executor.
* @param spark The current spark session. If master parameter is not set, the master in the spark session is used.
* @param master This param is needed for unittest. If set, the function return the value for it.
* if not set, basically, master in spark (SparkSession) is used.
* @return The number of default cores per executor based on master.
*/
def getDefaultNumExecutorCores(spark: SparkSession, log: Logger, master: Option[String] = None): Int = {
val masterOpt = master match {
case Some(_) => master
case None =>
try {
val masterConf = spark.sparkContext.getConf.getOption("spark.master")
if (masterConf.isDefined) {
log.info(s"ClusterUtils detected spark.master config (spark.master: ${masterConf.get})")
} else {
log.info("ClusterUtils did not detect spark.master config set")
}

masterConf
} catch {
case _: NoSuchElementException => {
log.info("spark.master config not set")
None
}
}
}

// ref: https://spark.apache.org/docs/latest/configuration.html
if (masterOpt.isEmpty) {
val numMachineCores = getJVMCPUs(spark)
log.info("ClusterUtils did not detect spark.master config set" +
s"So, the number of machine cores($numMachineCores) from JVM is used")
numMachineCores
} else if (masterOpt.get.startsWith("spark://") || masterOpt.get.startsWith("mesos://")) {
// all the available cores on the worker in standalone and Mesos coarse-grained modes
val numMachineCores = getJVMCPUs(spark)
log.info(s"ClusterUtils detected the number of executor cores from $numMachineCores machine cores from JVM" +
s"based on master address")
numMachineCores
} else if (masterOpt.get.startsWith("yarn") || masterOpt.get.startsWith("k8s://")) {
// 1 in YARN mode
log.info(s"ClusterUtils detected 1 as the number of executor cores based on master address")
1
} else {
val numMachineCores = getJVMCPUs(spark)
log.info(s"ClusterUtils did not detect master that has known default value." +
s"So, the number of machine cores($numMachineCores) from JVM is used")
numMachineCores
}
}

def getTaskCpus(dataset: Dataset[_], log: Logger): Int = {
val spark = dataset.sparkSession
try {
val taskCpusConfig = spark.sparkContext.getConf.getOption("spark.task.cpus")
if (taskCpusConfig.isEmpty) {
log.info("ClusterUtils did not detect spark.task.cpus config set, using default 1 instead")
}
taskCpusConfig.getOrElse("1").toInt
} catch {
case _: NoSuchElementException => {
log.info("spark.task.cpus config not set, using default 1 instead")
1
}
}
}

def getDriverHost(dataset: Dataset[_]): String = {
val blockManager = BlockManagerUtils.getBlockManager(dataset)
blockManager.master.getMemoryStatus.toList.flatMap({ case (blockManagerId, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object SparkSessionFactory {
def currentDir(): String = System.getProperty("user.dir")

def getSession(name: String, logLevel: String = "WARN",
numRetries: Int, numCores: Option[Int] = None): SparkSession = {
numRetries: Int = 0, numCores: Option[Int] = None): SparkSession = {
val cores = numCores.map(_.toString).getOrElse("*")
val conf = new SparkConf()
.setAppName(name)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.core.utils

import com.microsoft.ml.spark.core.test.base.{SparkSessionFactory, TestBase}
import org.slf4j.LoggerFactory

class VerifyClusterUtil extends TestBase {
test("Verify ClusterUtil can get default number of executor cores based on master") {
val spark = SparkSessionFactory.getSession("verifyClusterUtil-Session")
val log = LoggerFactory.getLogger("VerifyClusterUtil")

// https://spark.apache.org/docs/latest/configuration.html
assert(ClusterUtil.getDefaultNumExecutorCores(spark, log, Option("yarn")) == 1)
assert(ClusterUtil.getDefaultNumExecutorCores(spark, log, Option("spark://localhost:7077")) ==
ClusterUtil.getJVMCPUs(spark))
}
}

0 comments on commit 64481e9

Please sign in to comment.