Skip to content

Commit

Permalink
suggested editions
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Oct 26, 2022
1 parent 77bec10 commit 56d4e13
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 34 deletions.
5 changes: 1 addition & 4 deletions api/src/main/java/ai/djl/training/optimizer/Adagrad.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.training.tracker.Tracker;

import java.util.Map;
Expand Down Expand Up @@ -61,9 +60,7 @@ protected Adagrad(Builder builder) {
/** {@inheritDoc} */
@Override
public void update(String parameterId, NDArray weight, NDArray grad) {
if (learningRateTracker instanceof FixedPerVarTracker) {
((FixedPerVarTracker) learningRateTracker).setParameterId(parameterId);
}
learningRateTracker.setParameterId(parameterId);
int t = updateCount(parameterId);
float newLearningRate = learningRateTracker.getNewValue(t);
float weightDecay = getWeightDecay();
Expand Down
5 changes: 1 addition & 4 deletions api/src/main/java/ai/djl/training/optimizer/Adam.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.training.tracker.Tracker;
import ai.djl.util.Preconditions;

Expand Down Expand Up @@ -65,9 +64,7 @@ protected Adam(Builder builder) {
/** {@inheritDoc} */
@Override
public void update(String parameterId, NDArray weight, NDArray grad) {
if (learningRateTracker instanceof FixedPerVarTracker) {
((FixedPerVarTracker) learningRateTracker).setParameterId(parameterId);
}
learningRateTracker.setParameterId(parameterId);
int t = updateCount(parameterId);
double coef1 = 1.0 - Math.pow(beta1, t);
double coef2 = 1.0 - Math.pow(beta2, t);
Expand Down
5 changes: 1 addition & 4 deletions api/src/main/java/ai/djl/training/optimizer/Nag.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.training.tracker.Tracker;
import ai.djl.util.Preconditions;

Expand Down Expand Up @@ -53,9 +52,7 @@ protected Nag(Builder builder) {
/** {@inheritDoc} */
@Override
public void update(String parameterId, NDArray weight, NDArray grad) {
if (learningRateTracker instanceof FixedPerVarTracker) {
((FixedPerVarTracker) learningRateTracker).setParameterId(parameterId);
}
learningRateTracker.setParameterId(parameterId);
// TODO: Support Mixed precision Sparse
float newLearningRate = learningRateTracker.getNewValue(updateCount(parameterId));
float weightDecay = getWeightDecay();
Expand Down
5 changes: 1 addition & 4 deletions api/src/main/java/ai/djl/training/optimizer/RmsProp.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.training.tracker.Tracker;

import java.util.Map;
Expand Down Expand Up @@ -86,9 +85,7 @@ protected RmsProp(Builder builder) {
/** {@inheritDoc} */
@Override
public void update(String parameterId, NDArray weight, NDArray grad) {
if (learningRateTracker instanceof FixedPerVarTracker) {
((FixedPerVarTracker) learningRateTracker).setParameterId(parameterId);
}
learningRateTracker.setParameterId(parameterId);
float newLearningRate = learningRateTracker.getNewValue(updateCount(parameterId));
float weightDecay = getWeightDecay();

Expand Down
5 changes: 1 addition & 4 deletions api/src/main/java/ai/djl/training/optimizer/Sgd.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.training.tracker.Tracker;

import java.util.Map;
Expand Down Expand Up @@ -57,9 +56,7 @@ protected Sgd(Builder builder) {
/** {@inheritDoc} */
@Override
public void update(String parameterId, NDArray weight, NDArray grad) {
if (learningRateTracker instanceof FixedPerVarTracker) {
((FixedPerVarTracker) learningRateTracker).setParameterId(parameterId);
}
learningRateTracker.setParameterId(parameterId);
// TODO: Support Mixed precision Sparse
float weightDecay = getWeightDecay();
float learningRate = learningRateTracker.getNewValue(updateCount(parameterId));
Expand Down
8 changes: 8 additions & 0 deletions api/src/main/java/ai/djl/training/tracker/Tracker.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ public interface Tracker {
*/
float getNewValue(int numUpdate);

/**
* Set parameterId for this Tracker.
*
* @param parameterId the parameter Id
*/
default void setParameterId(String parameterId) {}
;

/**
* Returns a new instance of {@link ai.djl.training.tracker.FactorTracker.Builder} that can
* build an {@link FactorTracker}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import ai.djl.testing.TestRequirements;
import ai.djl.training.ParameterServer;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.training.tracker.Tracker;

import org.testng.Assert;
Expand Down Expand Up @@ -111,9 +110,7 @@ protected TestOptimizer(TestOptimizer.Builder builder) {
/** {@inheritDoc} */
@Override
public void update(String parameterId, NDArray weight, NDArray grad) {
if (learningRateTracker instanceof FixedPerVarTracker) {
((FixedPerVarTracker) learningRateTracker).setParameterId(parameterId);
}
learningRateTracker.setParameterId(parameterId);
weight.addi(
grad.mul(learningRateTracker.getNewValue(0))
.toDevice(weight.getDevice(), false));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,7 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
extraFileValues = new String[extraFileKeys.length];
}
if (options.containsKey("retrain")) {
String value = (String) options.get("retrain");
retrain =
"1".equalsIgnoreCase(value)
|| "true".equalsIgnoreCase(value)
|| "on".equalsIgnoreCase(value);
retrain = Boolean.parseBoolean((String) options.get("retrain"));
}
mapLocation = Boolean.parseBoolean((String) options.get("mapLocation"));
}
Expand All @@ -105,10 +101,11 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
for (int i = 0; i < extraFileKeys.length; i++) {
properties.put(extraFileKeys[i], extraFileValues[i]);
}
// Freeze the parameters if not retrain
for (Pair<String, Parameter> paramPair : block.getParameters()) {
paramPair.getValue().freeze(!retrain);
}
// By default, the parameters are frozen, since before adding this training feature
// `retrain`, it was frozen already by setting `JITCallGuard guard`, which disables
// autograd. Also, the pretrained parameters usually should not be updated too much. It
// is safe to freeze it. Users may unfreeze it and set their learning rate small.
block.freezeParameters(!retrain);
} else {
boolean hasParameter = true;
if (options != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public static TrainingResult runExample(String[] args)
.optProgress(new ProgressBar())
// Here the argument "pretrained" is borrowed.
// Pretrained means no need to retrain, and vice versa.
.optOption("retrain", arguments.isPreTrained() ? "0" : "1")
.optOption("retrain", arguments.isPreTrained() ? "false" : "true")
.build();

ZooModel<NDList, NDList> embedding = criteria.loadModel();
Expand Down

0 comments on commit 56d4e13

Please sign in to comment.