diff --git a/api/src/main/java/ai/djl/BaseModel.java b/api/src/main/java/ai/djl/BaseModel.java index 711b8ec11f1..bada2b2a8c2 100644 --- a/api/src/main/java/ai/djl/BaseModel.java +++ b/api/src/main/java/ai/djl/BaseModel.java @@ -61,6 +61,7 @@ public abstract class BaseModel implements Model { protected String modelName; protected NDManager manager; protected DataType dataType; + protected boolean wasLoaded; protected PairList inputData; protected Map artifacts = new ConcurrentHashMap<>(); protected Map properties = new ConcurrentHashMap<>(); @@ -78,6 +79,7 @@ public Block getBlock() { /** {@inheritDoc} */ @Override public void setBlock(Block block) { + wasLoaded = false; this.block = block; } diff --git a/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableWordEmbedding.java b/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableWordEmbedding.java index 8d4d6f3d539..582150f2293 100644 --- a/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableWordEmbedding.java +++ b/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableWordEmbedding.java @@ -16,6 +16,7 @@ import ai.djl.modality.nlp.Vocabulary; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.types.SparseFormat; +import ai.djl.nn.Block; import ai.djl.nn.core.Embedding; import java.nio.charset.StandardCharsets; @@ -59,30 +60,47 @@ public TrainableWordEmbedding(Vocabulary vocabulary, int embeddingSize) { .optUseDefault(false)); } + private TrainableWordEmbedding(NDArray embedding, List items) { + super(embedding); + this.fallthroughEmbedding = new DefaultItem(DEFAULT_UNKNOWN_TOKEN); + this.vocabulary = new DefaultVocabulary(items); + } + + private TrainableWordEmbedding( + NDArray embedding, List items, SparseFormat sparseFormat) { + super(embedding, sparseFormat); + this.fallthroughEmbedding = new DefaultItem(DEFAULT_UNKNOWN_TOKEN); + this.vocabulary = new DefaultVocabulary(items); + } + /** * Constructs a pretrained embedding. * + *

Because it is created with preTrained data, it is created as a frozen block. If you with + * to update it, call {@link Block#freezeParameters(boolean)}. + * * @param embedding the embedding array * @param items the items in the embedding (in matching order to the embedding array) + * @return the created embedding */ - public TrainableWordEmbedding(NDArray embedding, List items) { - super(embedding); - this.fallthroughEmbedding = new DefaultItem(DEFAULT_UNKNOWN_TOKEN); - this.vocabulary = new DefaultVocabulary(items); + public static TrainableWordEmbedding fromPretrained(NDArray embedding, List items) { + return new TrainableWordEmbedding(embedding, items); } /** * Constructs a pretrained embedding. * + *

Because it is created with preTrained data, it is created as a frozen block. If you with + * to update it, call {@link Block#freezeParameters(boolean)}. + * * @param embedding the embedding array * @param items the items in the embedding (in matching order to the embedding array) * @param sparseFormat whether to compute row sparse gradient in the backward calculation + * @return the created embedding */ - public TrainableWordEmbedding( + public static TrainableWordEmbedding fromPretrained( NDArray embedding, List items, SparseFormat sparseFormat) { - super(embedding, sparseFormat); - this.fallthroughEmbedding = new DefaultItem(DEFAULT_UNKNOWN_TOKEN); - this.vocabulary = new DefaultVocabulary(items); + return new TrainableWordEmbedding(embedding, items, sparseFormat); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java index b81d795296e..3d58d501293 100644 --- a/api/src/main/java/ai/djl/nn/Block.java +++ b/api/src/main/java/ai/djl/nn/Block.java @@ -100,6 +100,10 @@ * the javadoc at {@link ai.djl.training.Trainer}. At the end of training, a block represents a * fully-trained model. * + *

It is also possible to freeze parameters and blocks to avoid them being trained. When loading + * models or building blocks with preTrained data, they default to being frozen. If you wish to + * further refine these elements, use {@link Block#freezeParameters(boolean)} to unfreeze them. + * * @see this * tutorial on creating your first network diff --git a/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java b/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java index 957276571c7..bcfde720d3b 100644 --- a/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java +++ b/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java @@ -33,10 +33,14 @@ public class ConstantEmbedding extends AbstractBlock implements AbstractIndexedE /** * Constructs a constant embedding with the given constant. * + *

The constant is assumed to be a fixed value, and starts out as frozen. To unfreeze, use + * {@link ai.djl.nn.Block#freezeParameters(boolean)}. + * * @param embedding the value to return for all embeddings */ public ConstantEmbedding(NDArray embedding) { this.embedding = embedding; + freezeParameters(true); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/core/Embedding.java b/api/src/main/java/ai/djl/nn/core/Embedding.java index 7c336ce8e82..d6a937fe9a0 100644 --- a/api/src/main/java/ai/djl/nn/core/Embedding.java +++ b/api/src/main/java/ai/djl/nn/core/Embedding.java @@ -78,17 +78,20 @@ protected Embedding(BaseBuilder baseBuilder) { * * @param embedding the embedding array */ - public Embedding(NDArray embedding) { + protected Embedding(NDArray embedding) { this(embedding, SparseFormat.DENSE); } /** * Constructs a pretrained embedding. * + *

Because it is created with preTrained data, it is created as a frozen block. If you with + * to update it, call {@link Block#freezeParameters(boolean)}. + * * @param embedding the embedding array * @param format whether to compute row sparse gradient in the backward calculation */ - public Embedding(NDArray embedding, SparseFormat format) { + protected Embedding(NDArray embedding, SparseFormat format) { super(VERSION); numEmbeddings = Math.toIntExact(embedding.getShape().get(0)); embeddingSize = Math.toIntExact(embedding.getShape().get(1)); @@ -101,6 +104,7 @@ public Embedding(NDArray embedding, SparseFormat format) { .build()); this.embedding.setArray(embedding); inputShapes = new Shape[] {new Shape(-1)}; + freezeParameters(true); } /** {@inheritDoc} */ diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmModel.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmModel.java index 04c6f4fb5d5..c3755365325 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmModel.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmModel.java @@ -44,6 +44,7 @@ public class LgbmModel extends BaseModel { @Override public void load(Path modelPath, String prefix, Map options) throws IOException { setModelDir(modelPath); + wasLoaded = true; if (block != null) { throw new UnsupportedOperationException("LightGBM does not support dynamic blocks"); } diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java index 3b19c1426e2..bf41acb9b6c 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java @@ -46,6 +46,7 @@ public class XgbModel extends BaseModel { @Override public void load(Path modelPath, String prefix, Map options) throws IOException { setModelDir(modelPath); + wasLoaded = true; if (block != null) { throw new UnsupportedOperationException("XGBoost does not support dynamic blocks"); } diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java index 21951e6a8e0..130543b8528 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java @@ -90,6 +90,7 @@ public class MxModel extends BaseModel { public void load(Path modelPath, String prefix, Map options) throws IOException, MalformedModelException { setModelDir(modelPath); + wasLoaded = true; if (prefix == null) { prefix = modelName; } @@ -137,6 +138,13 @@ public void load(Path modelPath, String prefix, Map options) if (optimization != null) { ((MxSymbolBlock) block).optimizeFor(optimization); } + + // Freeze parameters to match Block spec for preTrained data + boolean trainParam = + options != null && Boolean.parseBoolean((String) options.get("trainParam")); + if (!trainParam) { + block.freezeParameters(true); + } } /** {@inheritDoc} */ @@ -147,6 +155,10 @@ public Trainer newTrainer(TrainingConfig trainingConfig) { throw new IllegalStateException( "You must set a block for the model before creating a new trainer"); } + if (wasLoaded) { + // Unfreeze parameters if training directly + block.freezeParameters(false); + } for (Pair> pair : initializer) { if (pair.getKey() != null && pair.getValue() != null) { block.setInitializer(pair.getKey(), pair.getValue()); diff --git a/engines/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/nlp/embedding/utils/BuildModelZooWordEmbedding.java b/engines/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/nlp/embedding/utils/BuildModelZooWordEmbedding.java index efe5dbfee0e..d2873e87042 100644 --- a/engines/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/nlp/embedding/utils/BuildModelZooWordEmbedding.java +++ b/engines/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/nlp/embedding/utils/BuildModelZooWordEmbedding.java @@ -48,7 +48,8 @@ private static void buildEmbedding(String dir, String name) throws IOException { NDArray idxToVec = model.getNDManager().load(path.resolve("idx_to_vec.mx")).singletonOrThrow(); List idxToToken = Utils.readLines(path.resolve("idx_to_token.txt")); - TrainableWordEmbedding embedding = new TrainableWordEmbedding(idxToVec, idxToToken); + TrainableWordEmbedding embedding = + TrainableWordEmbedding.fromPretrained(idxToVec, idxToToken); model.setBlock(embedding); model.save(path, name); } diff --git a/engines/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java b/engines/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java index 212fd7992ac..8620d16ec1c 100644 --- a/engines/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java +++ b/engines/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java @@ -186,6 +186,7 @@ public void trainWithCustomLayer() throws IOException, ModelException { SymbolBlock mlp = (SymbolBlock) model.getBlock(); SequentialBlock newMlp = new SequentialBlock(); mlp.removeLastBlock(); + mlp.freezeParameters(false); newMlp.add(mlp); Linear linear = Linear.builder().setUnits(10).build(); diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java index 81e07032a1c..9c57b70d677 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java @@ -65,6 +65,7 @@ public class OrtModel extends BaseModel { public void load(Path modelPath, String prefix, Map options) throws IOException, MalformedModelException { setModelDir(modelPath); + wasLoaded = true; if (block != null) { throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks"); } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index d797dbd6480..01601ef1ad7 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -63,6 +63,7 @@ public class PtModel extends BaseModel { public void load(Path modelPath, String prefix, Map options) throws IOException, MalformedModelException { setModelDir(modelPath); + wasLoaded = true; if (prefix == null) { prefix = modelName; } @@ -186,6 +187,10 @@ public Trainer newTrainer(TrainingConfig trainingConfig) { throw new IllegalStateException( "You must set a block for the model before creating a new trainer"); } + if (wasLoaded) { + // Unfreeze parameters if training directly + block.freezeParameters(false); + } for (Pair> pair : initializer) { if (pair.getKey() != null && pair.getValue() != null) { block.setInitializer(pair.getKey(), pair.getValue()); diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfModel.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfModel.java index a5e60cc138f..53f3981a8f6 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfModel.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfModel.java @@ -60,6 +60,7 @@ public class TfModel extends BaseModel { public void load(Path modelPath, String prefix, Map options) throws FileNotFoundException, MalformedModelException { setModelDir(modelPath); + wasLoaded = true; if (prefix == null) { prefix = modelName; } diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtModel.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtModel.java index e151cda5d3f..44047e0e614 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtModel.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtModel.java @@ -51,6 +51,7 @@ public class TrtModel extends BaseModel { @Override public void load(Path modelPath, String prefix, Map options) throws IOException { setModelDir(modelPath); + wasLoaded = true; if (block != null) { throw new UnsupportedOperationException("TensorRT does not support dynamic blocks"); } diff --git a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java index 7acb2f3531f..8c88eadf79a 100644 --- a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java +++ b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java @@ -145,6 +145,7 @@ private static Model getModel(Arguments arguments) SequentialBlock newBlock = new SequentialBlock(); SymbolBlock block = (SymbolBlock) model.getBlock(); block.removeLastBlock(); + block.freezeParameters(false); newBlock.add(block); // the original model don't include the flatten // so apply the flatten here diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWord2VecWordEmbedding.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWord2VecWordEmbedding.java index 230079d79cd..d047c121760 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWord2VecWordEmbedding.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWord2VecWordEmbedding.java @@ -42,7 +42,7 @@ public FtWord2VecWordEmbedding(Model model, Vocabulary vocabulary) { throw new IllegalArgumentException("The FtWord2VecWordEmbedding requires an FtModel"); } - this.embedding = (FtAbstractBlock) model.getBlock(); + this.embedding = ((FtModel) model).getBlock(); this.vocabulary = vocabulary; }