6
6
package org .apache .spark .storage
7
7
8
8
import org .apache .spark .internal .Logging
9
+ import org .apache .spark .shuffle .helper .S3ShuffleDispatcher
9
10
10
11
import java .io .{BufferedInputStream , InputStream }
11
12
import java .util
12
13
13
14
class S3BufferedPrefetchIterator (iter : Iterator [(BlockId , S3ShuffleBlockStream )], maxBufferSize : Long ) extends Iterator [(BlockId , InputStream )] with Logging {
15
+
16
+ private val concurrencyTask = S3ShuffleDispatcher .get.prefetchConcurrencyTask
17
+ private val startTime = System .nanoTime()
18
+
14
19
@ volatile private var memoryUsage : Long = 0
15
20
@ volatile private var hasItem : Boolean = iter.hasNext
16
21
private var timeWaiting : Long = 0
17
22
private var timePrefetching : Long = 0
18
- private var timeNext : Long = 0
19
23
private var numStreams : Long = 0
20
24
private var bytesRead : Long = 0
21
25
22
- private var nextElement : ( BlockId , S3ShuffleBlockStream ) = null
26
+ private var activeTasks : Long = 0
23
27
24
28
private val completed = new util.LinkedList [(InputStream , BlockId , Long )]()
25
29
26
30
private def prefetchThread (): Unit = {
27
- while (iter.hasNext || nextElement != null ) {
28
- if (nextElement == null ) {
29
- val now = System .nanoTime()
30
- nextElement = iter.next()
31
- timeNext = System .nanoTime() - now
31
+ var nextElement : (BlockId , S3ShuffleBlockStream ) = null
32
+ while (true ) {
33
+ synchronized {
34
+ if (! iter.hasNext && nextElement == null ) {
35
+ hasItem = false
36
+ return
37
+ }
38
+ if (nextElement == null ) {
39
+ nextElement = iter.next()
40
+ activeTasks += 1
41
+ hasItem = iter.hasNext
42
+ }
32
43
}
33
- val bsize = scala.math.min(maxBufferSize, nextElement._2.maxBytes).toInt
34
44
35
45
var fetchNext = false
46
+ val bsize = scala.math.min(maxBufferSize, nextElement._2.maxBytes).toInt
36
47
synchronized {
37
- if (memoryUsage + math.min( bsize, maxBufferSize) > maxBufferSize) {
48
+ if (memoryUsage + bsize > maxBufferSize) {
38
49
try {
39
50
wait()
40
51
}
@@ -43,6 +54,7 @@ class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)]
43
54
}
44
55
} else {
45
56
fetchNext = true
57
+ memoryUsage += bsize
46
58
}
47
59
}
48
60
@@ -59,50 +71,49 @@ class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)]
59
71
timePrefetching += System .nanoTime() - now
60
72
bytesRead += bsize
61
73
synchronized {
62
- memoryUsage += bsize
63
74
completed.push((stream, block, bsize))
64
- hasItem = iter.hasNext
65
- notify ()
75
+ activeTasks -= 1
76
+ notifyAll ()
66
77
}
67
78
}
68
79
}
69
80
}
70
81
71
82
private val self = this
72
- private val thread = new Thread {
83
+ private val threads = Array .fill[ Thread ](concurrencyTask)( new Thread {
73
84
override def run (): Unit = {
74
85
self.prefetchThread()
75
86
}
76
- }
77
- thread. start()
87
+ })
88
+ threads.foreach(_. start() )
78
89
79
90
private def printStatistics (): Unit = synchronized {
91
+ val totalRuntime = System .nanoTime() - startTime
80
92
try {
93
+ val tR = totalRuntime / 1000000
94
+ val wPer = 100 * timeWaiting / totalRuntime
81
95
val tW = timeWaiting / 1000000
82
96
val tP = timePrefetching / 1000000
83
- val tN = timeNext / 1000000
84
97
val bR = bytesRead
85
98
val r = numStreams
86
99
// Average time per prefetch
87
100
val atP = tP / r
88
101
// Average time waiting
89
102
val atW = tW / r
90
- // Average time next
91
- val atN = tN / r
92
103
// Average read bandwidth
93
104
val bW = bR.toDouble / (tP.toDouble / 1000 ) / (1024 * 1024 )
94
105
// Block size
95
106
val bs = bR / r
96
107
logInfo(s " Statistics: ${bR} bytes, ${tW} ms waiting ( ${atW} avg), " +
97
- s " ${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s) " +
98
- s " ${tN } ms for next ( ${atN} avg) " )
108
+ s " ${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s). " +
109
+ s " Total: ${tR } ms - ${wPer} % waiting " )
99
110
} catch {
100
111
case e : Exception => logError(f " Unable to print statistics: ${e.getMessage}. " )
101
112
}
102
113
}
103
114
104
115
override def hasNext : Boolean = synchronized {
105
- val result = hasItem || (completed.size() > 0 )
116
+ val result = hasItem || activeTasks > 0 || (completed.size() > 0 )
106
117
if (! result) {
107
118
printStatistics()
108
119
}
0 commit comments