Skip to content

Commit 7f6db51

Browse files
committed
Use multiple threads for prefetching.
Signed-off-by: Pascal Spörri <psp@zurich.ibm.com>
1 parent 18191d1 commit 7f6db51

File tree

3 files changed

+36
-22
lines changed

3 files changed

+36
-22
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Changing these values might have an impact on performance.
4040

4141
- `spark.shuffle.s3.bufferSize`: Default buffer size when writing (default: `8388608`)
4242
- `spark.shuffle.s3.maxBufferSizeTask`: Maximum size of the buffered output streams per task (default: `134217728`)
43+
- `spark.shuffle.s3.prefetchConcurrencyTask`: The per-task concurrency when prefetching (default: `2`).
4344
- `spark.shuffle.s3.cachePartitionLengths`: Cache partition lengths in memory (default: `true`)
4445
- `spark.shuffle.s3.cacheChecksums`: Cache checksums in memory (default: `true`)
4546
- `spark.shuffle.s3.cleanup`: Cleanup the shuffle files (default: `true`)

src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class S3ShuffleDispatcher extends Logging {
3333
// Optional
3434
val bufferSize: Int = conf.getInt("spark.shuffle.s3.bufferSize", defaultValue = 8 * 1024 * 1024)
3535
val maxBufferSizeTask: Int = conf.getInt("spark.shuffle.s3.maxBufferSizeTask", defaultValue = 128 * 1024 * 1024)
36+
val prefetchConcurrencyTask: Int = conf.getInt("spark.shuffle.s3.prefetchConcurrencyTask", defaultValue = 2)
3637
val cachePartitionLengths: Boolean = conf.getBoolean("spark.shuffle.s3.cachePartitionLengths", defaultValue = true)
3738
val cacheChecksums: Boolean = conf.getBoolean("spark.shuffle.s3.cacheChecksums", defaultValue = true)
3839
val cleanupShuffleFiles: Boolean = conf.getBoolean("spark.shuffle.s3.cleanup", defaultValue = true)
@@ -60,6 +61,7 @@ class S3ShuffleDispatcher extends Logging {
6061
// Optional
6162
logInfo(s"- spark.shuffle.s3.bufferSize=${bufferSize}")
6263
logInfo(s"- spark.shuffle.s3.maxBufferSizeTask=${maxBufferSizeTask}")
64+
logInfo(s"- spark.shuffle.s3.prefetchConcurrencyTask=${prefetchConcurrencyTask}")
6365
logInfo(s"- spark.shuffle.s3.cachePartitionLengths=${cachePartitionLengths}")
6466
logInfo(s"- spark.shuffle.s3.cacheChecksums=${cacheChecksums}")
6567
logInfo(s"- spark.shuffle.s3.cleanup=${cleanupShuffleFiles}")
@@ -112,7 +114,7 @@ class S3ShuffleDispatcher extends Logging {
112114
def openBlock(blockId: BlockId): FSDataInputStream = {
113115
val status = getFileStatusCached(blockId)
114116
val builder = fs.openFile(status.getPath).withFileStatus(status)
115-
val stream = builder.build().get()
117+
val stream = builder.build().get()
116118
if (canSetReadahead) {
117119
stream.setReadahead(0)
118120
}

src/main/scala/org/apache/spark/storage/S3BufferedPrefetchIterator.scala

+32-21
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,46 @@
66
package org.apache.spark.storage
77

88
import org.apache.spark.internal.Logging
9+
import org.apache.spark.shuffle.helper.S3ShuffleDispatcher
910

1011
import java.io.{BufferedInputStream, InputStream}
1112
import java.util
1213

1314
class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)], maxBufferSize: Long) extends Iterator[(BlockId, InputStream)] with Logging {
15+
16+
private val concurrencyTask = S3ShuffleDispatcher.get.prefetchConcurrencyTask
17+
private val startTime = System.nanoTime()
18+
1419
@volatile private var memoryUsage: Long = 0
1520
@volatile private var hasItem: Boolean = iter.hasNext
1621
private var timeWaiting: Long = 0
1722
private var timePrefetching: Long = 0
18-
private var timeNext: Long = 0
1923
private var numStreams: Long = 0
2024
private var bytesRead: Long = 0
2125

22-
private var nextElement: (BlockId, S3ShuffleBlockStream) = null
26+
private var activeTasks: Long = 0
2327

2428
private val completed = new util.LinkedList[(InputStream, BlockId, Long)]()
2529

2630
private def prefetchThread(): Unit = {
27-
while (iter.hasNext || nextElement != null) {
28-
if (nextElement == null) {
29-
val now = System.nanoTime()
30-
nextElement = iter.next()
31-
timeNext = System.nanoTime() - now
31+
var nextElement: (BlockId, S3ShuffleBlockStream) = null
32+
while (true) {
33+
synchronized {
34+
if (!iter.hasNext && nextElement == null) {
35+
hasItem = false
36+
return
37+
}
38+
if (nextElement == null) {
39+
nextElement = iter.next()
40+
activeTasks += 1
41+
hasItem = iter.hasNext
42+
}
3243
}
33-
val bsize = scala.math.min(maxBufferSize, nextElement._2.maxBytes).toInt
3444

3545
var fetchNext = false
46+
val bsize = scala.math.min(maxBufferSize, nextElement._2.maxBytes).toInt
3647
synchronized {
37-
if (memoryUsage + math.min(bsize, maxBufferSize) > maxBufferSize) {
48+
if (memoryUsage + bsize > maxBufferSize) {
3849
try {
3950
wait()
4051
}
@@ -43,6 +54,7 @@ class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)]
4354
}
4455
} else {
4556
fetchNext = true
57+
memoryUsage += bsize
4658
}
4759
}
4860

@@ -59,50 +71,49 @@ class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)]
5971
timePrefetching += System.nanoTime() - now
6072
bytesRead += bsize
6173
synchronized {
62-
memoryUsage += bsize
6374
completed.push((stream, block, bsize))
64-
hasItem = iter.hasNext
65-
notify()
75+
activeTasks -= 1
76+
notifyAll()
6677
}
6778
}
6879
}
6980
}
7081

