Skip to content

Commit

Permalink
Create Coverage.java (deepjavalibrary#1653)
Browse files Browse the repository at this point in the history
For regression problems, you'd like to know how often you're over-estimating the target.   Sometimes you want to overestimate, say, 75% of the time, e.g. to make sure you have enough milk in stock to satisfy your customers.

Co-authored-by: KexinFeng <fengx463@umn.edu>
  • Loading branch information
2 people authored and patins1 committed Aug 26, 2022
1 parent d1f9566 commit 94147de
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
58 changes: 58 additions & 0 deletions api/src/main/java/ai/djl/training/evaluator/Coverage.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.training.evaluator;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.util.Pair;

/**
* Coverage for a Regression problem: it measures the percent of predictions greater than the actual
* target, to determine whether the predictor is over-forecasting or under-forecasting. e.g. 0.50 if
* we predict near the median of the distribution.
*
* <pre>
* def coverage(target, forecast):
* return (np.mean((target &lt; forecast)))
* </pre>
*
* <a href="https://bibinmjose.github.io/2021/03/08/errorblog.html">...</a>
*/
public class Coverage extends AbstractAccuracy {

/**
* Creates an evaluator that measures the percent of predictions greater than the actual target.
*/
public Coverage() {
this("Coverage", 1);
}

/**
* Creates an evaluator that measures the percent of predictions greater than the actual target.
*
* @param name the name of the evaluator, default is "Coverage"
* @param axis the axis along which to count the correct prediction, default is 1
*/
public Coverage(String name, int axis) {
super(name, axis);
}

/** {@inheritDoc} */
@Override
protected Pair<Long, NDArray> accuracyHelper(NDList labels, NDList predictions) {
NDArray labl = labels.head();
NDArray pred = predictions.head();
return new Pair<>(labl.size(), labl.lt(pred));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.training.evaluator.AbstractAccuracy;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.evaluator.Coverage;
import ai.djl.training.evaluator.TopKAccuracy;

import org.testng.Assert;
Expand Down Expand Up @@ -58,6 +59,25 @@ public void testAccuracy() {
}
}

@Test
public void testCoverage() {
try (NDManager manager = NDManager.newBaseManager()) {

NDArray predictions = manager.create(new float[] {0.3f, 0.7f, 0.0f}, new Shape(3));
NDArray labels = manager.create(new float[] {0.5f, 0.5f, 0.5f}, new Shape(3));

Coverage coverage = new Coverage();
coverage.addAccumulator("");
coverage.updateAccumulator("", new NDList(labels), new NDList(predictions));
float accuracy = coverage.getAccumulator("");
float expectedAccuracy = 1.f / 3;
Assert.assertEquals(
accuracy,
expectedAccuracy,
"Wrong accuracy, expected: " + expectedAccuracy + ", actual: " + accuracy);
}
}

@Test
public void testTopKAccuracy() {
try (NDManager manager = NDManager.newBaseManager()) {
Expand Down

0 comments on commit 94147de

Please sign in to comment.