Skip to content

Commit 6f07891

Browse files
Roshrininswamy
authored andcommitted
adding context parameter to infer api- imageclassifier and objectdetector (apache#10252)
* adding context parameter * parameter description added
1 parent 59597e2 commit 6f07891

File tree

4 files changed

+56
-36
lines changed

4 files changed

+56
-36
lines changed

scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ImageClassifier.scala

+11-7
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package ml.dmlc.mxnet.infer
1919

20-
import ml.dmlc.mxnet.{DataDesc, NDArray, Shape}
20+
import ml.dmlc.mxnet.{Context, DataDesc, NDArray, Shape}
2121

2222
import scala.collection.mutable.ListBuffer
2323

@@ -37,13 +37,15 @@ import javax.imageio.ImageIO
3737
* file://model-dir/synset.txt
3838
* @param inputDescriptors Descriptors defining the input node names, shape,
3939
* layout and Type parameters
40+
* @param contexts Device Contexts on which you want to run Inference, defaults to CPU.
41+
* @param epoch Model epoch to load, defaults to 0.
4042
*/
4143
class ImageClassifier(modelPathPrefix: String,
42-
inputDescriptors: IndexedSeq[DataDesc])
44+
inputDescriptors: IndexedSeq[DataDesc],
45+
contexts: Array[Context] = Context.cpu(),
46+
epoch: Option[Int] = Some(0))
4347
extends Classifier(modelPathPrefix,
44-
inputDescriptors) {
45-
46-
val classifier: Classifier = getClassifier(modelPathPrefix, inputDescriptors)
48+
inputDescriptors, contexts, epoch) {
4749

4850
protected[infer] val inputLayout = inputDescriptors.head.layout
4951

@@ -108,8 +110,10 @@ class ImageClassifier(modelPathPrefix: String,
108110
result
109111
}
110112

111-
def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]): Classifier = {
112-
new Classifier(modelPathPrefix, inputDescriptors)
113+
def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc],
114+
contexts: Array[Context] = Context.cpu(),
115+
epoch: Option[Int] = Some(0)): Classifier = {
116+
new Classifier(modelPathPrefix, inputDescriptors, contexts, epoch)
113117
}
114118
}
115119

scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ObjectDetector.scala

+24-13
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
*/
1717

1818
package ml.dmlc.mxnet.infer
19+
1920
// scalastyle:off
2021
import java.awt.image.BufferedImage
2122
// scalastyle:on
22-
import ml.dmlc.mxnet.NDArray
23-
import ml.dmlc.mxnet.DataDesc
23+
24+
import ml.dmlc.mxnet.{Context, DataDesc, NDArray}
2425
import scala.collection.mutable.ListBuffer
26+
2527
/**
2628
* A class for object detection tasks
2729
*
@@ -32,11 +34,16 @@ import scala.collection.mutable.ListBuffer
3234
* file://model-dir/synset.txt
3335
* @param inputDescriptors Descriptors defining the input node names, shape,
3436
* layout and Type parameters
37+
* @param contexts Device Contexts on which you want to run Inference, defaults to CPU.
38+
* @param epoch Model epoch to load, defaults to 0.
3539
*/
3640
class ObjectDetector(modelPathPrefix: String,
37-
inputDescriptors: IndexedSeq[DataDesc]) {
41+
inputDescriptors: IndexedSeq[DataDesc],
42+
contexts: Array[Context] = Context.cpu(),
43+
epoch: Option[Int] = Some(0)) {
3844

39-
val imgClassifier: ImageClassifier = getImageClassifier(modelPathPrefix, inputDescriptors)
45+
val imgClassifier: ImageClassifier =
46+
getImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)
4047

4148
val inputShape = imgClassifier.inputShape
4249

