From df7b963dcbdba77b0b456572574f5f88e0d7bea5 Mon Sep 17 00:00:00 2001 From: gforman44 Date: Sat, 14 May 2022 00:12:58 -0700 Subject: [PATCH 1/3] Create Coverage.java For regression problems, you'd like to know how often you're over-estimating the target. Sometimes you want to overestimate, say, 75% of the time, e.g. to make sure you have enough milk in stock to satisfy your customers. --- .../ai/djl/training/evaluator/Coverage.java | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 api/src/main/java/ai/djl/training/evaluator/Coverage.java diff --git a/api/src/main/java/ai/djl/training/evaluator/Coverage.java b/api/src/main/java/ai/djl/training/evaluator/Coverage.java new file mode 100644 index 00000000000..3633f13c7c3 --- /dev/null +++ b/api/src/main/java/ai/djl/training/evaluator/Coverage.java @@ -0,0 +1,34 @@ +package ai.djl.training.evaluator; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.util.Pair; + +/** Coverage for a Regression problem: + * it measures the percent of predictions greater than the actual target, + * to determine whether the predictor is over-forecasting or under-forecasting. + * e.g. 0.50 if we predict near the median of the distribution. + *
+ *  def coverage(target, forecast):
+ *     return (np.mean((target < forecast)))
+ * 
+ * https://bibinmjose.github.io/2021/03/08/errorblog.html + */ +public class Coverage extends AbstractAccuracy { + + public Coverage() { + this("Coverage", 1); + } + + public Coverage(String name, int axis) { + super(name, axis); + } + + @Override + protected Pair accuracyHelper(NDList labels, NDList predictions) { + NDArray labl = labels.head(); + NDArray pred = predictions.head(); + return new Pair((Long) labl.size(), labl.lt(pred)); + } + +} From d6a2b0606ac1cb433479aaaed7f76914821616dd Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Thu, 25 Aug 2022 00:53:31 -0700 Subject: [PATCH 2/3] Add test --- .../ai/djl/training/evaluator/Coverage.java | 17 ++++++++-------- .../tests/training/EvaluatorTest.java | 20 +++++++++++++++++++ 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/api/src/main/java/ai/djl/training/evaluator/Coverage.java b/api/src/main/java/ai/djl/training/evaluator/Coverage.java index 3633f13c7c3..59415a817fd 100644 --- a/api/src/main/java/ai/djl/training/evaluator/Coverage.java +++ b/api/src/main/java/ai/djl/training/evaluator/Coverage.java @@ -4,15 +4,17 @@ import ai.djl.ndarray.NDList; import ai.djl.util.Pair; -/** Coverage for a Regression problem: - * it measures the percent of predictions greater than the actual target, - * to determine whether the predictor is over-forecasting or under-forecasting. - * e.g. 0.50 if we predict near the median of the distribution. +/** + * Coverage for a Regression problem: it measures the percent of predictions greater than the actual + * target, to determine whether the predictor is over-forecasting or under-forecasting. e.g. 0.50 if + * we predict near the median of the distribution. + * *
  *  def coverage(target, forecast):
- *     return (np.mean((target < forecast)))
+ *     return (np.mean((target < forecast)))
  * 
- * https://bibinmjose.github.io/2021/03/08/errorblog.html + * + * ... */ public class Coverage extends AbstractAccuracy { @@ -28,7 +30,6 @@ public Coverage(String name, int axis) { protected Pair accuracyHelper(NDList labels, NDList predictions) { NDArray labl = labels.head(); NDArray pred = predictions.head(); - return new Pair((Long) labl.size(), labl.lt(pred)); + return new Pair<>(labl.size(), labl.lt(pred)); } - } diff --git a/integration/src/main/java/ai/djl/integration/tests/training/EvaluatorTest.java b/integration/src/main/java/ai/djl/integration/tests/training/EvaluatorTest.java index 597ebb54c79..0135f9cfbbc 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/EvaluatorTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/EvaluatorTest.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.types.Shape; import ai.djl.training.evaluator.AbstractAccuracy; import ai.djl.training.evaluator.Accuracy; +import ai.djl.training.evaluator.Coverage; import ai.djl.training.evaluator.TopKAccuracy; import org.testng.Assert; @@ -58,6 +59,25 @@ public void testAccuracy() { } } + @Test + public void testCoverage() { + try (NDManager manager = NDManager.newBaseManager()) { + + NDArray predictions = manager.create(new float[] {0.3f, 0.7f, 0.0f}, new Shape(3)); + NDArray labels = manager.create(new float[] {0.5f, 0.5f, 0.5f}, new Shape(3)); + + Coverage coverage = new Coverage(); + coverage.addAccumulator(""); + coverage.updateAccumulator("", new NDList(labels), new NDList(predictions)); + float accuracy = coverage.getAccumulator(""); + float expectedAccuracy = 1.f / 3; + Assert.assertEquals( + accuracy, + expectedAccuracy, + "Wrong accuracy, expected: " + expectedAccuracy + ", actual: " + accuracy); + } + } + @Test public void testTopKAccuracy() { try (NDManager manager = NDManager.newBaseManager()) { From 097f44564091c84bcded8d0bb43b53ac6b1f6947 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Thu, 25 Aug 2022 03:11:13 -0700 Subject: [PATCH 3/3] javadoc --- .../ai/djl/training/evaluator/Coverage.java | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/api/src/main/java/ai/djl/training/evaluator/Coverage.java b/api/src/main/java/ai/djl/training/evaluator/Coverage.java index 59415a817fd..2ee61c09e6b 100644 --- a/api/src/main/java/ai/djl/training/evaluator/Coverage.java +++ b/api/src/main/java/ai/djl/training/evaluator/Coverage.java @@ -1,3 +1,16 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + package ai.djl.training.evaluator; import ai.djl.ndarray.NDArray; @@ -18,14 +31,24 @@ */ public class Coverage extends AbstractAccuracy { + /** + * Creates an evaluator that measures the percent of predictions greater than the actual target. + */ public Coverage() { this("Coverage", 1); } + /** + * Creates an evaluator that measures the percent of predictions greater than the actual target. + * + * @param name the name of the evaluator, default is "Coverage" + * @param axis the axis along which to count the correct prediction, default is 1 + */ public Coverage(String name, int axis) { super(name, axis); } + /** {@inheritDoc} */ @Override protected Pair accuracyHelper(NDList labels, NDList predictions) { NDArray labl = labels.head();