Skip to content

Commit f43a02d

Browse files
committed
Buffer streams up to a threshold.
Signed-off-by: Pascal Spörri <psp@zurich.ibm.com>
1 parent ee7e983 commit f43a02d

File tree

4 files changed

+147
-44
lines changed

4 files changed

+147
-44
lines changed

README.md

+2-4
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ These configuration values need to be passed to Spark to load and configure the
3535

3636
Changing these values might have an impact on performance.
3737

38-
- `spark.shuffle.s3.bufferSize`: Default size of the buffered output streams (default: `32768`,
39-
uses `spark.shuffle.file.buffer` as default)
40-
- `spark.shuffle.s3.bufferInputSize`: Maximum size of buffered input streams (default: `209715200`,
41-
uses `spark.network.maxRemoteBlockSizeFetchToMem` as default)
38+
- `spark.shuffle.s3.bufferSize`: Default buffer size when writing (default: `8388608`)
39+
- `spark.shuffle.s3.maxBufferSizeTask`: Maximum size of the buffered output streams per task (default: `134217728`)
4240
- `spark.shuffle.s3.cleanup`: Cleanup the shuffle files (default: `true`)
4341
- `spark.shuffle.s3.folderPrefixes`: The number of prefixes to use when storing files on S3
4442
(default: `10`, minimum: `1`).

