Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: spark.executor.cores' default value based on master when counting workers #855

Merged
merged 7 commits into from
Jun 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,91 @@ object ClusterUtil {
*/
def getNumCoresPerExecutor(dataset: Dataset[_], log: Logger): Int = {
val spark = dataset.sparkSession
val confTaskCpus =
try {
val taskCpusConfig = spark.sparkContext.getConf.getOption("spark.task.cpus")
if (taskCpusConfig.isEmpty) {
log.info("ClusterUtils did not detect spark.task.cpus config set, using default 1 instead")
}
taskCpusConfig.getOrElse("1").toInt
} catch {
case _: NoSuchElementException => {
log.info("spark.task.cpus config not set, using default 1 instead")
1
}
}
val confTaskCpus = getTaskCpus(dataset, log)
try {
val confCores = spark.sparkContext.getConf
.get("spark.executor.cores").toInt
val confCores = spark.sparkContext.getConf.get("spark.executor.cores").toInt
val coresPerExec = confCores / confTaskCpus
log.info(s"ClusterUtils calculated num cores per executor as $coresPerExec from $confCores " +
s"cores and $confTaskCpus task CPUs")
coresPerExec
} catch {
case _: NoSuchElementException =>
// If spark.executor.cores is not defined, get the cores per JVM
val numMachineCores = getJVMCPUs(spark)
val coresPerExec = numMachineCores / confTaskCpus
// If spark.executor.cores is not defined, get the cores based on master
val defaultNumCores = getDefaultNumExecutorCores(spark, log)
val coresPerExec = defaultNumCores / confTaskCpus
log.info(s"ClusterUtils calculated num cores per executor as $coresPerExec from " +
s"$numMachineCores machine cores from JVM and $confTaskCpus task CPUs")
s"default num cores($defaultNumCores) from master and $confTaskCpus task CPUs")
coresPerExec
}
}

/** Get number of default cores from sparkSession(required) or master(optional) for 1 executor.
* @param spark The current spark session. If master parameter is not set, the master in the spark session is used.
* @param master This param is needed for unittest. If set, the function return the value for it.
* if not set, basically, master in spark (SparkSession) is used.
* @return The number of default cores per executor based on master.
*/
def getDefaultNumExecutorCores(spark: SparkSession, log: Logger, master: Option[String] = None): Int = {
val masterOpt = master match {
case Some(_) => master
case None =>
try {
val masterConf = spark.sparkContext.getConf.getOption("spark.master")
if (masterConf.isDefined) {
log.info(s"ClusterUtils detected spark.master config (spark.master: ${masterConf.get})")
} else {
log.info("ClusterUtils did not detect spark.master config set")
}

masterConf
} catch {
case _: NoSuchElementException => {
log.info("spark.master config not set")
None
}
}
}

// ref: https://spark.apache.org/docs/latest/configuration.html
if (masterOpt.isEmpty) {
val numMachineCores = getJVMCPUs(spark)
log.info("ClusterUtils did not detect spark.master config set" +
s"So, the number of machine cores($numMachineCores) from JVM is used")
numMachineCores
} else if (masterOpt.get.startsWith("spark://") || masterOpt.get.startsWith("mesos://")) {
// all the available cores on the worker in standalone and Mesos coarse-grained modes
val numMachineCores = getJVMCPUs(spark)
log.info(s"ClusterUtils detected the number of executor cores from $numMachineCores machine cores from JVM" +
s"based on master address")
numMachineCores
} else if (masterOpt.get.startsWith("yarn") || masterOpt.get.startsWith("k8s://")) {
// 1 in YARN mode
log.info(s"ClusterUtils detected 1 as the number of executor cores based on master address")
1
} else {
val numMachineCores = getJVMCPUs(spark)
log.info(s"ClusterUtils did not detect master that has known default value." +
s"So, the number of machine cores($numMachineCores) from JVM is used")
numMachineCores
}
}

def getTaskCpus(dataset: Dataset[_], log: Logger): Int = {
val spark = dataset.sparkSession
try {
val taskCpusConfig = spark.sparkContext.getConf.getOption("spark.task.cpus")
if (taskCpusConfig.isEmpty) {
log.info("ClusterUtils did not detect spark.task.cpus config set, using default 1 instead")
}
taskCpusConfig.getOrElse("1").toInt
} catch {
case _: NoSuchElementException => {
log.info("spark.task.cpus config not set, using default 1 instead")
1
}
}
}

def getDriverHost(dataset: Dataset[_]): String = {
val blockManager = BlockManagerUtils.getBlockManager(dataset)
blockManager.master.getMemoryStatus.toList.flatMap({ case (blockManagerId, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object SparkSessionFactory {
def currentDir(): String = System.getProperty("user.dir")

def getSession(name: String, logLevel: String = "WARN",
numRetries: Int, numCores: Option[Int] = None): SparkSession = {
numRetries: Int = 0, numCores: Option[Int] = None): SparkSession = {
val cores = numCores.map(_.toString).getOrElse("*")
val conf = new SparkConf()
.setAppName(name)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.core.utils

import com.microsoft.ml.spark.core.test.base.{SparkSessionFactory, TestBase}
import org.slf4j.LoggerFactory

class VerifyClusterUtil extends TestBase {
test("Verify ClusterUtil can get default number of executor cores based on master") {
val spark = SparkSessionFactory.getSession("verifyClusterUtil-Session")
val log = LoggerFactory.getLogger("VerifyClusterUtil")

// https://spark.apache.org/docs/latest/configuration.html
assert(ClusterUtil.getDefaultNumExecutorCores(spark, log, Option("yarn")) == 1)
assert(ClusterUtil.getDefaultNumExecutorCores(spark, log, Option("spark://localhost:7077")) ==
ClusterUtil.getJVMCPUs(spark))
}
}