Skip to content

Commit

Permalink
tunable pretrained layer
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Oct 8, 2022
1 parent 4178741 commit 80f0e49
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 27 deletions.
18 changes: 9 additions & 9 deletions api/src/main/java/ai/djl/training/LocalParameterServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,17 @@ public void update(String parameterId, NDArray[] grads, NDArray[] params) {
// use duplicate because after the first optimizer.update
// PyTorch optimizer will zero grads[0]
// the second copy is to move the grads[0] to the device the weight is on
try (NDArray aggregatedGrad = grads[0].duplicate()) {
for (NDArray param : params) {
if (param.getDevice().equals(firstDevice)) {
optimizer.update(parameterId, param, aggregatedGrad);
} else {
try (NDArray gradSumCopy = aggregatedGrad.toDevice(param.getDevice(), true)) {
optimizer.update(parameterId, param, gradSumCopy);
}
}
NDArray aggregatedGrad = grads[0].duplicate();
for (NDArray param : params) {
if (param.getDevice().equals(firstDevice)) {
optimizer.update(parameterId, param, aggregatedGrad);
} else {
NDArray gradSumCopy = aggregatedGrad.toDevice(param.getDevice(), true);
optimizer.update(parameterId, param, gradSumCopy);
gradSumCopy.close();
}
}
aggregatedGrad.close();
}

/** {@inheritDoc} */
Expand Down
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/training/evaluator/Accuracy.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ protected Pair<Long, NDArray> accuracyHelper(NDList labels, NDList predictions)
predictionReduced = predictionReduced.reshape(label.getShape());
} else {
// Multi-class, one-hot label
predictionReduced = prediction;
predictionReduced = prediction.argMax(axis);
label = label.argMax(axis);
}
// result of sum is int64 now
long total = label.size();
Expand Down
4 changes: 4 additions & 0 deletions api/src/main/java/ai/djl/training/optimizer/Adagrad.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.training.tracker.Tracker;

import java.util.Map;
Expand Down Expand Up @@ -60,6 +61,9 @@ protected Adagrad(Builder builder) {
/** {@inheritDoc} */
@Override
public void update(String parameterId, NDArray weight, NDArray grad) {
if (learningRateTracker instanceof FixedPerVarTracker) {
((FixedPerVarTracker) learningRateTracker).setParameterId(parameterId);
}
int t = updateCount(parameterId);
float newLearningRate = learningRateTracker.getNewValue(t);
float weightDecay = getWeightDecay();
Expand Down
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/training/optimizer/Adam.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.training.tracker.Tracker;
import ai.djl.util.Preconditions;

Expand Down Expand Up @@ -64,6 +65,9 @@ protected Adam(Builder builder) {
/** {@inheritDoc} */
@Override
public void update(String parameterId, NDArray weight, NDArray grad) {
if (learningRateTracker instanceof FixedPerVarTracker) {
((FixedPerVarTracker) learningRateTracker).setParameterId(parameterId);
}
int t = updateCount(parameterId);
double coef1 = 1.0 - Math.pow(beta1, t);
double coef2 = 1.0 - Math.pow(beta2, t);
Expand Down Expand Up @@ -118,6 +122,8 @@ public static Builder builder() {
public static final class Builder extends OptimizerBuilder<Builder> {

private Tracker learningRateTracker = Tracker.fixed(0.001f);
// private Tracker learningRateTracker = Tracker.fixedPerVar(functional: String ->
// 0.001f);
private float beta1 = 0.9f;
private float beta2 = 0.999f;
private float epsilon = 1e-8f;
Expand Down
4 changes: 4 additions & 0 deletions api/src/main/java/ai/djl/training/optimizer/Nag.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.training.tracker.Tracker;
import ai.djl.util.Preconditions;

Expand Down Expand Up @@ -52,6 +53,9 @@ protected Nag(Builder builder) {
/** {@inheritDoc} */
@Override
public void update(String parameterId, NDArray weight, NDArray grad) {
if (learningRateTracker instanceof FixedPerVarTracker) {
((FixedPerVarTracker) learningRateTracker).setParameterId(parameterId);
}
// TODO: Support Mixed precision Sparse
float newLearningRate = learningRateTracker.getNewValue(updateCount(parameterId));
float weightDecay = getWeightDecay();
Expand Down
4 changes: 4 additions & 0 deletions api/src/main/java/ai/djl/training/optimizer/RmsProp.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.training.tracker.Tracker;

import java.util.Map;
Expand Down Expand Up @@ -85,6 +86,9 @@ protected RmsProp(Builder builder) {
/** {@inheritDoc} */
@Override
public void update(String parameterId, NDArray weight, NDArray grad) {
if (learningRateTracker instanceof FixedPerVarTracker) {
((FixedPerVarTracker) learningRateTracker).setParameterId(parameterId);
}
float newLearningRate = learningRateTracker.getNewValue(updateCount(parameterId));
float weightDecay = getWeightDecay();

Expand Down
4 changes: 4 additions & 0 deletions api/src/main/java/ai/djl/training/optimizer/Sgd.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.training.tracker.Tracker;

import java.util.Map;
Expand Down Expand Up @@ -56,6 +57,9 @@ protected Sgd(Builder builder) {
/** {@inheritDoc} */
@Override
public void update(String parameterId, NDArray weight, NDArray grad) {
if (learningRateTracker instanceof FixedPerVarTracker) {
((FixedPerVarTracker) learningRateTracker).setParameterId(parameterId);
}
// TODO: Support Mixed precision Sparse
float weightDecay = getWeightDecay();
float learningRate = learningRateTracker.getNewValue(updateCount(parameterId));
Expand Down
90 changes: 90 additions & 0 deletions api/src/main/java/ai/djl/training/tracker/FixedPerVarTracker.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.training.tracker;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
* {@link FixedPerVarTracker} is an implementation of {@link Tracker} which returns a fixed value.
*
* @see Tracker
*/
public class FixedPerVarTracker implements Tracker {

private float value;
private Map<String, Float> valueMap;
private String parameterId;

/**
* Creates a new instance of {@link FixedPerVarTracker}.
*
* @param builder the builder used to build this object
*/
public FixedPerVarTracker(Builder builder) {
this.value = builder.value;
this.valueMap = builder.valueMap;
}

/** {@inheritDoc} */
@Override
public float getNewValue(int numUpdate) {
return valueMap.getOrDefault(this.parameterId, this.value);
}

public void setParameterId(String parameterId) {
this.parameterId = parameterId;
}

/**
* Creates a builder to build a {@link FixedPerVarTracker}.
*
* @return a new builder
*/
public static Builder builder() {
return new Builder();
}

/** The Builder to construct an {@link FixedPerVarTracker} object. */
public static final class Builder {

private float value;
private Map<String, Float> valueMap = new ConcurrentHashMap<String, Float>();

private Builder() {}

public Builder setDefaultValue(float value) {
this.value = value;
return this;
}

public Builder put(String parameterId, float value) {
this.valueMap.put(parameterId, value);
return this;
}

public Builder putAll(Map<String, Float> valueMap) {
this.valueMap.putAll(valueMap);
return this;
}

/**
* Builds a {@link FixedPerVarTracker} block.
*
* @return the {@link FixedPerVarTracker} block
*/
public FixedPerVarTracker build() {
return new FixedPerVarTracker(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ai.djl.testing.TestRequirements;
import ai.djl.training.ParameterServer;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.training.tracker.Tracker;

import org.testng.Assert;
Expand Down Expand Up @@ -110,6 +111,9 @@ protected TestOptimizer(TestOptimizer.Builder builder) {
/** {@inheritDoc} */
@Override
public void update(String parameterId, NDArray weight, NDArray grad) {
if (learningRateTracker instanceof FixedPerVarTracker) {
((FixedPerVarTracker) learningRateTracker).setParameterId(parameterId);
}
weight.addi(
grad.mul(learningRateTracker.getNewValue(0))
.toDevice(weight.getDevice(), false));
Expand Down
1 change: 1 addition & 0 deletions examples/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies {
implementation project(":basicdataset")
implementation project(":model-zoo")
implementation project(":extensions:timeseries")
implementation 'org.testng:testng:7.1.0'

runtimeOnly project(":engines:pytorch:pytorch-model-zoo")
runtimeOnly project(":engines:tensorflow:tensorflow-model-zoo")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.Repository;
Expand All @@ -36,8 +37,12 @@
import ai.djl.training.listener.SaveModelTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.util.Pair;

import java.io.IOException;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -67,7 +72,7 @@ public static TrainingResult runExample(String[] args)
.optModelUrls(modelUrls)
.optEngine(Engine.getDefaultEngineName())
.optProgress(new ProgressBar())
.optOption("retrain", "0")
.optOption("retrain", "1")
.build();

ZooModel<NDList, NDList> embedding = criteria.loadModel();
Expand All @@ -86,16 +91,15 @@ public static TrainingResult runExample(String[] args)
// Config trainer
DefaultTrainingConfig config = setupTrainingConfig(arguments);

// /// Customized learning rate
// FixedPerVarTracker.Builder learningRateTrackerBuilder =
// FixedPerVarTracker.builder().setDefaultValue(0.001f);
// for (Pair<String, Parameter> paramPair : baseBlock.getParameters()) {
// learningRateTrackerBuilder.put(paramPair.getValue().getId(), 0.0001f);
// }
// Optimizer optimizer =
//
// Adam.builder().optLearningRateTracker(learningRateTrackerBuilder.build()).build();
// config.optOptimizer(optimizer);
/// Customized learning rate
FixedPerVarTracker.Builder learningRateTrackerBuilder =
FixedPerVarTracker.builder().setDefaultValue(0.001f);
for (Pair<String, Parameter> paramPair : baseBlock.getParameters()) {
learningRateTrackerBuilder.put(paramPair.getValue().getId(), 0.0001f);
}
Optimizer optimizer =
Adam.builder().optLearningRateTracker(learningRateTrackerBuilder.build()).build();
config.optOptimizer(optimizer);

Trainer trainer = model.newTrainer(config);
trainer.setMetrics(new Metrics());
Expand All @@ -109,10 +113,7 @@ public static TrainingResult runExample(String[] args)
// Data
String folderUrl = "/Users/fenkexin/Desktop/transferDJL/code/data/banana";
String subfolder = "/test/";

// set the image folder path
Repository repository = Repository.newInstance("banana", Paths.get(folderUrl + subfolder));

ImageFolder datasetTrain =
ImageFolder.builder()
.setRepository(repository)
Expand All @@ -121,13 +122,13 @@ public static TrainingResult runExample(String[] args)
// .addTargetTransform(new ToOneHot(2))
.setSampling(batchSize, true)
.build();

// call prepare before using
datasetTrain.prepare();

// train
EasyTrain.fit(trainer, 50, datasetTrain, null);

// model.save("your-model-path"); // save the model
// Save model
// model.save("your-model-path");

return null;
}
Expand Down

0 comments on commit 80f0e49

Please sign in to comment.