@@ -30,12 +30,9 @@ import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReadMetricsReporter,
30
30
import org .apache .spark .storage .ShuffleBlockFetcherIterator .FetchBlockInfo
31
31
import org .apache .spark .util .{CompletionIterator , ThreadUtils }
32
32
import org .apache .spark .util .collection .ExternalSorter
33
- import org .apache .spark .{InterruptibleIterator , SparkConf , SparkEnv , SparkException , TaskContext }
33
+ import org .apache .spark .{InterruptibleIterator , SparkConf , SparkEnv , TaskContext }
34
34
35
- import java .io .{BufferedInputStream , InputStream }
36
- import java .util .zip .{CheckedInputStream , Checksum }
37
- import scala .concurrent .duration .Duration
38
- import scala .concurrent .{Await , ExecutionContext , Future }
35
+ import scala .concurrent .ExecutionContext
39
36
40
37
/**
41
38
* This class was adapted from Apache Spark: BlockStoreShuffleReader.
@@ -55,7 +52,7 @@ class S3ShuffleReader[K, C](
55
52
56
53
private val dispatcher = S3ShuffleDispatcher .get
57
54
private val dep = handle.dependency
58
- private val bufferInputSize = dispatcher.bufferInputSize
55
+ private val maxBufferSizeTask = dispatcher.maxBufferSizeTask
59
56
60
57
private val fetchContinousBlocksInBatch : Boolean = {
61
58
val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects
@@ -77,17 +74,6 @@ class S3ShuffleReader[K, C](
77
74
doBatchFetch
78
75
}
79
76
80
- // Source: Cassandra connector for Apache Spark (https://github.com/datastax/spark-cassandra-connector)
81
- // com.datastax.spark.connector.datasource.JoinHelper
82
- // License: Apache 2.0
83
- // See here for an explanation: http://www.russellspitzer.com/2017/02/27/Concurrency-In-Spark/
84
- def slidingPrefetchIterator [T ](it : Iterator [Future [T ]], batchSize : Int ): Iterator [T ] = {
85
- val (firstElements, lastElement) = it.grouped(batchSize)
86
- .sliding(2 )
87
- .span(_ => it.hasNext)
88
- (firstElements.map(_.head) ++ lastElement.flatten).flatten.map(Await .result(_, Duration .Inf ))
89
- }
90
-
91
77
override def read (): Iterator [Product2 [K , C ]] = {
92
78
val serializerInstance = dep.serializer.newInstance()
93
79
val blocks = computeShuffleBlocks(handle.shuffleId,
@@ -98,35 +84,24 @@ class S3ShuffleReader[K, C](
98
84
99
85
val wrappedStreams = new S3ShuffleBlockIterator (blocks)
100
86
101
- // Create a key/value iterator for each stream
102
- val recordIterPromise = wrappedStreams.filterNot(_._2.maxBytes == 0 ).map { case (blockId, wrappedStream) =>
103
- readMetrics.incRemoteBytesRead(wrappedStream.maxBytes) // increase byte count.
87
+ val filteredStream = wrappedStreams.filterNot(_._2.maxBytes == 0 ).map(f => {
88
+ readMetrics.incRemoteBytesRead(f._2.maxBytes) // increase byte count.
104
89
readMetrics.incRemoteBlocksFetched(1 )
105
- // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
106
- // NextIterator. The NextIterator makes sure that close() is called on the
107
- // underlying InputStream when all records have been read.
108
- Future {
109
- val bufferSize = scala.math.min(wrappedStream.maxBytes, bufferInputSize).toInt
110
- val stream = new BufferedInputStream (wrappedStream, bufferSize)
111
-
112
- // Fill the buffered input stream by reading and then resetting the stream.
113
- stream.mark(bufferSize)
114
- stream.read()
115
- stream.reset()
116
-
90
+ f
91
+ })
92
+ val recordIter = new S3BufferedPrefetchIterator (filteredStream, maxBufferSizeTask)
93
+ .flatMap(s => {
94
+ val stream = s._2
95
+ val blockId = s._1
117
96
val checkedStream = if (dispatcher.checksumEnabled) {
118
97
new S3ChecksumValidationStream (blockId, stream, dispatcher.checksumAlgorithm)
119
98
} else {
120
99
stream
121
100
}
122
-
123
101
serializerInstance
124
102
.deserializeStream(serializerManager.wrapStream(blockId, checkedStream))
125
103
.asKeyValueIterator
126
- }(S3ShuffleReader .asyncExecutionContext)
127
- }
128
-
129
- val recordIter = slidingPrefetchIterator(recordIterPromise, dispatcher.prefetchBatchSize).flatten
104
+ })
130
105
131
106
// Update the context task metrics for each record read.
132
107
val metricIter = CompletionIterator [(Any , Any ), Iterator [(Any , Any )]](
0 commit comments