@@ -27,9 +27,15 @@ object S3ShuffleHelper extends Logging {
27
27
*
28
28
* @param shuffleIndex
29
29
*/
30
- def purgeCachedShuffleIndices (shuffleIndex : Int ): Unit = {
31
- val blockFilter = (block : ShuffleIndexBlockId ) => block.shuffleId == shuffleIndex
32
- cachedArrayLengths.remove(blockFilter, None )
30
+ def purgeCachedDataForShuffle (shuffleIndex : Int ): Unit = {
31
+ if (dispatcher.cachePartitionLengths) {
32
+ val filter = (block : ShuffleIndexBlockId ) => block.shuffleId == shuffleIndex
33
+ cachedArrayLengths.remove(filter, None )
34
+ }
35
+ if (dispatcher.cacheChecksums) {
36
+ val filter = (block : ShuffleChecksumBlockId ) => block.shuffleId == shuffleIndex
37
+ cachedChecksums.remove(filter, None )
38
+ }
33
39
}
34
40
35
41
/**
@@ -59,6 +65,7 @@ object S3ShuffleHelper extends Logging {
59
65
def listShuffleIndices (shuffleId : Int ): Array [ShuffleIndexBlockId ] = {
60
66
val shuffleIndexFilter : PathFilter = new PathFilter () {
61
67
private val prefix = f " shuffle_ ${shuffleId}_ "
68
+
62
69
override def accept (path : Path ): Boolean = {
63
70
val name = path.getName
64
71
name.startsWith(prefix) && name.endsWith(" _0.index" )
@@ -85,8 +92,8 @@ object S3ShuffleHelper extends Logging {
85
92
* @param mapId
86
93
* @return
87
94
*/
88
- def getPartitionLengthsCached (shuffleId : Int , mapId : Long ): Array [Long ] = {
89
- getPartitionLengthsCached (ShuffleIndexBlockId (shuffleId, mapId, NOOP_REDUCE_ID ))
95
+ def getPartitionLengths (shuffleId : Int , mapId : Long ): Array [Long ] = {
96
+ getPartitionLengths (ShuffleIndexBlockId (shuffleId, mapId, NOOP_REDUCE_ID ))
90
97
}
91
98
92
99
/**
@@ -95,19 +102,21 @@ object S3ShuffleHelper extends Logging {
95
102
* @param blockId
96
103
* @return
97
104
*/
98
- def getPartitionLengthsCached (blockId : ShuffleIndexBlockId ): Array [Long ] = {
99
- cachedArrayLengths.getOrElsePut(blockId, readBlockAsArray)
100
- }
101
-
102
- def getChecksumsCached (shuffleId : Int , mapId : Long ): Array [Long ] = {
103
- cachedChecksums.getOrElsePut(ShuffleChecksumBlockId (shuffleId, mapId, 0 ), readBlockAsArray)
105
+ def getPartitionLengths (blockId : ShuffleIndexBlockId ): Array [Long ] = {
106
+ if (dispatcher.cachePartitionLengths) {
107
+ return cachedArrayLengths.getOrElsePut(blockId, readBlockAsArray)
108
+ }
109
+ readBlockAsArray(blockId)
104
110
}
105
111
106
112
def getChecksums (shuffleId : Int , mapId : Long ): Array [Long ] = {
107
- getChecksums(ShuffleChecksumBlockId (shuffleId = shuffleId , mapId = mapId, reduceId = 0 ))
113
+ getChecksums(ShuffleChecksumBlockId (shuffleId, mapId, 0 ))
108
114
}
109
115
110
116
def getChecksums (blockId : ShuffleChecksumBlockId ): Array [Long ] = {
117
+ if (dispatcher.cacheChecksums) {
118
+ return cachedChecksums.getOrElsePut(blockId, readBlockAsArray)
119
+ }
111
120
readBlockAsArray(blockId)
112
121
}
113
122
0 commit comments