Skip to content

Commit 793f9c6

Browse files
committed
Fix Batch input issue with Scala Benchmark (apache#12848)
* add initial change * add fix * improved usage of Shape as well as warning message on performance * change into parallel * drop dropBack * apply Andrew's comments * remove add dim inside img 2 pixel * addressed Naveen's comment * update comments
1 parent afe7264 commit 793f9c6

File tree

10 files changed

+63
-33
lines changed

10 files changed

+63
-33
lines changed

scala-package/core/src/test/scala/org/apache/mxnet/ShapeSuite.scala

+15
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,19 @@ class ShapeSuite extends FunSuite with BeforeAndAfterAll {
2929
assert(Shape(1, 2, 3) === Shape(1, 2, 3))
3030
assert(Shape(1, 2) != Shape(1, 2, 3))
3131
}
32+
33+
test("drop") {
34+
val s = Shape(1, 2, 3)
35+
val s2 = s.drop(1)
36+
assert(s == Shape(1, 2, 3))
37+
assert(s2 == Shape(2, 3))
38+
val s3 = s.drop(2)
39+
assert(s3 == Shape(3))
40+
}
41+
42+
test("slice") {
43+
val s = Shape(1, 2, 3)
44+
val s2 = s.slice(0, 1)
45+
assert(s2 == Shape(1))
46+
}
3247
}

scala-package/examples/src/main/scala/org/apache/mxnetexamples/InferBase.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.mxnet._
2121

