Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Aug 25, 2022
1 parent 21803b6 commit 3e74fa7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
13 changes: 7 additions & 6 deletions api/src/main/java/ai/djl/training/evaluator/Coverage.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import ai.djl.ndarray.NDList;
import ai.djl.util.Pair;

/** Coverage for a Regression problem:
* it measures the percent of predictions greater than the actual target,
* to determine whether the predictor is over-forecasting or under-forecasting.
* e.g. 0.50 if we predict near the median of the distribution.
/**
* Coverage for a Regression problem: it measures the percent of predictions greater than the actual
* target, to determine whether the predictor is over-forecasting or under-forecasting. e.g. 0.50 if
* we predict near the median of the distribution.
*
* <pre>
* def coverage(target, forecast):
* return (np.mean((target < forecast)))
* </pre>
*
* https://bibinmjose.github.io/2021/03/08/errorblog.html
*/
public class Coverage extends AbstractAccuracy {
Expand All @@ -28,7 +30,6 @@ public Coverage(String name, int axis) {
protected Pair<Long, NDArray> accuracyHelper(NDList labels, NDList predictions) {
NDArray labl = labels.head();
NDArray pred = predictions.head();
return new Pair<Long,NDArray>((Long) labl.size(), labl.lt(pred));
return new Pair<>(labl.size(), labl.lt(pred));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.training.evaluator.AbstractAccuracy;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.evaluator.Coverage;
import ai.djl.training.evaluator.TopKAccuracy;
import org.testng.Assert;
import org.testng.annotations.Test;
Expand Down Expand Up @@ -57,6 +58,25 @@ public void testAccuracy() {
}
}

@Test
public void testCoverage() {
try (NDManager manager = NDManager.newBaseManager()) {

NDArray predictions = manager.create(new float[] {0.3f, 0.7f, 0.0f}, new Shape(3));
NDArray labels = manager.create(new float[] {0.5f, 0.5f, 0.5f}, new Shape(3));

Coverage coverage = new Coverage();
coverage.addAccumulator("");
coverage.updateAccumulator("", new NDList(labels), new NDList(predictions));
float accuracy = coverage.getAccumulator("");
float expectedAccuracy = 1.f / 3;
Assert.assertEquals(
accuracy,
expectedAccuracy,
"Wrong accuracy, expected: " + expectedAccuracy + ", actual: " + accuracy);
}
}

@Test
public void testTopKAccuracy() {
try (NDManager manager = NDManager.newBaseManager()) {
Expand Down

0 comments on commit 3e74fa7

Please sign in to comment.