@@ -54,7 +61,7 @@ class ObjectDetector(modelPathPrefix: String,
5461
* To Detect bounding boxes and corresponding labels
5562
*
5663
* @param inputImage : PathPrefix of the input image
57-
* @param topK : Get top k elements with maximum probability
64+
* @param topK : Get top k elements with maximum probability
5865
* @return List of List of tuples of (class, [probability, xmin, ymin, xmax, ymax])
5966
*/
6067
def imageObjectDetect(inputImage: BufferedImage,
@@ -71,9 +78,10 @@ class ObjectDetector(modelPathPrefix: String,
7178
/**
7279
* Takes input images as NDArrays. Useful when you want to perform multiple operations on
7380
* the input Array, or when you want to pass a batch of input images.
81+
*
7482
* @param input : Indexed Sequence of NDArrays
75-
* @param topK : (Optional) How many top_k(sorting will be based on the last axis)
76-
* elements to return. If not passed, returns all unsorted output.
83+
* @param topK : (Optional) How many top_k(sorting will be based on the last axis)
84+
* elements to return. If not passed, returns all unsorted output.
7785
* @return List of List of tuples of (class, [probability, xmin, ymin, xmax, ymax])
7886
*/
7987
def objectDetectWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int])
@@ -90,10 +98,10 @@ class ObjectDetector(modelPathPrefix: String,
9098
batchResult.toIndexedSeq
9199
}
92100

93-
private def sortAndReformat(predictResultND : NDArray, topK: Option[Int])
101+
private def sortAndReformat(predictResultND: NDArray, topK: Option[Int])
94102
: IndexedSeq[(String, Array[Float])] = {
95103
val predictResult: ListBuffer[Array[Float]] = ListBuffer[Array[Float]]()
96-
val accuracy : ListBuffer[Float] = ListBuffer[Float]()
104+
val accuracy: ListBuffer[Float] = ListBuffer[Float]()
97105

98106
// iterating over the all the predictions
99107
val length = predictResultND.shape(0)
@@ -110,7 +118,7 @@ class ObjectDetector(modelPathPrefix: String,
110118
handler.execute(r.dispose())
111119
}
112120
var result = IndexedSeq[(String, Array[Float])]()
113-
if(topK.isDefined) {
121+
if (topK.isDefined) {
114122
var sortedIndices = accuracy.zipWithIndex.sortBy(-_._1).map(_._2)
115123
sortedIndices = sortedIndices.take(topK.get)
116124
// takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax
@@ -127,8 +135,9 @@ class ObjectDetector(modelPathPrefix: String,
127135

128136
/**
129137
* To classify batch of input images according to the provided model
138+
*
130139
* @param inputBatch Input batch of Buffered images
131-
* @param topK Get top k elements with maximum probability
140+
* @param topK Get top k elements with maximum probability
132141
* @return List of list of tuples of (class, probability)
133142
*/
134143
def imageBatchObjectDetect(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None):
@@ -148,9 +157,11 @@ class ObjectDetector(modelPathPrefix: String,
148157
result
149158
}
150159

151-
def getImageClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]):
160+
def getImageClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc],
161+
contexts: Array[Context] = Context.cpu(),
162+
epoch: Option[Int] = Some(0)):
152163
ImageClassifier = {
153-
new ImageClassifier(modelPathPrefix, inputDescriptors)
164+
new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)
154165
}
155166

156167
}

scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ImageClassifierSuite.scala

+14-12
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717

1818
package ml.dmlc.mxnet.infer
1919

20-
import ml.dmlc.mxnet.{DType, DataDesc, Shape, NDArray}
21-
20+
import ml.dmlc.mxnet._
2221
import org.mockito.Matchers._
2322
import org.mockito.Mockito
24-
import org.scalatest.{BeforeAndAfterAll}
23+
import org.scalatest.BeforeAndAfterAll
2524

2625
// scalastyle:off
2726
import java.awt.image.BufferedImage
@@ -33,15 +32,16 @@ import java.awt.image.BufferedImage
3332
class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
3433