2222
trait InferBase {
2323

24-
def loadModel(context: Array[Context]): Any
24+
def loadModel(context: Array[Context], batchInference : Boolean): Any
2525
def loadSingleData(): Any
2626
def loadBatchFileList(batchSize: Int): List[Any]
2727
def loadInputBatch(source: Any): Any

scala-package/examples/src/main/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmark.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ object ScalaInferenceBenchmark {
3131

3232
private val logger = LoggerFactory.getLogger(classOf[CLIParserBase])
3333

34-
def loadModel(objectToRun: InferBase, context: Array[Context]):
34+
def loadModel(objectToRun: InferBase, context: Array[Context], batchInference : Boolean):
3535
Any = {
36-
objectToRun.loadModel(context)
36+
objectToRun.loadModel(context, batchInference)
3737
}
3838

3939
def loadDataSet(objectToRun: InferBase):
@@ -134,7 +134,7 @@ object ScalaInferenceBenchmark {
134134
logger.info("Running single inference call")
135135
// Benchmarking single inference call
136136
NDArrayCollector.auto().withScope {
137-
val loadedModel = loadModel(exampleToBenchmark, context)
137+
val loadedModel = loadModel(exampleToBenchmark, context, false)
138138
val dataSet = loadDataSet(exampleToBenchmark)
139139
val inferenceTimes = runInference(exampleToBenchmark, loadedModel, dataSet, baseCLI.count)
140140
printStatistics(inferenceTimes, "single_inference")
@@ -144,7 +144,7 @@ object ScalaInferenceBenchmark {
144144
logger.info("Running for batch inference call")
145145
// Benchmarking batch inference call
146146
NDArrayCollector.auto().withScope {
147-
val loadedModel = loadModel(exampleToBenchmark, context)
147+
val loadedModel = loadModel(exampleToBenchmark, context, true)
148148
val batchDataSet = loadBatchDataSet(exampleToBenchmark, baseCLI.batchSize)
149149
val inferenceTimes = runBatchInference(exampleToBenchmark, loadedModel, batchDataSet)
150150
printStatistics(inferenceTimes, "batch_inference")

scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,11 @@ class CLIParser extends CLIParserBase{
175175

176176
class ImageClassifierExample(CLIParser: CLIParser) extends InferBase{
177177

178-
override def loadModel(context: Array[Context]): Classifier = {
178+
override def loadModel(context: Array[Context],
179+
batchInference : Boolean = false): Classifier = {
179180
val dType = DType.Float32
180-
val inputShape = Shape(1, 3, 224, 224)
181+
val batchSize = if (batchInference) CLIParser.batchSize else 1
182+
val inputShape = Shape(batchSize, 3, 224, 224)
181183

182184
val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW"))
183185

scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,10 @@ class CLIParser extends CLIParserBase {
203203

204204
class SSDClassifierExample(CLIParser: CLIParser)
205205
extends InferBase {
206-
override def loadModel(context: Array[Context]): Any = {
206+
override def loadModel(context: Array[Context], batchInference: Boolean = false): Any = {
207207
val dType = DType.Float32
208-
val inputShape = Shape(1, 3, 512, 512)
208+
val batchSize = if (batchInference) CLIParser.batchSize else 1
209+
val inputShape = Shape(batchSize, 3, 512, 512)
209210
val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW"))
210211
new ObjectDetector(CLIParser.modelPathPrefix, inputDescriptors, context)
211212
}

scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class TestCharRnn(CLIParser: CLIParser) extends InferBase {
115115

116116
private var vocab : Map[String, Int] = null
117117

118-
override def loadModel(context: Array[Context]): Any = {
118+
override def loadModel(context: Array[Context], batchInference : Boolean = false): Any = {
119119
val batchSize = 32
120120
val buckets = List(129)
121121
val numHidden = 512

scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala

+17-12
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,16 @@ class ImageClassifier(modelPathPrefix: String,
7676
topK: Option[Int] = None): IndexedSeq[IndexedSeq[(String, Float)]] = {
7777

7878
val scaledImage = ImageClassifier.reshapeImage(inputImage, width, height)
79-
val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, inputShape)
79+
val imageShape = inputShape.drop(1)
80+
val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape)
81+
val imgWithBatchNum = NDArray.api.expand_dims(pixelsNDArray, 0)
8082
inputImage.flush()
8183
scaledImage.flush()
84+
handler.execute(pixelsNDArray.dispose())
8285

83-
val output = super.classifyWithNDArray(IndexedSeq(pixelsNDArray), topK)
86+
val output = super.classifyWithNDArray(IndexedSeq(imgWithBatchNum), topK)
8487

85-
handler.execute(pixelsNDArray.dispose())
88+
handler.execute(imgWithBatchNum.dispose())
8689

8790
IndexedSeq(output(0))
8891
}
@@ -97,14 +100,16 @@ class ImageClassifier(modelPathPrefix: String,
97100
def classifyImageBatch(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None):
98101
IndexedSeq[IndexedSeq[(String, Float)]] = {
99102

100-
val imageBatch = ListBuffer[NDArray]()
101-
for (image <- inputBatch) {
102-
val scaledImage = ImageClassifier.reshapeImage(image, width, height)
103-
val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, inputShape)
104-
imageBatch += pixelsNDArray
105-
}
103+
val inputBatchSeq = inputBatch.toIndexedSeq
104+
val imageBatch = inputBatchSeq.indices.par.map(idx => {
105+
val scaledImage = ImageClassifier.reshapeImage(inputBatchSeq(idx), width, height)
106+
val imageShape = inputShape.drop(1)
107+
val imgND = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape)
108+
val imgWithBatch = NDArray.api.expand_dims(imgND, 0).get
109+
handler.execute(imgND.dispose())
110+
imgWithBatch
111+
}).toList
106112
val op = NDArray.concatenate(imageBatch)
107-
108113
val result = super.classifyWithNDArray(IndexedSeq(op), topK)
109114
handler.execute(op.dispose())
110115
handler.execute(imageBatch.foreach(_.dispose()))
@@ -147,9 +152,9 @@ object ImageClassifier {
147152
* returned by this method after the use.
148153
* </p>
149154
* @param resizedImage BufferedImage to get pixels from
150-
* @param inputImageShape Input shape; for example for resnet it is (1,3,224,224).
155+
* @param inputImageShape Input shape; for example for resnet it is (3,224,224).
151156
Should be same as inputDescriptor shape.
152-
* @return NDArray pixels array
157+
* @return NDArray pixels array with shape (3, 224, 224) in CHW format
153158
*/
154159
def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape): NDArray = {
155160
// Get height and width of the image

scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala

+15-9
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,12 @@ class ObjectDetector(modelPathPrefix: String,
7272
: IndexedSeq[IndexedSeq[(String, Array[Float])]] = {
7373

7474
val scaledImage = ImageClassifier.reshapeImage(inputImage, width, height)
75-
val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, inputShape)
76-
val output = objectDetectWithNDArray(IndexedSeq(pixelsNDArray), topK)
75+
val imageShape = inputShape.drop(1)
76+
val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape)
77+
val pixelsNDWithBatch = NDArray.api.expand_dims(pixelsNDArray, 0)
7778
handler.execute(pixelsNDArray.dispose())
79+
val output = objectDetectWithNDArray(IndexedSeq(pixelsNDWithBatch), topK)
80+
handler.execute(pixelsNDWithBatch.dispose())
7881
output
7982
}
8083

@@ -147,13 +150,16 @@ class ObjectDetector(modelPathPrefix: String,
147150
def imageBatchObjectDetect(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None):
148151
IndexedSeq[IndexedSeq[(String, Array[Float])]] = {
149152

150-
val imageBatch = ListBuffer[NDArray]()
151-
for (image <- inputBatch) {
152-
val scaledImage = ImageClassifier.reshapeImage(image, width, height)
153-
val pixelsNdarray = ImageClassifier.bufferedImageToPixels(scaledImage, inputShape)
154-
imageBatch += pixelsNdarray
155-
}
156-
val op = NDArray.concatenate(imageBatch)
153+
val inputBatchSeq = inputBatch.toIndexedSeq
154+
val imageBatch = inputBatchSeq.indices.par.map(idx => {
155+
val scaledImage = ImageClassifier.reshapeImage(inputBatchSeq(idx), width, height)
156+
val imageShape = inputShape.drop(1)
157+
val pixelsND = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape)
158+
val pixelsNDWithBatch = NDArray.api.expand_dims(pixelsND, 0).get
159+
handler.execute(pixelsND.dispose())
160+
pixelsNDWithBatch
161+
})
162+
val op = NDArray.concatenate(imageBatch.toList)
157163

158164
val result = objectDetectWithNDArray(IndexedSeq(op), topK)
159165
handler.execute(op.dispose())

scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala

+1
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ class Predictor(modelPathPrefix: String,
181181

182182
// rebind with the new batchSize
183183
if (batchSize != inputBatchSize) {
184+
logger.info(s"Latency increased due to batchSize mismatch $batchSize vs $inputBatchSize")
184185
val desc = iDescriptors.map((f : DataDesc) => new DataDesc(f.name,
185186
Shape(f.shape.toVector.patch(batchIndex, Vector(inputBatchSize), 1)), f.dtype, f.layout) )
186187
mxNetHandler.execute(mod.bind(desc, forceRebind = true,

scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
6565
val image1 = new BufferedImage(100, 200, BufferedImage.TYPE_BYTE_GRAY)
6666
val image2 = ImageClassifier.reshapeImage(image1, 2, 2)
6767

68-
val result = ImageClassifier.bufferedImageToPixels(image2, Shape(1, 3, 2, 2))
68+
val result = ImageClassifier.bufferedImageToPixels(image2, Shape(3, 2, 2))
6969

70-
assert(result.shape == inputDescriptor(0).shape)
70+
assert(result.shape == inputDescriptor(0).shape.drop(1))
7171
}
7272

7373
test("ImageClassifierSuite-testWithInputImage") {

0 commit comments

Comments
 (0)