7182
private val self = this
72-
private val thread = new Thread {
83+
private val threads = Array.fill[Thread](concurrencyTask)(new Thread {
7384
override def run(): Unit = {
7485
self.prefetchThread()
7586
}
76-
}
77-
thread.start()
87+
})
88+
threads.foreach(_.start())
7889

7990
private def printStatistics(): Unit = synchronized {
91+
val totalRuntime = System.nanoTime() - startTime
8092
try {
93+
val tR = totalRuntime / 1000000
94+
val wPer = 100 * timeWaiting / totalRuntime
8195
val tW = timeWaiting / 1000000
8296
val tP = timePrefetching / 1000000
83-
val tN = timeNext / 1000000
8497
val bR = bytesRead
8598
val r = numStreams
8699
// Average time per prefetch
87100
val atP = tP / r
88101
// Average time waiting
89102
val atW = tW / r
90-
// Average time next
91-
val atN = tN / r
92103
// Average read bandwidth
93104
val bW = bR.toDouble / (tP.toDouble / 1000) / (1024 * 1024)
94105
// Block size
95106
val bs = bR / r
96107
logInfo(s"Statistics: ${bR} bytes, ${tW} ms waiting (${atW} avg), " +
97-
s"${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s) " +
98-
s"${tN} ms for next (${atN} avg)")
108+
s"${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s). " +
109+
s"Total: ${tR} ms - ${wPer}% waiting")
99110
} catch {
100111
case e: Exception => logError(f"Unable to print statistics: ${e.getMessage}.")
101112
}
102113
}
103114

104115
override def hasNext: Boolean = synchronized {
105-
val result = hasItem || (completed.size() > 0)
116+
val result = hasItem || activeTasks > 0 || (completed.size() > 0)
106117
if (!result) {
107118
printStatistics()
108119
}

0 commit comments

Comments
 (0)