Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creates standard for PreTrained behavior #2360

Merged
merged 1 commit into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/src/main/java/ai/djl/BaseModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public abstract class BaseModel implements Model {
protected String modelName;
protected NDManager manager;
protected DataType dataType;
protected boolean wasLoaded;
protected PairList<String, Shape> inputData;
protected Map<String, Object> artifacts = new ConcurrentHashMap<>();
protected Map<String, String> properties = new ConcurrentHashMap<>();
Expand All @@ -78,6 +79,7 @@ public Block getBlock() {
/** {@inheritDoc} */
@Override
public void setBlock(Block block) {
wasLoaded = false;
this.block = block;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.Block;
import ai.djl.nn.core.Embedding;

import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -59,30 +60,47 @@ public TrainableWordEmbedding(Vocabulary vocabulary, int embeddingSize) {
.optUseDefault(false));
}

private TrainableWordEmbedding(NDArray embedding, List<String> items) {
super(embedding);
this.fallthroughEmbedding = new DefaultItem(DEFAULT_UNKNOWN_TOKEN);
this.vocabulary = new DefaultVocabulary(items);
}

private TrainableWordEmbedding(
NDArray embedding, List<String> items, SparseFormat sparseFormat) {
super(embedding, sparseFormat);
this.fallthroughEmbedding = new DefaultItem(DEFAULT_UNKNOWN_TOKEN);
this.vocabulary = new DefaultVocabulary(items);
}

/**
* Constructs a pretrained embedding.
*
* <p>Because it is created with preTrained data, it is created as a frozen block. If you with
* to update it, call {@link Block#freezeParameters(boolean)}.
*
* @param embedding the embedding array
* @param items the items in the embedding (in matching order to the embedding array)
* @return the created embedding
*/
public TrainableWordEmbedding(NDArray embedding, List<String> items) {
super(embedding);
this.fallthroughEmbedding = new DefaultItem(DEFAULT_UNKNOWN_TOKEN);
this.vocabulary = new DefaultVocabulary(items);
public static TrainableWordEmbedding fromPretrained(NDArray embedding, List<String> items) {
return new TrainableWordEmbedding(embedding, items);
}

/**
* Constructs a pretrained embedding.
*
* <p>Because it is created with preTrained data, it is created as a frozen block. If you with
* to update it, call {@link Block#freezeParameters(boolean)}.
*
* @param embedding the embedding array
* @param items the items in the embedding (in matching order to the embedding array)
* @param sparseFormat whether to compute row sparse gradient in the backward calculation
* @return the created embedding
*/
public TrainableWordEmbedding(
public static TrainableWordEmbedding fromPretrained(
NDArray embedding, List<String> items, SparseFormat sparseFormat) {
super(embedding, sparseFormat);
this.fallthroughEmbedding = new DefaultItem(DEFAULT_UNKNOWN_TOKEN);
this.vocabulary = new DefaultVocabulary(items);
return new TrainableWordEmbedding(embedding, items, sparseFormat);
}

/** {@inheritDoc} */
Expand Down
4 changes: 4 additions & 0 deletions api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@
* the javadoc at {@link ai.djl.training.Trainer}. At the end of training, a block represents a
* fully-trained model.
*
* <p>It is also possible to freeze parameters and blocks to avoid them being trained. When loading
* models or building blocks with preTrained data, they default to being frozen. If you wish to
* further refine these elements, use {@link Block#freezeParameters(boolean)} to unfreeze them.
*
* @see <a
* href="https://github.com/deepjavalibrary/djl/blob/master/jupyter/tutorial/01_create_your_first_network.ipynb">this
* tutorial on creating your first network</a>
Expand Down
4 changes: 4 additions & 0 deletions api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@ public class ConstantEmbedding extends AbstractBlock implements AbstractIndexedE
/**
* Constructs a constant embedding with the given constant.
*
* <p>The constant is assumed to be a fixed value, and starts out as frozen. To unfreeze, use
* {@link ai.djl.nn.Block#freezeParameters(boolean)}.
*
* @param embedding the value to return for all embeddings
*/
public ConstantEmbedding(NDArray embedding) {
this.embedding = embedding;
freezeParameters(true);
}

/** {@inheritDoc} */
Expand Down
8 changes: 6 additions & 2 deletions api/src/main/java/ai/djl/nn/core/Embedding.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,20 @@ protected Embedding(BaseBuilder<T, ?> baseBuilder) {
*
* @param embedding the embedding array
*/
public Embedding(NDArray embedding) {
protected Embedding(NDArray embedding) {
this(embedding, SparseFormat.DENSE);
}

/**
* Constructs a pretrained embedding.
*
* <p>Because it is created with preTrained data, it is created as a frozen block. If you with
* to update it, call {@link Block#freezeParameters(boolean)}.
*
* @param embedding the embedding array
* @param format whether to compute row sparse gradient in the backward calculation
*/
public Embedding(NDArray embedding, SparseFormat format) {
protected Embedding(NDArray embedding, SparseFormat format) {
super(VERSION);
numEmbeddings = Math.toIntExact(embedding.getShape().get(0));
embeddingSize = Math.toIntExact(embedding.getShape().get(1));
Expand All @@ -101,6 +104,7 @@ public Embedding(NDArray embedding, SparseFormat format) {
.build());
this.embedding.setArray(embedding);
inputShapes = new Shape[] {new Shape(-1)};
freezeParameters(true);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class LgbmModel extends BaseModel {
@Override
public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException {
setModelDir(modelPath);
wasLoaded = true;
if (block != null) {
throw new UnsupportedOperationException("LightGBM does not support dynamic blocks");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public class XgbModel extends BaseModel {
@Override
public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException {
setModelDir(modelPath);
wasLoaded = true;
if (block != null) {
throw new UnsupportedOperationException("XGBoost does not support dynamic blocks");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ public class MxModel extends BaseModel {
public void load(Path modelPath, String prefix, Map<String, ?> options)
throws IOException, MalformedModelException {
setModelDir(modelPath);
wasLoaded = true;
if (prefix == null) {
prefix = modelName;
}
Expand Down Expand Up @@ -137,6 +138,13 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
if (optimization != null) {
((MxSymbolBlock) block).optimizeFor(optimization);
}

// Freeze parameters to match Block spec for preTrained data
boolean trainParam =
options != null && Boolean.parseBoolean((String) options.get("trainParam"));
if (!trainParam) {
block.freezeParameters(true);
}
}

/** {@inheritDoc} */
Expand All @@ -147,6 +155,10 @@ public Trainer newTrainer(TrainingConfig trainingConfig) {
throw new IllegalStateException(
"You must set a block for the model before creating a new trainer");
}
if (wasLoaded) {
// Unfreeze parameters if training directly
block.freezeParameters(false);
}
for (Pair<Initializer, Predicate<Parameter>> pair : initializer) {
if (pair.getKey() != null && pair.getValue() != null) {
block.setInitializer(pair.getKey(), pair.getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ private static void buildEmbedding(String dir, String name) throws IOException {
NDArray idxToVec =
model.getNDManager().load(path.resolve("idx_to_vec.mx")).singletonOrThrow();
List<String> idxToToken = Utils.readLines(path.resolve("idx_to_token.txt"));
TrainableWordEmbedding embedding = new TrainableWordEmbedding(idxToVec, idxToToken);
TrainableWordEmbedding embedding =
TrainableWordEmbedding.fromPretrained(idxToVec, idxToToken);
model.setBlock(embedding);
model.save(path, name);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ public void trainWithCustomLayer() throws IOException, ModelException {
SymbolBlock mlp = (SymbolBlock) model.getBlock();
SequentialBlock newMlp = new SequentialBlock();
mlp.removeLastBlock();
mlp.freezeParameters(false);
newMlp.add(mlp);
Linear linear = Linear.builder().setUnits(10).build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public class OrtModel extends BaseModel {
public void load(Path modelPath, String prefix, Map<String, ?> options)
throws IOException, MalformedModelException {
setModelDir(modelPath);
wasLoaded = true;
if (block != null) {
throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public class PtModel extends BaseModel {
public void load(Path modelPath, String prefix, Map<String, ?> options)
throws IOException, MalformedModelException {
setModelDir(modelPath);
wasLoaded = true;
if (prefix == null) {
prefix = modelName;
}
Expand Down Expand Up @@ -186,6 +187,10 @@ public Trainer newTrainer(TrainingConfig trainingConfig) {
throw new IllegalStateException(
"You must set a block for the model before creating a new trainer");
}
if (wasLoaded) {
// Unfreeze parameters if training directly
block.freezeParameters(false);
}
for (Pair<Initializer, Predicate<Parameter>> pair : initializer) {
if (pair.getKey() != null && pair.getValue() != null) {
block.setInitializer(pair.getKey(), pair.getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public class TfModel extends BaseModel {
public void load(Path modelPath, String prefix, Map<String, ?> options)
throws FileNotFoundException, MalformedModelException {
setModelDir(modelPath);
wasLoaded = true;
if (prefix == null) {
prefix = modelName;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public class TrtModel extends BaseModel {
@Override
public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException {
setModelDir(modelPath);
wasLoaded = true;
if (block != null) {
throw new UnsupportedOperationException("TensorRT does not support dynamic blocks");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ private static Model getModel(Arguments arguments)
SequentialBlock newBlock = new SequentialBlock();
SymbolBlock block = (SymbolBlock) model.getBlock();
block.removeLastBlock();
block.freezeParameters(false);
newBlock.add(block);
// the original model don't include the flatten
// so apply the flatten here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public FtWord2VecWordEmbedding(Model model, Vocabulary vocabulary) {
throw new IllegalArgumentException("The FtWord2VecWordEmbedding requires an FtModel");
}

this.embedding = (FtAbstractBlock) model.getBlock();
this.embedding = ((FtModel) model).getBlock();
this.vocabulary = vocabulary;
}

Expand Down