-
Notifications
You must be signed in to change notification settings - Fork 688
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[timeseries] add probability distribution support for timeseries (#2025)
* feature: add distribution and loss function * bug fix and add unit test * add comments and unit test * add copyright and package-info * style fix * add args array * feature: add affinely distribution * feature: add sample for zero num * bug fix: sum the loss * bug fix: add builder for StudentTOutput and clamp the min value in neg_binomial * Update extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java * Update extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java * Update extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java * format Co-authored-by: Carkham <1302112560@qq.com> Co-authored-by: KexinFeng <fengx463@umn.edu>
- Loading branch information
1 parent
767ac22
commit a4c8a85
Showing
13 changed files
with
915 additions
and
0 deletions.
There are no files selected for viewing
74 changes: 74 additions & 0 deletions
74
extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/AffineTransformed.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
/* | ||
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
|
||
package ai.djl.timeseries.distribution; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
|
||
/** Represents the distribution of an affinely transformed random variable. */ | ||
public class AffineTransformed extends Distribution { | ||
|
||
private Distribution baseDistribution; | ||
private NDArray loc; | ||
private NDArray scale; | ||
|
||
/** | ||
* Construct a new {@code AffineTransformed} | ||
* | ||
* <p>This is the distribution of Y = scale * X + loc, where X is a random variable distributed | ||
* according to {@code baseDistribution}. | ||
* | ||
* @param baseDistribution original distribution | ||
* @param loc translation parameter of the affine transformation | ||
* @param scale scaling parameter of the affine transformation | ||
*/ | ||
public AffineTransformed(Distribution baseDistribution, NDArray loc, NDArray scale) { | ||
this.baseDistribution = baseDistribution; | ||
this.loc = loc == null ? baseDistribution.mean().zerosLike() : loc; | ||
this.scale = scale == null ? baseDistribution.mean().onesLike() : scale; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDArray logProb(NDArray target) { | ||
NDArray x = fInv(target); | ||
NDArray ladj = logAbsDetJac(x); | ||
NDArray lp = ladj.mul(-1); | ||
return baseDistribution.logProb(x).add(lp); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDArray sample(int numSamples) { | ||
NDArray sample = baseDistribution.sample(numSamples); | ||
return f(sample); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDArray mean() { | ||
return baseDistribution.mean().mul(scale).add(loc); | ||
} | ||
|
||
private NDArray f(NDArray x) { | ||
return x.mul(scale).add(loc); | ||
} | ||
|
||
private NDArray fInv(NDArray y) { | ||
return y.sub(loc).div(scale); | ||
} | ||
|
||
private NDArray logAbsDetJac(NDArray x) { | ||
return scale.broadcast(x.getShape()).abs().log(); | ||
} | ||
} |
113 changes: 113 additions & 0 deletions
113
extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
/* | ||
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
|
||
package ai.djl.timeseries.distribution; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDList; | ||
|
||
/** An abstract class representing probability distribution. */ | ||
public abstract class Distribution { | ||
|
||
/** | ||
* Compute the log of the probability density/mass function evaluated at target. | ||
* | ||
* @param target {@link NDArray} of shape (*batch_shape, *event_shape) | ||
* @return Tensor of shape (batch_shape) containing the probability log-density for each event | ||
* in target | ||
*/ | ||
public abstract NDArray logProb(NDArray target); | ||
|
||
/** | ||
* Draw samples from the distribution. | ||
* | ||
* <p>This function would expand the dimension of arguments, the first dimension of the output | ||
* will be numSamples. | ||
* | ||
* @param numSamples Number of samples to be drawn | ||
* @return a {@link NDArray} has shape (num_samples, *batch_shape, *target_shape) | ||
*/ | ||
public abstract NDArray sample(int numSamples); | ||
|
||
/** | ||
* Draw samples from the distribution. | ||
* | ||
* <p>This function would not expand the dimension | ||
* | ||
* @return a sampled {@link NDArray} | ||
*/ | ||
public NDArray sample() { | ||
return sample(0); | ||
} | ||
|
||
/** | ||
* Return the mean of the distribution. | ||
* | ||
* @return the mean of the distribution | ||
*/ | ||
public abstract NDArray mean(); | ||
|
||
/** | ||
* A builder to extend for all classes extend the {@link Distribution}. | ||
* | ||
* @param <T> the concrete builder type | ||
*/ | ||
public abstract static class DistributionBuilder<T extends DistributionBuilder<T>> { | ||
protected NDList distrArgs; | ||
protected NDArray scale; | ||
protected NDArray loc; | ||
|
||
/** | ||
* Set the appropriate arguments for the probability distribution. | ||
* | ||
* @param distrArgs a {@link NDList} containing distribution args named after the parameter | ||
* name | ||
* @return this builder | ||
*/ | ||
public T setDistrArgs(NDList distrArgs) { | ||
this.distrArgs = distrArgs; | ||
return self(); | ||
} | ||
|
||
/** | ||
* Set the affine scale for the probability distribution. | ||
* | ||
* @param scale the affine scale | ||
* @return this builder | ||
*/ | ||
public T optScale(NDArray scale) { | ||
this.scale = scale; | ||
return self(); | ||
} | ||
|
||
/** | ||
* Set the affine location of the probability. | ||
* | ||
* @param loc the affine location | ||
* @return this builder | ||
*/ | ||
public T optLoc(NDArray loc) { | ||
this.loc = loc; | ||
return self(); | ||
} | ||
|
||
/** | ||
* Build a {@code Distribution}. | ||
* | ||
* @return the {@code Distribution} | ||
*/ | ||
public abstract Distribution build(); | ||
|
||
protected abstract T self(); | ||
} | ||
} |
65 changes: 65 additions & 0 deletions
65
extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/* | ||
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
|
||
package ai.djl.timeseries.distribution; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDArrays; | ||
import ai.djl.ndarray.NDList; | ||
import ai.djl.timeseries.distribution.output.DistributionOutput; | ||
import ai.djl.training.loss.Loss; | ||
|
||
/** | ||
* {@code DistributionLoss} calculates loss for a given distribution. | ||
* | ||
* <p>Distribution Loss is calculated by {@link Distribution#logProb(NDArray)} at label point | ||
*/ | ||
public class DistributionLoss extends Loss { | ||
|
||
private DistributionOutput distrOutput; | ||
|
||
/** | ||
* Calculates Distribution Loss between the label and distribution. | ||
* | ||
* @param name the name of the loss | ||
* @param distrOutput the {@link DistributionOutput} to construct the target distribution | ||
*/ | ||
public DistributionLoss(String name, DistributionOutput distrOutput) { | ||
super(name); | ||
this.distrOutput = distrOutput; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDArray evaluate(NDList labels, NDList predictions) { | ||
Distribution.DistributionBuilder<?> builder = distrOutput.distributionBuilder(); | ||
builder.setDistrArgs(predictions); | ||
if (predictions.contains("scale")) { | ||
builder.optScale(predictions.get("scale")); | ||
} | ||
if (predictions.contains("loc")) { | ||
builder.optLoc(predictions.get("loc")); | ||
} | ||
|
||
NDArray loss = builder.build().logProb(labels.singletonOrThrow()).mul(-1); | ||
|
||
if (predictions.contains("loss_weights")) { | ||
NDArray lossWeights = predictions.get("loss_weights"); | ||
NDArray weightedValue = | ||
NDArrays.where(lossWeights.neq(0), loss.mul(lossWeights), loss.zerosLike()); | ||
NDArray sumWeights = lossWeights.sum().maximum(1.); | ||
loss = weightedValue.sum().div(sumWeights); | ||
} | ||
return loss; | ||
} | ||
} |
107 changes: 107 additions & 0 deletions
107
extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/* | ||
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
|
||
package ai.djl.timeseries.distribution; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDManager; | ||
import ai.djl.util.Preconditions; | ||
|
||
/** | ||
* Negative binomial distribution. | ||
* | ||
* <p>The distribution of the number of successes in a sequence of independent Bernoulli trials. | ||
* | ||
* <p>Two arguments for this distribution. {@code mu} mean of the distribution, {@code alpha} the | ||
* inverse number of negative Bernoulli trials to stop | ||
*/ | ||
public final class NegativeBinomial extends Distribution { | ||
|
||
private NDArray mu; | ||
private NDArray alpha; | ||
|
||
NegativeBinomial(Builder builder) { | ||
mu = builder.distrArgs.get("mu"); | ||
alpha = builder.distrArgs.get("alpha"); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDArray logProb(NDArray target) { | ||
|
||
NDArray alphaInv = alpha.getNDArrayInternal().rdiv(1); | ||
NDArray alphaTimesMu = alpha.mul(mu); | ||
|
||
return target.mul(alphaTimesMu.div(alphaTimesMu.add(1)).log()) | ||
.sub(alphaInv.mul(alphaTimesMu.add(1).log())) | ||
.add(target.add(alphaInv).gammaln()) | ||
.sub(target.add(1.).gammaln()) | ||
.sub(alphaInv.gammaln()); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDArray sample(int numSamples) { | ||
NDManager manager = mu.getManager(); | ||
NDArray expandedMu = numSamples > 0 ? mu.expandDims(0).repeat(0, numSamples) : mu; | ||
NDArray expandedAlpha = numSamples > 0 ? alpha.expandDims(0).repeat(0, numSamples) : alpha; | ||
|
||
NDArray r = expandedAlpha.getNDArrayInternal().rdiv(1f); | ||
NDArray theta = expandedAlpha.mul(expandedMu); | ||
return manager.samplePoisson(manager.sampleGamma(r, theta)); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDArray mean() { | ||
return mu; | ||
} | ||
|
||
/** | ||
* Creates a builder to build a {@code NegativeBinomial}. | ||
* | ||
* @return a new builder | ||
*/ | ||
public static Builder builder() { | ||
return new Builder(); | ||
} | ||
|
||
/** The builder to construct a {@code NegativeBinomial}. */ | ||
public static final class Builder extends DistributionBuilder<Builder> { | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public Distribution build() { | ||
Preconditions.checkArgument( | ||
distrArgs.contains("mu"), "NegativeBinomial's args must contain mu."); | ||
Preconditions.checkArgument( | ||
distrArgs.contains("alpha"), "NegativeBinomial's args must contain alpha."); | ||
// We cannot scale using the affine transformation since negative binomial should return | ||
// integers. Instead we scale the parameters. | ||
if (scale != null) { | ||
NDArray mu = distrArgs.get("mu"); | ||
mu = mu.mul(scale); | ||
mu.setName("mu"); | ||
distrArgs.remove("mu"); | ||
distrArgs.add(mu); | ||
} | ||
return new NegativeBinomial(this); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
protected Builder self() { | ||
return this; | ||
} | ||
} | ||
} |
Oops, something went wrong.