src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class S3ShuffleDispatcher extends Logging {
3434
private val isS3A = rootDir.startsWith("s3a://")
3535

3636
// Optional
37-
val bufferSize: Int = conf.getInt("spark.shuffle.s3.bufferSize", defaultValue = conf.get(SHUFFLE_FILE_BUFFER_SIZE).toInt * 1024)
38-
val bufferInputSize: Int = conf.getInt("spark.shuffle.s3.bufferInputSize", defaultValue = conf.get(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM).toInt)
37+
val bufferSize: Int = conf.getInt("spark.shuffle.s3.bufferSize", defaultValue = 8 * 1024 * 1024)
38+
val maxBufferSizeTask: Int = conf.getInt("spark.shuffle.s3.maxBufferSizeTask", defaultValue = 128 * 1024 * 1024)
3939
val cleanupShuffleFiles: Boolean = conf.getBoolean("spark.shuffle.s3.cleanup", defaultValue = true)
4040
val folderPrefixes: Int = conf.getInt("spark.shuffle.s3.folderPrefixes", defaultValue = 10)
4141
val prefetchBatchSize: Int = conf.getInt("spark.shuffle.s3.prefetchBatchSize", defaultValue = 25)
@@ -60,7 +60,7 @@ class S3ShuffleDispatcher extends Logging {
6060

6161
// Optional
6262
logInfo(s"- spark.shuffle.s3.bufferSize=${bufferSize}")
63-
logInfo(s"- spark.shuffle.s3.bufferInputSize=${bufferInputSize}")
63+
logInfo(s"- spark.shuffle.s3.maxBufferSizeTask=${maxBufferSizeTask}")
6464
logInfo(s"- spark.shuffle.s3.cleanup=${cleanupShuffleFiles}")
6565
logInfo(s"- spark.shuffle.s3.folderPrefixes=${folderPrefixes}")
6666
logInfo(s"- spark.shuffle.s3.prefetchBlockSize=${prefetchBatchSize}")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/**
2+
* Copyright 2023- IBM Inc. All rights reserved
3+
* SPDX-License-Identifier: Apache2.0
4+
*/
5+
6+
package org.apache.spark.storage
7+
8+
import org.apache.spark.internal.Logging
9+
10+
import java.io.{BufferedInputStream, InputStream}
11+
import java.util
12+
13+
class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)], maxBufferSize: Long) extends Iterator[(BlockId, InputStream)] with Logging {
14+
@volatile private var memoryUsage: Long = 0
15+
@volatile private var hasItem: Boolean = iter.hasNext
16+
private var timeWaiting: Long = 0
17+
private var timePrefetching: Long = 0
18+
private var timeNext: Long = 0
19+
private var numStreams: Long = 0
20+
private var bytesRead: Long = 0
21+
22+
private var nextElement: (BlockId, S3ShuffleBlockStream) = null
23+
24+
private val completed = new util.LinkedList[(InputStream, BlockId, Long)]()
25+
26+
private def prefetchThread(): Unit = {
27+
while (iter.hasNext || nextElement != null) {
28+
if (nextElement == null) {
29+
val now = System.nanoTime()
30+
nextElement = iter.next()
31+
timeNext = System.nanoTime() - now
32+
}
33+
val bsize = scala.math.min(maxBufferSize, nextElement._2.maxBytes).toInt
34+
35+
var fetchNext = false
36+
synchronized {
37+
if (memoryUsage + math.min(bsize, maxBufferSize) > maxBufferSize) {
38+
try {
39+
wait()
40+
}
41+
catch {
42+
case _: InterruptedException =>
43+
Thread.currentThread.interrupt()
44+
}
45+
} else {
46+
fetchNext = true
47+
}
48+
}
49+
50+
if (fetchNext) {
51+
val block = nextElement._1
52+
val s = nextElement._2
53+
nextElement = null
54+
val now = System.nanoTime()
55+
val stream = new BufferedInputStream(s, bsize)
56+
// Fill the buffered input stream by reading and then resetting the stream.
57+
stream.mark(bsize)
58+
stream.read()
59+
stream.reset()
60+
timePrefetching += System.nanoTime() - now
61+
bytesRead += bsize
62+
synchronized {
63+
memoryUsage += bsize
64+
completed.push((stream, block, bsize))
65+
hasItem = iter.hasNext
66+
notifyAll()
67+
}
68+
}
69+
}
70+
}
71+
72+
private val self = this
73+
private val thread = new Thread {
74+
override def run(): Unit = {
75+
self.prefetchThread()
76+
}
77+
}
78+
thread.start()
79+
80+
private def printStatistics(): Unit = synchronized {
81+
try {
82+
val tW = timeWaiting / 1000000
83+
val tP = timePrefetching / 1000000
84+
val tN = timeNext / 1000000
85+
val bR = bytesRead
86+
val r = numStreams
87+
// Average time per prefetch
88+
val atP = tP / r
89+
// Average time waiting
90+
val atW = tW / r
91+
// Average time next
92+
val atN = tN / r
93+
// Average read bandwidth
94+
val bW = bR.toDouble / (tP.toDouble / 1000) / (1024 * 1024)
95+
// Block size
96+
val bs = bR / r
97+
logInfo(s"Statistics: ${bR} bytes, ${tW} ms waiting (${atW} avg), " +
98+
s"${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s) " +
99+
s"${tN} ms for next (${atN} avg)")
100+
} catch {
101+
case e: Exception => logError(f"Unable to print statistics: ${e.getMessage}.")
102+
}
103+
}
104+
105+
override def hasNext: Boolean = synchronized {
106+
val result = hasItem || (completed.size() > 0)
107+
if (!result) {
108+
printStatistics()
109+
}
110+
result
111+
}
112+
113+
override def next(): (BlockId, InputStream) = synchronized {
114+
val now = System.nanoTime()
115+
while (completed.isEmpty) {
116+
try {
117+
wait()
118+
} catch {
119+
case _: InterruptedException =>
120+
Thread.currentThread.interrupt()
121+
}
122+
}
123+
timeWaiting += System.nanoTime() - now
124+
numStreams += 1
125+
val result = completed.pop()
126+
memoryUsage -= result._3
127+
notifyAll()
128+
return (result._2, result._1)
129+
}
130+
}

src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala

+12-37
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,9 @@ import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReadMetricsReporter,
3030
import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo
3131
import org.apache.spark.util.{CompletionIterator, ThreadUtils}
3232
import org.apache.spark.util.collection.ExternalSorter
33-
import org.apache.spark.{InterruptibleIterator, SparkConf, SparkEnv, SparkException, TaskContext}
33+
import org.apache.spark.{InterruptibleIterator, SparkConf, SparkEnv, TaskContext}
3434

