Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

adding context parameter to infer api- imageclassifier and objectdetector #10252

Merged
merged 2 commits into from
Mar 26, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package ml.dmlc.mxnet.infer

import ml.dmlc.mxnet.{DataDesc, NDArray, Shape}
import ml.dmlc.mxnet.{Context, DataDesc, NDArray, Shape}

import scala.collection.mutable.ListBuffer

Expand All @@ -39,11 +39,11 @@ import javax.imageio.ImageIO
* layout and Type parameters
*/
class ImageClassifier(modelPathPrefix: String,
inputDescriptors: IndexedSeq[DataDesc])
inputDescriptors: IndexedSeq[DataDesc],
contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0))
extends Classifier(modelPathPrefix,
inputDescriptors) {

val classifier: Classifier = getClassifier(modelPathPrefix, inputDescriptors)
inputDescriptors, contexts, epoch) {

protected[infer] val inputLayout = inputDescriptors.head.layout

Expand Down Expand Up @@ -108,8 +108,10 @@ class ImageClassifier(modelPathPrefix: String,
result
}

def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]): Classifier = {
new Classifier(modelPathPrefix, inputDescriptors)
def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc],
contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0)): Classifier = {
new Classifier(modelPathPrefix, inputDescriptors, contexts, epoch)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
*/

package ml.dmlc.mxnet.infer

// scalastyle:off
import java.awt.image.BufferedImage
// scalastyle:on
import ml.dmlc.mxnet.NDArray
import ml.dmlc.mxnet.DataDesc

import ml.dmlc.mxnet.{Context, DataDesc, NDArray}
import scala.collection.mutable.ListBuffer

/**
* A class for object detection tasks
*
Expand All @@ -34,9 +36,12 @@ import scala.collection.mutable.ListBuffer
* layout and Type parameters
*/
class ObjectDetector(modelPathPrefix: String,
inputDescriptors: IndexedSeq[DataDesc]) {
inputDescriptors: IndexedSeq[DataDesc],
contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0)) {

val imgClassifier: ImageClassifier = getImageClassifier(modelPathPrefix, inputDescriptors)
val imgClassifier: ImageClassifier =
getImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)

val inputShape = imgClassifier.inputShape

