Skip to content

Commit

Permalink
Creates standard for PreTrained behavior (#2360)
Browse files Browse the repository at this point in the history
This changes the standard for DJL behavior with preTrained blocks. As of now,
they should also start out with frozen parameters. This has been applied to the embeddings.

It was previously applied only to PyTorch, but as of now applies to all models.
However, I did leave a carveout for models. It adds a boolean "wasLoaded" so
that if you load a model and then create a Trainer directly from it, it will not
be frozen. If you load a model and then append some new layers to it (as we have
several examples of), then it will need to be unfrozen.
  • Loading branch information
zachgk authored Feb 3, 2023
1 parent 28a02b0 commit cba412d
Show file tree
Hide file tree
Showing 16 changed files with 69 additions and 12 deletions.
2 changes: 2 additions & 0 deletions api/src/main/java/ai/djl/BaseModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public abstract class BaseModel implements Model {
protected String modelName;
protected NDManager manager;
protected DataType dataType;
protected boolean wasLoaded;
protected PairList<String, Shape> inputData;
protected Map<String, Object> artifacts = new ConcurrentHashMap<>();
protected Map<String, String> properties = new ConcurrentHashMap<>();
Expand All @@ -78,6 +79,7 @@ public Block getBlock() {
/** {@inheritDoc} */
@Override
public void setBlock(Block block) {
wasLoaded = false;
this.block = block;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.Block;
import ai.djl.nn.core.Embedding;

import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -59,30 +60,47 @@ public TrainableWordEmbedding(Vocabulary vocabulary, int embeddingSize) {
.optUseDefault(false));
}

private TrainableWordEmbedding(NDArray embedding, List<String> items) {
super(embedding);
this.fallthroughEmbedding = new DefaultItem(DEFAULT_UNKNOWN_TOKEN);
this.vocabulary = new DefaultVocabulary(items);
}

private TrainableWordEmbedding(
NDArray embedding, List<String> items, SparseFormat sparseFormat) {
super(embedding, sparseFormat);
this.fallthroughEmbedding = new DefaultItem(DEFAULT_UNKNOWN_TOKEN);
this.vocabulary = new DefaultVocabulary(items);
}

/**
* Constructs a pretrained embedding.
*
* <p>Because it is created with preTrained data, it is created as a frozen block. If you with
* to update it, call {@link Block#freezeParameters(boolean)}.
*
* @param embedding the embedding array
* @param items the items in the embedding (in matching order to the embedding array)
* @return the created embedding
*/
public TrainableWordEmbedding(NDArray embedding, List<String> items) {
super(embedding);
this.fallthroughEmbedding = new DefaultItem(DEFAULT_UNKNOWN_TOKEN);
this.vocabulary = new DefaultVocabulary(items);
public static TrainableWordEmbedding fromPretrained(NDArray embedding, List<String> items) {
return new TrainableWordEmbedding(embedding, items);
}

/**
* Constructs a pretrained embedding.
*
* <p>Because it is created with preTrained data, it is created as a frozen block. If you with
* to update it, call {@link Block#freezeParameters(boolean)}.
*
* @param embedding the embedding array
* @param items the items in the embedding (in matching order to the embedding array)
* @param sparseFormat whether to compute row sparse gradient in the backward calculation
* @return the created embedding
*/
public TrainableWordEmbedding(
public static TrainableWordEmbedding fromPretrained(
NDArray embedding, List<String> items, SparseFormat sparseFormat) {
super(embedding, sparseFormat);
this.fallthroughEmbedding = new DefaultItem(DEFAULT_UNKNOWN_TOKEN);
this.vocabulary = new DefaultVocabulary(items);
return new TrainableWordEmbedding(embedding, items, sparseFormat);
}

/** {@inheritDoc} */
Expand Down
4 changes: 4 additions & 0 deletions api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@
* the javadoc at {@link ai.djl.training.Trainer}. At the end of training, a block represents a
* fully-trained model.
*
* <p>It is also possible to freeze parameters and blocks to avoid them being trained. When loading
* models or building blocks with preTrained data, they default to being frozen. If you wish to
* further refine these elements, use {@link Block#freezeParameters(boolean)} to unfreeze them.
*
* @see <a
* href="https://github.com/deepjavalibrary/djl/blob/master/jupyter/tutorial/01_create_your_first_network.ipynb">this
* tutorial on creating your first network</a>
Expand Down
4 changes: 4 additions & 0 deletions api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@ public class ConstantEmbedding extends AbstractBlock implements AbstractIndexedE
/**
* Constructs a constant embedding with the given constant.
*
* <p>The constant is assumed to be a fixed value, and starts out as frozen. To unfreeze, use
* {@link ai.djl.nn.Block#freezeParameters(boolean)}.
*
* @param embedding the value to return for all embeddings
*/
public ConstantEmbedding(NDArray embedding) {
this.embedding = embedding;
freezeParameters(true);
}

/** {@inheritDoc} */
Expand Down
8 changes: 6 additions & 2 deletions api/src/main/java/ai/djl/nn/core/Embedding.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,20 @@ protected Embedding(BaseBuilder<T, ?> baseBuilder) {
*
* @param embedding the embedding array
*/
public Embedding(NDArray embedding) {
protected Embedding(NDArray embedding) {
this(embedding, SparseFormat.DENSE);
}

/**
* Constructs a pretrained embedding.
*
* <p>Because it is created with preTrained data, it is created as a frozen block. If you with
* to update it, call {@link Block#freezeParameters(boolean)}.
*
* @param embedding the embedding array
* @param format whether to compute row sparse gradient in the backward calculation
*/
public Embedding(NDArray embedding, SparseFormat format) {
protected Embedding(NDArray embedding, SparseFormat format) {
super(VERSION);
numEmbeddings = Math.toIntExact(embedding.getShape().get(0));
embeddingSize = Math.toIntExact(embedding.getShape().get(1));
Expand All @@ -101,6 +104,7 @@ public Embedding(NDArray embedding, SparseFormat format) {
.build());
this.embedding.setArray(embedding);
inputShapes = new Shape[] {new Shape(-1)};
freezeParameters(true);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class LgbmModel extends BaseModel {
@Override
public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException {
setModelDir(modelPath);
wasLoaded = true;
if (block != null) {
throw new UnsupportedOperationException("LightGBM does not support dynamic blocks");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public class XgbModel extends BaseModel {
@Override
public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException {
setModelDir(modelPath);
wasLoaded = true;
if (block != null) {
throw new UnsupportedOperationException("XGBoost does not support dynamic blocks");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ public class MxModel extends BaseModel {
public void load(Path modelPath, String prefix, Map<String, ?> options)
throws IOException, MalformedModelException {
setModelDir(modelPath);
wasLoaded = true;
if (prefix == null) {
prefix = modelName;
}
Expand Down Expand Up @@ -137,6 +138,13 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
if (optimization != null) {
((MxSymbolBlock) block).optimizeFor(optimization);
}

// Freeze parameters to match Block spec for preTrained data
boolean trainParam =
options != null && Boolean.parseBoolean((String) options.get("trainParam"));
if (!trainParam) {
block.freezeParameters(true);
}
}

/** {@inheritDoc} */
Expand All @@ -147,6 +155,10 @@ public Trainer newTrainer(TrainingConfig trainingConfig) {
throw new IllegalStateException(
"You must set a block for the model before creating a new trainer");
}
if (wasLoaded) {
// Unfreeze parameters if training directly
block.freezeParameters(false);
}
for (Pair<Initializer, Predicate<Parameter>> pair : initializer) {
if (pair.getKey() != null && pair.getValue() != null) {
block.setInitializer(pair.getKey(), pair.getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ private static void buildEmbedding(String dir, String name) throws IOException {
NDArray idxToVec =
model.getNDManager().load(path.resolve("idx_to_vec.mx")).singletonOrThrow();
List<String> idxToToken = Utils.readLines(path.resolve("idx_to_token.txt"));
TrainableWordEmbedding embedding = new TrainableWordEmbedding(idxToVec, idxToToken);
TrainableWordEmbedding embedding =
TrainableWordEmbedding.fromPretrained(idxToVec, idxToToken);
model.setBlock(embedding);
model.save(path, name);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ public void trainWithCustomLayer() throws IOException, ModelException {
SymbolBlock mlp = (SymbolBlock) model.getBlock();
SequentialBlock newMlp = new SequentialBlock();
mlp.removeLastBlock();
mlp.freezeParameters(false);
newMlp.add(mlp);
Linear linear = Linear.builder().setUnits(10).build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public class OrtModel extends BaseModel {
public void load(Path modelPath, String prefix, Map<String, ?> options)
throws IOException, MalformedModelException {
setModelDir(modelPath);
wasLoaded = true;
if (block != null) {
throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public class PtModel extends BaseModel {
public void load(Path modelPath, String prefix, Map<String, ?> options)
throws IOException, MalformedModelException {
setModelDir(modelPath);
wasLoaded = true;
if (prefix == null) {
prefix = modelName;
}
Expand Down Expand Up @@ -186,6 +187,10 @@ public Trainer newTrainer(TrainingConfig trainingConfig) {
throw new IllegalStateException(
"You must set a block for the model before creating a new trainer");
}
if (wasLoaded) {
// Unfreeze parameters if training directly
block.freezeParameters(false);
}
for (Pair<Initializer, Predicate<Parameter>> pair : initializer) {
if (pair.getKey() != null && pair.getValue() != null) {
block.setInitializer(pair.getKey(), pair.getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public class TfModel extends BaseModel {
public void load(Path modelPath, String prefix, Map<String, ?> options)
throws FileNotFoundException, MalformedModelException {
setModelDir(modelPath);
wasLoaded = true;
if (prefix == null) {
prefix = modelName;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public class TrtModel extends BaseModel {
@Override
public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException {
setModelDir(modelPath);
wasLoaded = true;
if (block != null) {
throw new UnsupportedOperationException("TensorRT does not support dynamic blocks");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ private static Model getModel(Arguments arguments)
SequentialBlock newBlock = new SequentialBlock();
SymbolBlock block = (SymbolBlock) model.getBlock();
block.removeLastBlock();
block.freezeParameters(false);
newBlock.add(block);
// the original model don't include the flatten
// so apply the flatten here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public FtWord2VecWordEmbedding(Model model, Vocabulary vocabulary) {
throw new IllegalArgumentException("The FtWord2VecWordEmbedding requires an FtModel");
}

this.embedding = (FtAbstractBlock) model.getBlock();
this.embedding = ((FtModel) model).getBlock();
this.vocabulary = vocabulary;
}

Expand Down

0 comments on commit cba412d

Please sign in to comment.