From a4c8a8504cf4e6b78a9d155fa5e286fc2de079a5 Mon Sep 17 00:00:00 2001 From: Carkham <60054018+Carkham@users.noreply.github.com> Date: Tue, 20 Sep 2022 10:38:29 +0800 Subject: [PATCH] [timeseries] add probability distribution support for timeseries (#2025) * feature: add distribution and loss function * bug fix and add unit test * add comments and unit test * add copyright and package-info * style fix * add args array * feature: add affinely distribution * feature: add sample for zero num * bug fix: sum the loss * bug fix: add builder for StudentTOutput and clamp the min value in neg_binomial * Update extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java * Update extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java * Update extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java * format Co-authored-by: Carkham <1302112560@qq.com> Co-authored-by: KexinFeng --- .../distribution/AffineTransformed.java | 74 +++++++++ .../timeseries/distribution/Distribution.java | 113 ++++++++++++++ .../distribution/DistributionLoss.java | 65 ++++++++ .../distribution/NegativeBinomial.java | 107 +++++++++++++ .../djl/timeseries/distribution/StudentT.java | 109 +++++++++++++ .../distribution/output/ArgProj.java | 146 ++++++++++++++++++ .../output/DistributionOutput.java | 91 +++++++++++ .../output/NegativeBinomialOutput.java | 53 +++++++ .../distribution/output/StudentTOutput.java | 54 +++++++ .../distribution/output/package-info.java | 15 ++ .../timeseries/distribution/package-info.java | 15 ++ .../distribution/DistributionTest.java | 58 +++++++ .../timeseries/distribution/package-info.java | 15 ++ 13 files changed, 915 insertions(+) create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/AffineTransformed.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/package-info.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/package-info.java create mode 100644 extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java create mode 100644 extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/AffineTransformed.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/AffineTransformed.java new file mode 100644 index 00000000000..251aea51689 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/AffineTransformed.java @@ -0,0 +1,74 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution; + +import ai.djl.ndarray.NDArray; + +/** Represents the distribution of an affinely transformed random variable. */ +public class AffineTransformed extends Distribution { + + private Distribution baseDistribution; + private NDArray loc; + private NDArray scale; + + /** + * Construct a new {@code AffineTransformed} + * + *

This is the distribution of Y = scale * X + loc, where X is a random variable distributed + * according to {@code baseDistribution}. + * + * @param baseDistribution original distribution + * @param loc translation parameter of the affine transformation + * @param scale scaling parameter of the affine transformation + */ + public AffineTransformed(Distribution baseDistribution, NDArray loc, NDArray scale) { + this.baseDistribution = baseDistribution; + this.loc = loc == null ? baseDistribution.mean().zerosLike() : loc; + this.scale = scale == null ? baseDistribution.mean().onesLike() : scale; + } + + /** {@inheritDoc} */ + @Override + public NDArray logProb(NDArray target) { + NDArray x = fInv(target); + NDArray ladj = logAbsDetJac(x); + NDArray lp = ladj.mul(-1); + return baseDistribution.logProb(x).add(lp); + } + + /** {@inheritDoc} */ + @Override + public NDArray sample(int numSamples) { + NDArray sample = baseDistribution.sample(numSamples); + return f(sample); + } + + /** {@inheritDoc} */ + @Override + public NDArray mean() { + return baseDistribution.mean().mul(scale).add(loc); + } + + private NDArray f(NDArray x) { + return x.mul(scale).add(loc); + } + + private NDArray fInv(NDArray y) { + return y.sub(loc).div(scale); + } + + private NDArray logAbsDetJac(NDArray x) { + return scale.broadcast(x.getShape()).abs().log(); + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java new file mode 100644 index 00000000000..739694cb9a7 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java @@ -0,0 +1,113 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; + +/** An abstract class representing probability distribution. */ +public abstract class Distribution { + + /** + * Compute the log of the probability density/mass function evaluated at target. + * + * @param target {@link NDArray} of shape (*batch_shape, *event_shape) + * @return Tensor of shape (batch_shape) containing the probability log-density for each event + * in target + */ + public abstract NDArray logProb(NDArray target); + + /** + * Draw samples from the distribution. + * + *

This function would expand the dimension of arguments, the first dimension of the output + * will be numSamples. + * + * @param numSamples Number of samples to be drawn + * @return a {@link NDArray} has shape (num_samples, *batch_shape, *target_shape) + */ + public abstract NDArray sample(int numSamples); + + /** + * Draw samples from the distribution. + * + *

This function would not expand the dimension + * + * @return a sampled {@link NDArray} + */ + public NDArray sample() { + return sample(0); + } + + /** + * Return the mean of the distribution. + * + * @return the mean of the distribution + */ + public abstract NDArray mean(); + + /** + * A builder to extend for all classes extend the {@link Distribution}. + * + * @param the concrete builder type + */ + public abstract static class DistributionBuilder> { + protected NDList distrArgs; + protected NDArray scale; + protected NDArray loc; + + /** + * Set the appropriate arguments for the probability distribution. + * + * @param distrArgs a {@link NDList} containing distribution args named after the parameter + * name + * @return this builder + */ + public T setDistrArgs(NDList distrArgs) { + this.distrArgs = distrArgs; + return self(); + } + + /** + * Set the affine scale for the probability distribution. + * + * @param scale the affine scale + * @return this builder + */ + public T optScale(NDArray scale) { + this.scale = scale; + return self(); + } + + /** + * Set the affine location of the probability. + * + * @param loc the affine location + * @return this builder + */ + public T optLoc(NDArray loc) { + this.loc = loc; + return self(); + } + + /** + * Build a {@code Distribution}. + * + * @return the {@code Distribution} + */ + public abstract Distribution build(); + + protected abstract T self(); + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java new file mode 100644 index 00000000000..30251e71728 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java @@ -0,0 +1,65 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDList; +import ai.djl.timeseries.distribution.output.DistributionOutput; +import ai.djl.training.loss.Loss; + +/** + * {@code DistributionLoss} calculates loss for a given distribution. + * + *

Distribution Loss is calculated by {@link Distribution#logProb(NDArray)} at label point + */ +public class DistributionLoss extends Loss { + + private DistributionOutput distrOutput; + + /** + * Calculates Distribution Loss between the label and distribution. + * + * @param name the name of the loss + * @param distrOutput the {@link DistributionOutput} to construct the target distribution + */ + public DistributionLoss(String name, DistributionOutput distrOutput) { + super(name); + this.distrOutput = distrOutput; + } + + /** {@inheritDoc} */ + @Override + public NDArray evaluate(NDList labels, NDList predictions) { + Distribution.DistributionBuilder builder = distrOutput.distributionBuilder(); + builder.setDistrArgs(predictions); + if (predictions.contains("scale")) { + builder.optScale(predictions.get("scale")); + } + if (predictions.contains("loc")) { + builder.optLoc(predictions.get("loc")); + } + + NDArray loss = builder.build().logProb(labels.singletonOrThrow()).mul(-1); + + if (predictions.contains("loss_weights")) { + NDArray lossWeights = predictions.get("loss_weights"); + NDArray weightedValue = + NDArrays.where(lossWeights.neq(0), loss.mul(lossWeights), loss.zerosLike()); + NDArray sumWeights = lossWeights.sum().maximum(1.); + loss = weightedValue.sum().div(sumWeights); + } + return loss; + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java new file mode 100644 index 00000000000..8565babc244 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java @@ -0,0 +1,107 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.util.Preconditions; + +/** + * Negative binomial distribution. + * + *

The distribution of the number of successes in a sequence of independent Bernoulli trials. + * + *

Two arguments for this distribution. {@code mu} mean of the distribution, {@code alpha} the + * inverse number of negative Bernoulli trials to stop + */ +public final class NegativeBinomial extends Distribution { + + private NDArray mu; + private NDArray alpha; + + NegativeBinomial(Builder builder) { + mu = builder.distrArgs.get("mu"); + alpha = builder.distrArgs.get("alpha"); + } + + /** {@inheritDoc} */ + @Override + public NDArray logProb(NDArray target) { + + NDArray alphaInv = alpha.getNDArrayInternal().rdiv(1); + NDArray alphaTimesMu = alpha.mul(mu); + + return target.mul(alphaTimesMu.div(alphaTimesMu.add(1)).log()) + .sub(alphaInv.mul(alphaTimesMu.add(1).log())) + .add(target.add(alphaInv).gammaln()) + .sub(target.add(1.).gammaln()) + .sub(alphaInv.gammaln()); + } + + /** {@inheritDoc} */ + @Override + public NDArray sample(int numSamples) { + NDManager manager = mu.getManager(); + NDArray expandedMu = numSamples > 0 ? mu.expandDims(0).repeat(0, numSamples) : mu; + NDArray expandedAlpha = numSamples > 0 ? alpha.expandDims(0).repeat(0, numSamples) : alpha; + + NDArray r = expandedAlpha.getNDArrayInternal().rdiv(1f); + NDArray theta = expandedAlpha.mul(expandedMu); + return manager.samplePoisson(manager.sampleGamma(r, theta)); + } + + /** {@inheritDoc} */ + @Override + public NDArray mean() { + return mu; + } + + /** + * Creates a builder to build a {@code NegativeBinomial}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** The builder to construct a {@code NegativeBinomial}. */ + public static final class Builder extends DistributionBuilder { + + /** {@inheritDoc} */ + @Override + public Distribution build() { + Preconditions.checkArgument( + distrArgs.contains("mu"), "NegativeBinomial's args must contain mu."); + Preconditions.checkArgument( + distrArgs.contains("alpha"), "NegativeBinomial's args must contain alpha."); + // We cannot scale using the affine transformation since negative binomial should return + // integers. Instead we scale the parameters. + if (scale != null) { + NDArray mu = distrArgs.get("mu"); + mu = mu.mul(scale); + mu.setName("mu"); + distrArgs.remove("mu"); + distrArgs.add(mu); + } + return new NegativeBinomial(this); + } + + /** {@inheritDoc} */ + @Override + protected Builder self() { + return this; + } + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java new file mode 100644 index 00000000000..34168acd601 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java @@ -0,0 +1,109 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDManager; +import ai.djl.util.Preconditions; + +/** + * Student's t-test distribution. + * + *

Three arguments for this distribution. {@code mu} mean of the distribution, {@code sigma} the + * standard deviations (scale), {@code nu} degrees of freedom. + */ +public class StudentT extends Distribution { + + private NDArray mu; + private NDArray sigma; + private NDArray nu; + + StudentT(Builder builder) { + mu = builder.distrArgs.get("mu"); + sigma = builder.distrArgs.get("sigma"); + nu = builder.distrArgs.get("nu"); + } + + /** {@inheritDoc} */ + @Override + public NDArray logProb(NDArray target) { + NDArray nup1Half = nu.add(1.).div(2.); + NDArray part1 = nu.getNDArrayInternal().rdiv(1.).mul(target.sub(mu).div(sigma).square()); + + NDArray z = + nup1Half.gammaln() + .sub(nu.div(2.).gammaln()) + .sub(nu.mul(Math.PI).log().mul(0.5)) + .sub(sigma.log()); + + return z.sub(nup1Half.mul(part1.add(1.).log())); + } + + /** {@inheritDoc} */ + @Override + public NDArray sample(int numSamples) { + NDManager manager = mu.getManager(); + NDArray expandedMu = numSamples > 0 ? mu.expandDims(0).repeat(0, numSamples) : mu; + NDArray expandedSigma = numSamples > 0 ? sigma.expandDims(0).repeat(0, numSamples) : sigma; + NDArray expandedNu = numSamples > 0 ? nu.expandDims(0).repeat(0, numSamples) : nu; + + NDArray gammas = + manager.sampleGamma( + expandedNu.div(2.), + expandedNu.mul(expandedSigma.square()).getNDArrayInternal().rdiv(2.)); + return manager.sampleNormal(expandedMu, gammas.sqrt().getNDArrayInternal().rdiv(1.)); + } + + /** {@inheritDoc} */ + @Override + public NDArray mean() { + return NDArrays.where(nu.gt(1.0), mu, mu.getManager().full(mu.getShape(), Float.NaN)); + } + + /** + * Creates a builder to build a {@code NegativeBinomial}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** The builder to construct a {@code NegativeBinomial}. */ + public static final class Builder extends DistributionBuilder { + + /** {@inheritDoc} */ + @Override + public Distribution build() { + Preconditions.checkArgument( + distrArgs.contains("mu"), "StudentTl's args must contain mu."); + Preconditions.checkArgument( + distrArgs.contains("sigma"), "StudentTl's args must contain sigma."); + Preconditions.checkArgument( + distrArgs.contains("nu"), "StudentTl's args must contain nu."); + StudentT baseDistr = new StudentT(this); + if (scale == null && loc == null) { + return baseDistr; + } + return new AffineTransformed(baseDistr, loc, scale); + } + + /** {@inheritDoc} */ + @Override + protected Builder self() { + return this; + } + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java new file mode 100644 index 00000000000..6969666e881 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java @@ -0,0 +1,146 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution.output; + +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.AbstractBlock; +import ai.djl.nn.Block; +import ai.djl.nn.core.Linear; +import ai.djl.training.ParameterStore; +import ai.djl.util.Pair; +import ai.djl.util.PairList; +import ai.djl.util.Preconditions; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +/** + * A Block used to map the output of a dense layer to statistical parameters, like mean and standard + * deviation. It will be used in both training and inference. + */ +public final class ArgProj extends AbstractBlock { + + private Block domainMap; + private List proj; + + ArgProj(Builder builder) { + proj = new ArrayList<>(); + for (Pair entry : builder.argsDim) { + proj.add( + addChildBlock( + String.format("%s_distr_%s", builder.prefix, entry.getKey()), + Linear.builder().setUnits(entry.getValue()).build())); + } + domainMap = + addChildBlock(String.format("%s_domain_map", builder.prefix), builder.domainMap); + } + + /** {@inheritDoc} */ + @Override + protected void initializeChildBlocks( + NDManager manager, DataType dataType, Shape... inputShapes) { + for (Block block : proj) { + block.initialize(manager, dataType, inputShapes); + } + } + + /** {@inheritDoc} */ + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + NDList paramsUnbounded = new NDList(); + for (Block block : proj) { + paramsUnbounded.add( + block.forward(parameterStore, inputs, training, params).singletonOrThrow()); + } + return domainMap.forward(parameterStore, paramsUnbounded, training, params); + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputShapes) { + Shape[] projOutShapes = new Shape[proj.size()]; + for (int i = 0; i < proj.size(); i++) { + projOutShapes[i] = proj.get(i).getOutputShapes(inputShapes)[0]; + } + return domainMap.getOutputShapes(projOutShapes); + } + + /** + * Creates a builder to build a {@code ArgProj}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** The Builder to construct a {@code ArgProj} type of {@link Block}. */ + public static final class Builder { + private PairList argsDim; + private Function domainMap; + private String prefix = ""; + + /** + * Set the arguments dimensions of distribution. + * + * @param argsDim the arguments dimension + * @return this builder + */ + public Builder setArgsDim(PairList argsDim) { + this.argsDim = argsDim; + return this; + } + + /** + * Set the domain map function. + * + * @param domainMap the domain map function + * @return this builder + */ + public Builder setDomainMap(Function domainMap) { + this.domainMap = domainMap; + return this; + } + + /** + * Set the block name prefix. + * + * @param prefix the prefix + * @return this builder + */ + public Builder optPrefix(String prefix) { + this.prefix = prefix; + return this; + } + + /** + * Build a {@link ArgProj} block. + * + * @return the {@link ArgProj} block. + */ + public ArgProj build() { + Preconditions.checkArgument(argsDim != null, "must specify dim args"); + Preconditions.checkArgument(domainMap != null, "must specify domain PairList function"); + return new ArgProj(this); + } + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java new file mode 100644 index 00000000000..de3538df489 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java @@ -0,0 +1,91 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution.output; + +import ai.djl.ndarray.NDList; +import ai.djl.timeseries.distribution.Distribution; +import ai.djl.util.PairList; + +/** A class to construct a distribution given the output of a network. */ +public abstract class DistributionOutput { + + protected PairList argsDim; + private float valueInSupport; + + /** + * A float that will have a valid numeric value when computing the log-loss of the corresponding + * distribution. + * + *

By default {@code 0f}. This value will be used when padding data series. + * + * @return the valueInSupport + */ + public float getValueInSupport() { + return valueInSupport; + } + + /** + * Return the corresponding projection block based on the arguments dimension of different + * distributions. + * + * @return the corresponding projection block + */ + public ArgProj getArgsProj() { + return ArgProj.builder().setArgsDim(argsDim).setDomainMap(this::domainMap).build(); + } + + /** + * Return the corresponding projection block based on the arguments dimension of different + * ditributions. + * + * @param prefix the prefix string of projection layer block + * @return the corresponding projection block + */ + public ArgProj getArgsProj(String prefix) { + return ArgProj.builder() + .setArgsDim(argsDim) + .setDomainMap(this::domainMap) + .optPrefix(prefix) + .build(); + } + + /** + * Return an array containing all the argument names. + * + * @return an array containing all the argument names + */ + public String[] getArgsArray() { + return argsDim.keyArray(new String[argsDim.size()]); + } + + /** + * Convert arguments to the right shape and domain. The domain depends on the type of + * distribution, while the correct shape is obtained by reshaping the trailing axis in such a + * way that the returned tensors define a distribution of the right event_shape. + * + *

This function is usually used as the lambda of the Lambda block. + * + * @param arrays the arguments + * @return converted arguments + */ + public abstract NDList domainMap(NDList arrays); + + /** + * Return the associated {@code DistributionBuilder}, given the collection of constructor + * arguments and, optionally, a scale tensor. + * + * @return the associated {@code DistributionBuilder} + */ + public abstract Distribution.DistributionBuilder distributionBuilder(); +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java new file mode 100644 index 00000000000..84222fe9741 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java @@ -0,0 +1,53 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution.output; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.timeseries.distribution.Distribution; +import ai.djl.timeseries.distribution.NegativeBinomial; +import ai.djl.util.PairList; + +/** + * {@code NegativeBinomialOutput} is a {@link DistributionOutput} for the negative binomial + * distribution. + */ +public final class NegativeBinomialOutput extends DistributionOutput { + + /** Construct a negative binomial output with two arguments, {@code mu} and {@code alpha}. */ + public NegativeBinomialOutput() { + argsDim = new PairList<>(2); + argsDim.add("mu", 1); + argsDim.add("alpha", 1); + } + + /** {@inheritDoc} */ + @Override + public NDList domainMap(NDList arrays) { + NDArray mu = arrays.get(0); + NDArray alpha = arrays.get(1); + mu = mu.getNDArrayInternal().softPlus().maximum(Float.MIN_VALUE).squeeze(-1); + alpha = alpha.getNDArrayInternal().softPlus().maximum(Float.MIN_VALUE).squeeze(-1); + // TODO: make setName() must be implemented + mu.setName("mu"); + alpha.setName("alpha"); + return new NDList(mu, alpha); + } + + /** {@inheritDoc} */ + @Override + public Distribution.DistributionBuilder distributionBuilder() { + return NegativeBinomial.builder(); + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java new file mode 100644 index 00000000000..a38e7cab5c9 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java @@ -0,0 +1,54 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution.output; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.timeseries.distribution.Distribution; +import ai.djl.timeseries.distribution.StudentT; +import ai.djl.util.PairList; + +/** {@code StudentTOutput} is a {@link DistributionOutput} for the Student's t-test distribution. */ +public class StudentTOutput extends DistributionOutput { + + /** Construct a negative binomial output with two arguments, {@code mu} and {@code sigma}. */ + public StudentTOutput() { + argsDim = new PairList<>(3); + argsDim.add("mu", 1); + argsDim.add("sigma", 1); + argsDim.add("nu", 1); + } + + /** {@inheritDoc} */ + @Override + public NDList domainMap(NDList arrays) { + NDArray mu = arrays.get(0); + NDArray sigma = arrays.get(1); + NDArray nu = arrays.get(2); + mu = mu.squeeze(-1); + sigma = sigma.getNDArrayInternal().softPlus().squeeze(-1); + nu = nu.getNDArrayInternal().softPlus().add(2.).squeeze(-1); + // TODO: make setName() must be implemented + mu.setName("mu"); + sigma.setName("sigma"); + nu.setName("nu"); + return new NDList(mu, sigma, nu); + } + + /** {@inheritDoc} */ + @Override + public Distribution.DistributionBuilder distributionBuilder() { + return StudentT.builder(); + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/package-info.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/package-info.java new file mode 100644 index 00000000000..8003778cd6a --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains classes to support construct distribution and project arguments. */ +package ai.djl.timeseries.distribution.output; diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/package-info.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/package-info.java new file mode 100644 index 00000000000..5f18cc5f60e --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains classes to support distribution in djl. */ +package ai.djl.timeseries.distribution; diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java new file mode 100644 index 00000000000..051f7f4743b --- /dev/null +++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java @@ -0,0 +1,58 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.testing.Assertions; + +import org.testng.annotations.Test; + +public class DistributionTest { + + @Test + public void testNegativeBinomial() { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray mu = manager.create(new float[] {1000f, 1f}); + NDArray alpha = manager.create(new float[] {1f, 2f}); + mu.setName("mu"); + alpha.setName("alpha"); + Distribution negativeBinomial = + NegativeBinomial.builder().setDistrArgs(new NDList(mu, alpha)).build(); + + NDArray expected = manager.create(new float[] {-6.9098f, -1.6479f}); + NDArray real = negativeBinomial.logProb(manager.create(new float[] {1f, 1f})); + Assertions.assertAlmostEquals(real, expected); + } + } + + @Test + public void testStudentT() { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray mu = manager.create(new float[] {1000f, -1000f}); + NDArray sigma = manager.create(new float[] {1f, 2f}); + NDArray nu = manager.create(new float[] {4.2f, 3f}); + mu.setName("mu"); + sigma.setName("sigma"); + nu.setName("nu"); + Distribution studentT = + StudentT.builder().setDistrArgs(new NDList(mu, sigma, nu)).build(); + + NDArray expected = manager.create(new float[] {-0.9779f, -1.6940f}); + NDArray real = studentT.logProb(manager.create(new float[] {1000f, -1000f})); + Assertions.assertAlmostEquals(real, expected); + } + } +} diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java new file mode 100644 index 00000000000..d6d06167d43 --- /dev/null +++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains tests for the distribution module. */ +package ai.djl.timeseries.distribution; \ No newline at end of file