3534
class MyImageClassifier(modelPathPrefix: String,
36-
inputDescriptors: IndexedSeq[DataDesc])
35+
inputDescriptors: IndexedSeq[DataDesc])
3736
extends ImageClassifier(modelPathPrefix, inputDescriptors) {
3837

3938
override def getPredictor(): MyClassyPredictor = {
4039
Mockito.mock(classOf[MyClassyPredictor])
4140
}
4241

4342
override def getClassifier(modelPathPrefix: String, inputDescriptors:
44-
IndexedSeq[DataDesc]): Classifier = {
43+
IndexedSeq[DataDesc], contexts: Array[Context] = Context.cpu(),
44+
epoch: Option[Int] = Some(0)): Classifier = {
4545
Mockito.mock(classOf[Classifier])
4646
}
4747

@@ -84,7 +84,7 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
8484

8585
val synset = testImageClassifier.synset
8686

87-
val predictExpectedOp : List[(String, Float)] =
87+
val predictExpectedOp: List[(String, Float)] =
8888
List[(String, Float)]((synset(1), .98f), (synset(2), .97f),
8989
(synset(3), .96f), (synset(0), .99f))
9090

@@ -93,13 +93,14 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
9393
Mockito.doReturn(IndexedSeq(predictExpectedND)).when(testImageClassifier.predictor)
9494
.predictWithNDArray(any(classOf[IndexedSeq[NDArray]]))
9595

96-
Mockito.doReturn(IndexedSeq(predictExpectedOp)).when(testImageClassifier.classifier)
96+
Mockito.doReturn(IndexedSeq(predictExpectedOp))
97+
.when(testImageClassifier.getClassifier(modelPath, inputDescriptor))
9798
.classifyWithNDArray(any(classOf[IndexedSeq[NDArray]]), Some(anyInt()))
9899

99100
val predictResult: IndexedSeq[IndexedSeq[(String, Float)]] =
100101
testImageClassifier.classifyImage(inputImage, Some(4))
101102

102-
for(i <- predictExpected.indices) {
103+
for (i <- predictExpected.indices) {
103104
assertResult(predictExpected(i).sortBy(-_)) {
104105
predictResult(i).map(_._2).toArray
105106
}
@@ -119,23 +120,24 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
119120

120121
val predictExpected: IndexedSeq[Array[Array[Float]]] =
121122
IndexedSeq[Array[Array[Float]]](Array(Array(.98f, 0.97f, 0.96f, 0.99f),
122-
Array(.98f, 0.97f, 0.96f, 0.99f)))
123+
Array(.98f, 0.97f, 0.96f, 0.99f)))
123124

124125
val synset = testImageClassifier.synset
125126

126-
val predictExpectedOp : List[List[(String, Float)]] =
127+
val predictExpectedOp: List[List[(String, Float)]] =
127128
List[List[(String, Float)]](List((synset(1), .98f), (synset(2), .97f),
128129
(synset(3), .96f), (synset(0), .99f)),
129130
List((synset(1), .98f), (synset(2), .97f),
130-
(synset(3), .96f), (synset(0), .99f)))
131+
(synset(3), .96f), (synset(0), .99f)))
131132

132133
val predictExpectedND: NDArray = NDArray.array(predictExpected.flatten.flatten.toArray,
133134
Shape(2, 4))
134135

135136
Mockito.doReturn(IndexedSeq(predictExpectedND)).when(testImageClassifier.predictor)
136137
.predictWithNDArray(any(classOf[IndexedSeq[NDArray]]))
137138

138-
Mockito.doReturn(IndexedSeq(predictExpectedOp)).when(testImageClassifier.classifier)
139+
Mockito.doReturn(IndexedSeq(predictExpectedOp))
140+
.when(testImageClassifier.getClassifier(modelPath, inputDescriptor))
139141
.classifyWithNDArray(any(classOf[IndexedSeq[NDArray]]), Some(anyInt()))
140142

141143
val result: IndexedSeq[IndexedSeq[(String, Float)]] =

scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala

+7-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import java.awt.image.BufferedImage
2323
// scalastyle:on
2424
import ml.dmlc.mxnet.Context
2525
import ml.dmlc.mxnet.DataDesc
26-
import ml.dmlc.mxnet.{NDArray, Shape}
26+
import ml.dmlc.mxnet.{Context, NDArray, Shape}
2727
import org.mockito.Matchers.any
2828
import org.mockito.Mockito
2929
import org.scalatest.BeforeAndAfterAll
@@ -36,21 +36,24 @@ class ObjectDetectorSuite extends ClassifierSuite with BeforeAndAfterAll {
3636
extends ObjectDetector(modelPathPrefix, inputDescriptors) {
3737

3838
override def getImageClassifier(modelPathPrefix: String, inputDescriptors:
39-
IndexedSeq[DataDesc]): ImageClassifier = {
39+
IndexedSeq[DataDesc], contexts: Array[Context] = Context.cpu(),
40+
epoch: Option[Int] = Some(0)): ImageClassifier = {
4041
new MyImageClassifier(modelPathPrefix, inputDescriptors)
4142
}
4243

4344
}
4445

4546
class MyImageClassifier(modelPathPrefix: String,
4647
protected override val inputDescriptors: IndexedSeq[DataDesc])
47-
extends ImageClassifier(modelPathPrefix, inputDescriptors) {
48+
extends ImageClassifier(modelPathPrefix, inputDescriptors, Context.cpu(), Some(0)) {
4849

4950
override def getPredictor(): MyClassyPredictor = {
5051
Mockito.mock(classOf[MyClassyPredictor])
5152
}
5253

53-
override def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]):
54+
override def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc],
55+
contexts: Array[Context] = Context.cpu(),
56+
epoch: Option[Int] = Some(0)):
5457
Classifier = {
5558
new MyClassifier(modelPathPrefix, inputDescriptors)
5659
}

0 commit comments

Comments
 (0)