Skip to content

Commit ad234de

Browse files
committed
Enable/disable caching of partition lengths with a configuration variable.
Signed-off-by: Pascal Spörri <psp@zurich.ibm.com>
1 parent 83a59b3 commit ad234de

File tree

6 files changed

+32
-16
lines changed

6 files changed

+32
-16
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ These configuration values need to be passed to Spark to load and configure the
3838

3939
Changing these values might have an impact on performance.
4040

41+
- `spark.shuffle.s3.cachePartitionLengths`: Cache partition lengths in memory (default: `true`)
42+
- `spark.shuffle.s3.cacheChecksums`: Cache checksums in memory (default: `true`)
4143
- `spark.shuffle.s3.cleanup`: Cleanup the shuffle files (default: `true`)
4244
- `spark.shuffle.s3.folderPrefixes`: The number of prefixes to use when storing files on S3
4345
(default: `10`, minimum: `1`).

src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class S3ShuffleDispatcher extends Logging {
3333
private val isS3A = rootDir.startsWith("s3a://")
3434

3535
// Optional
36+
val cachePartitionLengths: Boolean = conf.getBoolean("spark.shuffle.s3.cachePartitionLengths", defaultValue = true)
37+
val cacheChecksums: Boolean = conf.getBoolean("spark.shuffle.s3.cacheChecksums", defaultValue = true)
3638
val cleanupShuffleFiles: Boolean = conf.getBoolean("spark.shuffle.s3.cleanup", defaultValue = true)
3739
val folderPrefixes: Int = conf.getInt("spark.shuffle.s3.folderPrefixes", defaultValue = 10)
3840
val prefetchBatchSize: Int = conf.getInt("spark.shuffle.s3.prefetchBatchSize", defaultValue = 25)
@@ -57,6 +59,8 @@ class S3ShuffleDispatcher extends Logging {
5759
logInfo(s"- spark.shuffle.s3.rootDir=${rootDir} (app dir: ${appDir})")
5860

5961
// Optional
62+
logInfo(s"- spark.shuffle.s3.cachePartitionLengths=${cachePartitionLengths}")
63+
logInfo(s"- spark.shuffle.s3.cacheChecksums=${cacheChecksums}")
6064
logInfo(s"- spark.shuffle.s3.cleanup=${cleanupShuffleFiles}")
6165
logInfo(s"- spark.shuffle.s3.folderPrefixes=${folderPrefixes}")
6266
logInfo(s"- spark.shuffle.s3.prefetchBlockSize=${prefetchBatchSize}")

src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleHelper.scala

+21-12
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,15 @@ object S3ShuffleHelper extends Logging {
2727
*
2828
* @param shuffleIndex
2929
*/
30-
def purgeCachedShuffleIndices(shuffleIndex: Int): Unit = {
31-
val blockFilter = (block: ShuffleIndexBlockId) => block.shuffleId == shuffleIndex
32-
cachedArrayLengths.remove(blockFilter, None)
30+
def purgeCachedDataForShuffle(shuffleIndex: Int): Unit = {
31+
if (dispatcher.cachePartitionLengths) {
32+
val filter = (block: ShuffleIndexBlockId) => block.shuffleId == shuffleIndex
33+
cachedArrayLengths.remove(filter, None)
34+
}
35+
if (dispatcher.cacheChecksums) {
36+
val filter = (block: ShuffleChecksumBlockId) => block.shuffleId == shuffleIndex
37+
cachedChecksums.remove(filter, None)
38+
}
3339
}
3440

3541
/**
@@ -59,6 +65,7 @@ object S3ShuffleHelper extends Logging {
5965
def listShuffleIndices(shuffleId: Int): Array[ShuffleIndexBlockId] = {
6066
val shuffleIndexFilter: PathFilter = new PathFilter() {
6167
private val prefix = f"shuffle_${shuffleId}_"
68+
6269
override def accept(path: Path): Boolean = {
6370
val name = path.getName
6471
name.startsWith(prefix) && name.endsWith("_0.index")
@@ -85,8 +92,8 @@ object S3ShuffleHelper extends Logging {
8592
* @param mapId
8693
* @return
8794
*/
88-
def getPartitionLengthsCached(shuffleId: Int, mapId: Long): Array[Long] = {
89-
getPartitionLengthsCached(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID))
95+
def getPartitionLengths(shuffleId: Int, mapId: Long): Array[Long] = {
96+
getPartitionLengths(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID))
9097
}
9198

9299
/**
@@ -95,19 +102,21 @@ object S3ShuffleHelper extends Logging {
95102
* @param blockId
96103
* @return
97104
*/
98-
def getPartitionLengthsCached(blockId: ShuffleIndexBlockId): Array[Long] = {
99-
cachedArrayLengths.getOrElsePut(blockId, readBlockAsArray)
100-
}
101-
102-
def getChecksumsCached(shuffleId: Int, mapId: Long): Array[Long] = {
103-
cachedChecksums.getOrElsePut(ShuffleChecksumBlockId(shuffleId, mapId, 0), readBlockAsArray)
105+
def getPartitionLengths(blockId: ShuffleIndexBlockId): Array[Long] = {
106+
if (dispatcher.cachePartitionLengths) {
107+
return cachedArrayLengths.getOrElsePut(blockId, readBlockAsArray)
108+
}
109+
readBlockAsArray(blockId)
104110
}
105111

106112
def getChecksums(shuffleId: Int, mapId: Long): Array[Long] = {
107-
getChecksums(ShuffleChecksumBlockId(shuffleId = shuffleId, mapId = mapId, reduceId = 0))
113+
getChecksums(ShuffleChecksumBlockId(shuffleId, mapId, 0))
108114
}
109115

110116
def getChecksums(blockId: ShuffleChecksumBlockId): Array[Long] = {
117+
if (dispatcher.cacheChecksums) {
118+
return cachedChecksums.getOrElsePut(blockId, readBlockAsArray)
119+
}
111120
readBlockAsArray(blockId)
112121
}
113122

src/main/scala/org/apache/spark/shuffle/sort/S3ShuffleManager.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ private[spark] class S3ShuffleManager(conf: SparkConf) extends ShuffleManager wi
133133
// Remove and close all input streams.
134134
dispatcher.closeCachedBlocks(shuffleId)
135135
// Remove metadata.
136-
S3ShuffleHelper.purgeCachedShuffleIndices(shuffleId)
136+
S3ShuffleHelper.purgeCachedDataForShuffle(shuffleId)
137137
}
138138

139139
/** Remove a shuffle's metadata from the ShuffleManager. */

src/main/scala/org/apache/spark/storage/S3ChecksumValidationStream.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@ class S3ChecksumValidationStream(
2727
}
2828

2929
private val checksum: Checksum = S3ShuffleHelper.createChecksumAlgorithm(checksumAlgorithm)
30-
private val lengths: Array[Long] = S3ShuffleHelper.getPartitionLengthsCached(shuffleId, mapId)
31-
private val referenceChecksums: Array[Long] = S3ShuffleHelper.getChecksumsCached(shuffleId, mapId)
30+
private val lengths: Array[Long] = S3ShuffleHelper.getPartitionLengths(shuffleId, mapId)
31+
private val referenceChecksums: Array[Long] = S3ShuffleHelper.getChecksums(shuffleId, mapId)
3232

3333
private var pos: Long = 0
3434
private var reduceId: Int = startReduceId
3535
private var blockLength: Long = lengths(reduceId)
36+
3637
private def eof(): Boolean = reduceId > endReduceId
3738

3839
validateChecksum()

src/main/scala/org/apache/spark/storage/S3ShuffleBlockIterator.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class S3ShuffleBlockIterator(
5757
}
5858

5959
def getAccumulatedLengths(shuffleId: Int, mapId: Long): Array[Long] = {
60-
val lengths = S3ShuffleHelper.getPartitionLengthsCached(shuffleId, mapId)
60+
val lengths = S3ShuffleHelper.getPartitionLengths(shuffleId, mapId)
6161
Array[Long](0) ++ lengths.tail.scan(lengths.head)(_ + _)
6262
}
6363
}

0 commit comments

Comments
 (0)