16
16
*/
17
17
18
18
package ml .dmlc .mxnet .infer
19
+
19
20
// scalastyle:off
20
21
import java .awt .image .BufferedImage
21
22
// scalastyle:on
22
- import ml . dmlc . mxnet . NDArray
23
- import ml .dmlc .mxnet .DataDesc
23
+
24
+ import ml .dmlc .mxnet .{ Context , DataDesc , NDArray }
24
25
import scala .collection .mutable .ListBuffer
26
+
25
27
/**
26
28
* A class for object detection tasks
27
29
*
@@ -32,11 +34,16 @@ import scala.collection.mutable.ListBuffer
32
34
* file://model-dir/synset.txt
33
35
* @param inputDescriptors Descriptors defining the input node names, shape,
34
36
* layout and Type parameters
37
+ * @param contexts Device Contexts on which you want to run Inference, defaults to CPU.
38
+ * @param epoch Model epoch to load, defaults to 0.
35
39
*/
36
40
class ObjectDetector (modelPathPrefix : String ,
37
- inputDescriptors : IndexedSeq [DataDesc ]) {
41
+ inputDescriptors : IndexedSeq [DataDesc ],
42
+ contexts : Array [Context ] = Context .cpu(),
43
+ epoch : Option [Int ] = Some (0 )) {
38
44
39
- val imgClassifier : ImageClassifier = getImageClassifier(modelPathPrefix, inputDescriptors)
45
+ val imgClassifier : ImageClassifier =
46
+ getImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)
40
47
41
48
val inputShape = imgClassifier.inputShape
42
49
@@ -54,7 +61,7 @@ class ObjectDetector(modelPathPrefix: String,
54
61
* To Detect bounding boxes and corresponding labels
55
62
*
56
63
* @param inputImage : PathPrefix of the input image
57
- * @param topK : Get top k elements with maximum probability
64
+ * @param topK : Get top k elements with maximum probability
58
65
* @return List of List of tuples of (class, [probability, xmin, ymin, xmax, ymax])
59
66
*/
60
67
def imageObjectDetect (inputImage : BufferedImage ,
@@ -71,9 +78,10 @@ class ObjectDetector(modelPathPrefix: String,
71
78
/**
72
79
* Takes input images as NDArrays. Useful when you want to perform multiple operations on
73
80
* the input Array, or when you want to pass a batch of input images.
81
+ *
74
82
* @param input : Indexed Sequence of NDArrays
75
- * @param topK : (Optional) How many top_k(sorting will be based on the last axis)
76
- * elements to return. If not passed, returns all unsorted output.
83
+ * @param topK : (Optional) How many top_k(sorting will be based on the last axis)
84
+ * elements to return. If not passed, returns all unsorted output.
77
85
* @return List of List of tuples of (class, [probability, xmin, ymin, xmax, ymax])
78
86
*/
79
87
def objectDetectWithNDArray (input : IndexedSeq [NDArray ], topK : Option [Int ])
@@ -90,10 +98,10 @@ class ObjectDetector(modelPathPrefix: String,
90
98
batchResult.toIndexedSeq
91
99
}
92
100
93
- private def sortAndReformat (predictResultND : NDArray , topK : Option [Int ])
101
+ private def sortAndReformat (predictResultND : NDArray , topK : Option [Int ])
94
102
: IndexedSeq [(String , Array [Float ])] = {
95
103
val predictResult : ListBuffer [Array [Float ]] = ListBuffer [Array [Float ]]()
96
- val accuracy : ListBuffer [Float ] = ListBuffer [Float ]()
104
+ val accuracy : ListBuffer [Float ] = ListBuffer [Float ]()
97
105
98
106
// iterating over the all the predictions
99
107
val length = predictResultND.shape(0 )
@@ -110,7 +118,7 @@ class ObjectDetector(modelPathPrefix: String,
110
118
handler.execute(r.dispose())
111
119
}
112
120
var result = IndexedSeq [(String , Array [Float ])]()
113
- if (topK.isDefined) {
121
+ if (topK.isDefined) {
114
122
var sortedIndices = accuracy.zipWithIndex.sortBy(- _._1).map(_._2)
115
123
sortedIndices = sortedIndices.take(topK.get)
116
124
// takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax
@@ -127,8 +135,9 @@ class ObjectDetector(modelPathPrefix: String,
127
135
128
136
/**
129
137
* To classify batch of input images according to the provided model
138
+ *
130
139
* @param inputBatch Input batch of Buffered images
131
- * @param topK Get top k elements with maximum probability
140
+ * @param topK Get top k elements with maximum probability
132
141
* @return List of list of tuples of (class, probability)
133
142
*/
134
143
def imageBatchObjectDetect (inputBatch : Traversable [BufferedImage ], topK : Option [Int ] = None ):
@@ -148,9 +157,11 @@ class ObjectDetector(modelPathPrefix: String,
148
157
result
149
158
}
150
159
151
- def getImageClassifier (modelPathPrefix : String , inputDescriptors : IndexedSeq [DataDesc ]):
160
+ def getImageClassifier (modelPathPrefix : String , inputDescriptors : IndexedSeq [DataDesc ],
161
+ contexts : Array [Context ] = Context .cpu(),
162
+ epoch : Option [Int ] = Some (0 )):
152
163
ImageClassifier = {
153
- new ImageClassifier (modelPathPrefix, inputDescriptors)
164
+ new ImageClassifier (modelPathPrefix, inputDescriptors, contexts, epoch )
154
165
}
155
166
156
167
}
0 commit comments