Skip to content

Commit

Permalink
[timeseries] add probability distribution support for timeseries (#2025)
Browse files Browse the repository at this point in the history
* feature: add distribution and loss function

* bug fix and add unit test

* add comments and unit test

* add copyright and package-info

* style fix

* add args array

* feature: add affinely distribution

* feature: add sample for zero num

* bug fix: sum the loss

* bug fix: add builder for StudentTOutput and clamp the min value in neg_binomial

* Update extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java

* Update extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java

* Update extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java

* format

Co-authored-by: Carkham <1302112560@qq.com>
Co-authored-by: KexinFeng <fengx463@umn.edu>
  • Loading branch information
3 people authored Sep 20, 2022
1 parent 767ac22 commit a4c8a85
Show file tree
Hide file tree
Showing 13 changed files with 915 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.timeseries.distribution;

import ai.djl.ndarray.NDArray;

/** Represents the distribution of an affinely transformed random variable. */
public class AffineTransformed extends Distribution {

private Distribution baseDistribution;
private NDArray loc;
private NDArray scale;

/**
* Construct a new {@code AffineTransformed}
*
* <p>This is the distribution of Y = scale * X + loc, where X is a random variable distributed
* according to {@code baseDistribution}.
*
* @param baseDistribution original distribution
* @param loc translation parameter of the affine transformation
* @param scale scaling parameter of the affine transformation
*/
public AffineTransformed(Distribution baseDistribution, NDArray loc, NDArray scale) {
this.baseDistribution = baseDistribution;
this.loc = loc == null ? baseDistribution.mean().zerosLike() : loc;
this.scale = scale == null ? baseDistribution.mean().onesLike() : scale;
}

/** {@inheritDoc} */
@Override
public NDArray logProb(NDArray target) {
NDArray x = fInv(target);
NDArray ladj = logAbsDetJac(x);
NDArray lp = ladj.mul(-1);
return baseDistribution.logProb(x).add(lp);
}

/** {@inheritDoc} */
@Override
public NDArray sample(int numSamples) {
NDArray sample = baseDistribution.sample(numSamples);
return f(sample);
}

/** {@inheritDoc} */
@Override
public NDArray mean() {
return baseDistribution.mean().mul(scale).add(loc);
}

private NDArray f(NDArray x) {
return x.mul(scale).add(loc);
}

private NDArray fInv(NDArray y) {
return y.sub(loc).div(scale);
}

private NDArray logAbsDetJac(NDArray x) {
return scale.broadcast(x.getShape()).abs().log();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.timeseries.distribution;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/** An abstract class representing probability distribution. */
public abstract class Distribution {

/**
* Compute the log of the probability density/mass function evaluated at target.
*
* @param target {@link NDArray} of shape (*batch_shape, *event_shape)
* @return Tensor of shape (batch_shape) containing the probability log-density for each event
* in target
*/
public abstract NDArray logProb(NDArray target);

/**
* Draw samples from the distribution.
*
* <p>This function would expand the dimension of arguments, the first dimension of the output
* will be numSamples.
*
* @param numSamples Number of samples to be drawn
* @return a {@link NDArray} has shape (num_samples, *batch_shape, *target_shape)
*/
public abstract NDArray sample(int numSamples);

/**
* Draw samples from the distribution.
*
* <p>This function would not expand the dimension
*
* @return a sampled {@link NDArray}
*/
public NDArray sample() {
return sample(0);
}

/**
* Return the mean of the distribution.
*
* @return the mean of the distribution
*/
public abstract NDArray mean();

/**
* A builder to extend for all classes extend the {@link Distribution}.
*
* @param <T> the concrete builder type
*/
public abstract static class DistributionBuilder<T extends DistributionBuilder<T>> {
protected NDList distrArgs;
protected NDArray scale;
protected NDArray loc;

/**
* Set the appropriate arguments for the probability distribution.
*
* @param distrArgs a {@link NDList} containing distribution args named after the parameter
* name
* @return this builder
*/
public T setDistrArgs(NDList distrArgs) {
this.distrArgs = distrArgs;
return self();
}

/**
* Set the affine scale for the probability distribution.
*
* @param scale the affine scale
* @return this builder
*/
public T optScale(NDArray scale) {
this.scale = scale;
return self();
}

/**
* Set the affine location of the probability.
*
* @param loc the affine location
* @return this builder
*/
public T optLoc(NDArray loc) {
this.loc = loc;
return self();
}

/**
* Build a {@code Distribution}.
*
* @return the {@code Distribution}
*/
public abstract Distribution build();

protected abstract T self();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.timeseries.distribution;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.timeseries.distribution.output.DistributionOutput;
import ai.djl.training.loss.Loss;

/**
* {@code DistributionLoss} calculates loss for a given distribution.
*
* <p>Distribution Loss is calculated by {@link Distribution#logProb(NDArray)} at label point
*/
public class DistributionLoss extends Loss {

private DistributionOutput distrOutput;

/**
* Calculates Distribution Loss between the label and distribution.
*
* @param name the name of the loss
* @param distrOutput the {@link DistributionOutput} to construct the target distribution
*/
public DistributionLoss(String name, DistributionOutput distrOutput) {
super(name);
this.distrOutput = distrOutput;
}

/** {@inheritDoc} */
@Override
public NDArray evaluate(NDList labels, NDList predictions) {
Distribution.DistributionBuilder<?> builder = distrOutput.distributionBuilder();
builder.setDistrArgs(predictions);
if (predictions.contains("scale")) {
builder.optScale(predictions.get("scale"));
}
if (predictions.contains("loc")) {
builder.optLoc(predictions.get("loc"));
}

NDArray loss = builder.build().logProb(labels.singletonOrThrow()).mul(-1);

if (predictions.contains("loss_weights")) {
NDArray lossWeights = predictions.get("loss_weights");
NDArray weightedValue =
NDArrays.where(lossWeights.neq(0), loss.mul(lossWeights), loss.zerosLike());
NDArray sumWeights = lossWeights.sum().maximum(1.);
loss = weightedValue.sum().div(sumWeights);
}
return loss;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.timeseries.distribution;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.util.Preconditions;

/**
* Negative binomial distribution.
*
* <p>The distribution of the number of successes in a sequence of independent Bernoulli trials.
*
* <p>Two arguments for this distribution. {@code mu} mean of the distribution, {@code alpha} the
* inverse number of negative Bernoulli trials to stop
*/
public final class NegativeBinomial extends Distribution {

private NDArray mu;
private NDArray alpha;

NegativeBinomial(Builder builder) {
mu = builder.distrArgs.get("mu");
alpha = builder.distrArgs.get("alpha");
}

/** {@inheritDoc} */
@Override
public NDArray logProb(NDArray target) {

NDArray alphaInv = alpha.getNDArrayInternal().rdiv(1);
NDArray alphaTimesMu = alpha.mul(mu);

return target.mul(alphaTimesMu.div(alphaTimesMu.add(1)).log())
.sub(alphaInv.mul(alphaTimesMu.add(1).log()))
.add(target.add(alphaInv).gammaln())
.sub(target.add(1.).gammaln())
.sub(alphaInv.gammaln());
}

/** {@inheritDoc} */
@Override
public NDArray sample(int numSamples) {
NDManager manager = mu.getManager();
NDArray expandedMu = numSamples > 0 ? mu.expandDims(0).repeat(0, numSamples) : mu;
NDArray expandedAlpha = numSamples > 0 ? alpha.expandDims(0).repeat(0, numSamples) : alpha;

NDArray r = expandedAlpha.getNDArrayInternal().rdiv(1f);
NDArray theta = expandedAlpha.mul(expandedMu);
return manager.samplePoisson(manager.sampleGamma(r, theta));
}

/** {@inheritDoc} */
@Override
public NDArray mean() {
return mu;
}

/**
* Creates a builder to build a {@code NegativeBinomial}.
*
* @return a new builder
*/
public static Builder builder() {
return new Builder();
}

/** The builder to construct a {@code NegativeBinomial}. */
public static final class Builder extends DistributionBuilder<Builder> {

/** {@inheritDoc} */
@Override
public Distribution build() {
Preconditions.checkArgument(
distrArgs.contains("mu"), "NegativeBinomial's args must contain mu.");
Preconditions.checkArgument(
distrArgs.contains("alpha"), "NegativeBinomial's args must contain alpha.");
// We cannot scale using the affine transformation since negative binomial should return
// integers. Instead we scale the parameters.
if (scale != null) {
NDArray mu = distrArgs.get("mu");
mu = mu.mul(scale);
mu.setName("mu");
distrArgs.remove("mu");
distrArgs.add(mu);
}
return new NegativeBinomial(this);
}

/** {@inheritDoc} */
@Override
protected Builder self() {
return this;
}
}
}
Loading

0 comments on commit a4c8a85

Please sign in to comment.