Expand All @@ -54,7 +59,7 @@ class ObjectDetector(modelPathPrefix: String,
* To Detect bounding boxes and corresponding labels
*
* @param inputImage : PathPrefix of the input image
* @param topK : Get top k elements with maximum probability
* @param topK : Get top k elements with maximum probability
* @return List of List of tuples of (class, [probability, xmin, ymin, xmax, ymax])
*/
def imageObjectDetect(inputImage: BufferedImage,
Expand All @@ -71,9 +76,10 @@ class ObjectDetector(modelPathPrefix: String,
/**
* Takes input images as NDArrays. Useful when you want to perform multiple operations on
* the input Array, or when you want to pass a batch of input images.
*
* @param input : Indexed Sequence of NDArrays
* @param topK : (Optional) How many top_k(sorting will be based on the last axis)
* elements to return. If not passed, returns all unsorted output.
* @param topK : (Optional) How many top_k(sorting will be based on the last axis)
* elements to return. If not passed, returns all unsorted output.
* @return List of List of tuples of (class, [probability, xmin, ymin, xmax, ymax])
*/
def objectDetectWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int])
Expand All @@ -90,10 +96,10 @@ class ObjectDetector(modelPathPrefix: String,
batchResult.toIndexedSeq
}

private def sortAndReformat(predictResultND : NDArray, topK: Option[Int])
private def sortAndReformat(predictResultND: NDArray, topK: Option[Int])
: IndexedSeq[(String, Array[Float])] = {
val predictResult: ListBuffer[Array[Float]] = ListBuffer[Array[Float]]()
val accuracy : ListBuffer[Float] = ListBuffer[Float]()
val accuracy: ListBuffer[Float] = ListBuffer[Float]()

// iterating over the all the predictions
val length = predictResultND.shape(0)
Expand All @@ -110,7 +116,7 @@ class ObjectDetector(modelPathPrefix: String,
handler.execute(r.dispose())
}
var result = IndexedSeq[(String, Array[Float])]()
if(topK.isDefined) {
if (topK.isDefined) {
var sortedIndices = accuracy.zipWithIndex.sortBy(-_._1).map(_._2)
sortedIndices = sortedIndices.take(topK.get)
// takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax
Expand All @@ -127,8 +133,9 @@ class ObjectDetector(modelPathPrefix: String,

/**
* To classify batch of input images according to the provided model
*
* @param inputBatch Input batch of Buffered images
* @param topK Get top k elements with maximum probability
* @param topK Get top k elements with maximum probability
* @return List of list of tuples of (class, probability)
*/
def imageBatchObjectDetect(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None):
Expand All @@ -148,9 +155,11 @@ class ObjectDetector(modelPathPrefix: String,
result
}

def getImageClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]):
def getImageClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc],
contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0)):
ImageClassifier = {
new ImageClassifier(modelPathPrefix, inputDescriptors)
new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

package ml.dmlc.mxnet.infer

import ml.dmlc.mxnet.{DType, DataDesc, Shape, NDArray}

import ml.dmlc.mxnet._
import org.mockito.Matchers._
import org.mockito.Mockito
import org.scalatest.{BeforeAndAfterAll}
import org.scalatest.BeforeAndAfterAll

// scalastyle:off
import java.awt.image.BufferedImage
Expand All @@ -33,15 +32,16 @@ import java.awt.image.BufferedImage
class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {

class MyImageClassifier(modelPathPrefix: String,
inputDescriptors: IndexedSeq[DataDesc])
inputDescriptors: IndexedSeq[DataDesc])
extends ImageClassifier(modelPathPrefix, inputDescriptors) {

override def getPredictor(): MyClassyPredictor = {
Mockito.mock(classOf[MyClassyPredictor])
}

override def getClassifier(modelPathPrefix: String, inputDescriptors:
IndexedSeq[DataDesc]): Classifier = {
IndexedSeq[DataDesc], contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0)): Classifier = {
Mockito.mock(classOf[Classifier])
}

Expand Down Expand Up @@ -84,7 +84,7 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {

val synset = testImageClassifier.synset

val predictExpectedOp : List[(String, Float)] =
val predictExpectedOp: List[(String, Float)] =
List[(String, Float)]((synset(1), .98f), (synset(2), .97f),
(synset(3), .96f), (synset(0), .99f))

Expand All @@ -93,13 +93,14 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
Mockito.doReturn(IndexedSeq(predictExpectedND)).when(testImageClassifier.predictor)
.predictWithNDArray(any(classOf[IndexedSeq[NDArray]]))

Mockito.doReturn(IndexedSeq(predictExpectedOp)).when(testImageClassifier.classifier)
Mockito.doReturn(IndexedSeq(predictExpectedOp))
.when(testImageClassifier.getClassifier(modelPath, inputDescriptor))
.classifyWithNDArray(any(classOf[IndexedSeq[NDArray]]), Some(anyInt()))

val predictResult: IndexedSeq[IndexedSeq[(String, Float)]] =
testImageClassifier.classifyImage(inputImage, Some(4))

for(i <- predictExpected.indices) {
for (i <- predictExpected.indices) {
assertResult(predictExpected(i).sortBy(-_)) {
predictResult(i).map(_._2).toArray
}
Expand All @@ -119,23 +120,24 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {

val predictExpected: IndexedSeq[Array[Array[Float]]] =
IndexedSeq[Array[Array[Float]]](Array(Array(.98f, 0.97f, 0.96f, 0.99f),
Array(.98f, 0.97f, 0.96f, 0.99f)))
Array(.98f, 0.97f, 0.96f, 0.99f)))

val synset = testImageClassifier.synset

val predictExpectedOp : List[List[(String, Float)]] =
val predictExpectedOp: List[List[(String, Float)]] =
List[List[(String, Float)]](List((synset(1), .98f), (synset(2), .97f),
(synset(3), .96f), (synset(0), .99f)),
List((synset(1), .98f), (synset(2), .97f),
(synset(3), .96f), (synset(0), .99f)))
(synset(3), .96f), (synset(0), .99f)))

val predictExpectedND: NDArray = NDArray.array(predictExpected.flatten.flatten.toArray,
Shape(2, 4))

Mockito.doReturn(IndexedSeq(predictExpectedND)).when(testImageClassifier.predictor)
.predictWithNDArray(any(classOf[IndexedSeq[NDArray]]))

Mockito.doReturn(IndexedSeq(predictExpectedOp)).when(testImageClassifier.classifier)
Mockito.doReturn(IndexedSeq(predictExpectedOp))
.when(testImageClassifier.getClassifier(modelPath, inputDescriptor))
.classifyWithNDArray(any(classOf[IndexedSeq[NDArray]]), Some(anyInt()))

val result: IndexedSeq[IndexedSeq[(String, Float)]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.awt.image.BufferedImage
// scalastyle:on
import ml.dmlc.mxnet.Context
import ml.dmlc.mxnet.DataDesc
import ml.dmlc.mxnet.{NDArray, Shape}
import ml.dmlc.mxnet.{Context, NDArray, Shape}
import org.mockito.Matchers.any
import org.mockito.Mockito
import org.scalatest.BeforeAndAfterAll
Expand All @@ -36,21 +36,24 @@ class ObjectDetectorSuite extends ClassifierSuite with BeforeAndAfterAll {
extends ObjectDetector(modelPathPrefix, inputDescriptors) {

override def getImageClassifier(modelPathPrefix: String, inputDescriptors:
IndexedSeq[DataDesc]): ImageClassifier = {
IndexedSeq[DataDesc], contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0)): ImageClassifier = {
new MyImageClassifier(modelPathPrefix, inputDescriptors)
}

}

class MyImageClassifier(modelPathPrefix: String,
protected override val inputDescriptors: IndexedSeq[DataDesc])
extends ImageClassifier(modelPathPrefix, inputDescriptors) {
extends ImageClassifier(modelPathPrefix, inputDescriptors, Context.cpu(), Some(0)) {

override def getPredictor(): MyClassyPredictor = {
Mockito.mock(classOf[MyClassyPredictor])
}

override def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]):
override def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc],
contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0)):
Classifier = {
new MyClassifier(modelPathPrefix, inputDescriptors)
}
Expand Down