35-
import java.io.{BufferedInputStream, InputStream}
36-
import java.util.zip.{CheckedInputStream, Checksum}
37-
import scala.concurrent.duration.Duration
38-
import scala.concurrent.{Await, ExecutionContext, Future}
35+
import scala.concurrent.ExecutionContext
3936

4037
/**
4138
* This class was adapted from Apache Spark: BlockStoreShuffleReader.
@@ -55,7 +52,7 @@ class S3ShuffleReader[K, C](
5552

5653
private val dispatcher = S3ShuffleDispatcher.get
5754
private val dep = handle.dependency
58-
private val bufferInputSize = dispatcher.bufferInputSize
55+
private val maxBufferSizeTask = dispatcher.maxBufferSizeTask
5956

6057
private val fetchContinousBlocksInBatch: Boolean = {
6158
val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects
@@ -77,17 +74,6 @@ class S3ShuffleReader[K, C](
7774
doBatchFetch
7875
}
7976

80-
// Source: Cassandra connector for Apache Spark (https://github.com/datastax/spark-cassandra-connector)
81-
// com.datastax.spark.connector.datasource.JoinHelper
82-
// License: Apache 2.0
83-
// See here for an explanation: http://www.russellspitzer.com/2017/02/27/Concurrency-In-Spark/
84-
def slidingPrefetchIterator[T](it: Iterator[Future[T]], batchSize: Int): Iterator[T] = {
85-
val (firstElements, lastElement) = it.grouped(batchSize)
86-
.sliding(2)
87-
.span(_ => it.hasNext)
88-
(firstElements.map(_.head) ++ lastElement.flatten).flatten.map(Await.result(_, Duration.Inf))
89-
}
90-
9177
override def read(): Iterator[Product2[K, C]] = {
9278
val serializerInstance = dep.serializer.newInstance()
9379
val blocks = computeShuffleBlocks(handle.shuffleId,
@@ -98,35 +84,24 @@ class S3ShuffleReader[K, C](
9884

9985
val wrappedStreams = new S3ShuffleBlockIterator(blocks)
10086

101-
// Create a key/value iterator for each stream
102-
val recordIterPromise = wrappedStreams.filterNot(_._2.maxBytes == 0).map { case (blockId, wrappedStream) =>
103-
readMetrics.incRemoteBytesRead(wrappedStream.maxBytes) // increase byte count.
87+
val filteredStream = wrappedStreams.filterNot(_._2.maxBytes == 0).map(f => {
88+
readMetrics.incRemoteBytesRead(f._2.maxBytes) // increase byte count.
10489
readMetrics.incRemoteBlocksFetched(1)
105-
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
106-
// NextIterator. The NextIterator makes sure that close() is called on the
107-
// underlying InputStream when all records have been read.
108-
Future {
109-
val bufferSize = scala.math.min(wrappedStream.maxBytes, bufferInputSize).toInt
110-
val stream = new BufferedInputStream(wrappedStream, bufferSize)
111-
112-
// Fill the buffered input stream by reading and then resetting the stream.
113-
stream.mark(bufferSize)
114-
stream.read()
115-
stream.reset()
116-
90+
f
91+
})
92+
val recordIter = new S3BufferedPrefetchIterator(filteredStream, maxBufferSizeTask)
93+
.flatMap(s => {
94+
val stream = s._2
95+
val blockId = s._1
11796
val checkedStream = if (dispatcher.checksumEnabled) {
11897
new S3ChecksumValidationStream(blockId, stream, dispatcher.checksumAlgorithm)
11998
} else {
12099
stream
121100
}
122-
123101
serializerInstance
124102
.deserializeStream(serializerManager.wrapStream(blockId, checkedStream))
125103
.asKeyValueIterator
126-
}(S3ShuffleReader.asyncExecutionContext)
127-
}
128-
129-
val recordIter = slidingPrefetchIterator(recordIterPromise, dispatcher.prefetchBatchSize).flatten
104+
})
130105

131106
// Update the context task metrics for each record read.
132107
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](

0 commit comments

Comments
 (0)