From 738a329bb27107fd247194ad3981292057cfaf68 Mon Sep 17 00:00:00 2001 From: Vanja Radulovic Date: Wed, 22 Jan 2025 14:11:05 +0100 Subject: [PATCH 1/3] Expose OnnxRuntime getMetadata() in DJL API --- api/src/main/java/ai/djl/BaseModel.java | 9 +++++++++ api/src/main/java/ai/djl/Model.java | 9 +++++++++ api/src/main/java/ai/djl/nn/AbstractBaseBlock.java | 10 +++++++++- api/src/main/java/ai/djl/nn/Block.java | 9 ++++++++- .../main/java/ai/djl/repository/zoo/ZooModel.java | 5 +++++ .../ai/djl/onnxruntime/engine/OrtSymbolBlock.java | 14 ++++++++++++++ 6 files changed, 54 insertions(+), 2 deletions(-) diff --git a/api/src/main/java/ai/djl/BaseModel.java b/api/src/main/java/ai/djl/BaseModel.java index 5705d480b38..a9ce3489ea7 100644 --- a/api/src/main/java/ai/djl/BaseModel.java +++ b/api/src/main/java/ai/djl/BaseModel.java @@ -63,6 +63,7 @@ public abstract class BaseModel implements Model { protected DataType dataType; protected boolean wasLoaded; protected PairList inputData; + protected PairList metadata; protected Map artifacts = new ConcurrentHashMap<>(); protected Map properties = new ConcurrentHashMap<>(); @@ -141,6 +142,14 @@ public PairList describeInput() { return inputData; } + @Override + public PairList getCustomMetadata() { + if (metadata == null) { + metadata = block.getCustomMetadata(); + } + return metadata; + } + /** {@inheritDoc} */ @Override public PairList describeOutput() { diff --git a/api/src/main/java/ai/djl/Model.java b/api/src/main/java/ai/djl/Model.java index dab2568949f..34fed685c48 100644 --- a/api/src/main/java/ai/djl/Model.java +++ b/api/src/main/java/ai/djl/Model.java @@ -311,6 +311,15 @@ default Predictor newPredictor(Translator translator) { */ PairList describeInput(); + /** + * Returns the custom metadata of the model. + * + *

It contains the custom metadata information that can be obtained from the model. + * + * @return a PairList of String and String + */ + PairList getCustomMetadata(); + /** * Returns the output descriptor of the model. * diff --git a/api/src/main/java/ai/djl/nn/AbstractBaseBlock.java b/api/src/main/java/ai/djl/nn/AbstractBaseBlock.java index 9fabd2b4df7..53258d7cf53 100644 --- a/api/src/main/java/ai/djl/nn/AbstractBaseBlock.java +++ b/api/src/main/java/ai/djl/nn/AbstractBaseBlock.java @@ -377,4 +377,12 @@ public Shape[] getInputShapes() { public DataType[] getOutputDataTypes() { return outputDataTypes; } -} + + /** + * {@inheritDoc} + */ + @Override + public PairList getCustomMetadata() { + return null; + } +} \ No newline at end of file diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java index 56c92a23e6b..52663abfde6 100644 --- a/api/src/main/java/ai/djl/nn/Block.java +++ b/api/src/main/java/ai/djl/nn/Block.java @@ -353,4 +353,11 @@ static void validateLayout(LayoutType[] expectedLayout, LayoutType[] actualLayou } } } -} + + /** + * Returns a map of all the custom metadata of the block. + * + * @return the map of {@link PairList} + */ + PairList getCustomMetadata(); +} \ No newline at end of file diff --git a/api/src/main/java/ai/djl/repository/zoo/ZooModel.java b/api/src/main/java/ai/djl/repository/zoo/ZooModel.java index bf1595445a8..ecbd9c4bd53 100644 --- a/api/src/main/java/ai/djl/repository/zoo/ZooModel.java +++ b/api/src/main/java/ai/djl/repository/zoo/ZooModel.java @@ -176,6 +176,11 @@ public PairList describeOutput() { return model.describeOutput(); } + @Override + public PairList getCustomMetadata() { + return model.getCustomMetadata(); + } + /** {@inheritDoc} */ @Override public String[] getArtifactNames() { diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java index 4e8df210d40..9ac9809db72 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java @@ -26,6 +26,7 @@ import ai.djl.util.PairList; import ai.onnxruntime.OnnxJavaType; import ai.onnxruntime.OnnxMap; +import ai.onnxruntime.OnnxModelMetadata; import ai.onnxruntime.OnnxSequence; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OnnxValue; @@ -128,6 +129,19 @@ public PairList describeInput() { return result; } + /** + * {@inheritDoc} + */ + @Override + public PairList getCustomMetadata() { + try { + OnnxModelMetadata modelMetadata = session.getMetadata(); + return new PairList<>(modelMetadata.getCustomMetadata()); + } catch (OrtException e) { + throw new EngineException(e); + } + } + private NDList evaluateOutput(OrtSession.Result results) { NDList output = new NDList(); for (Map.Entry r : results) { From d92cc0fecfb4571ce7993c79288d529be82fc790 Mon Sep 17 00:00:00 2001 From: Vanja Radulovic Date: Fri, 24 Jan 2025 22:59:25 +0100 Subject: [PATCH 2/3] Revert "Expose OnnxRuntime getMetadata() in DJL API" This reverts commit 738a329bb27107fd247194ad3981292057cfaf68. --- api/src/main/java/ai/djl/BaseModel.java | 9 --------- api/src/main/java/ai/djl/Model.java | 9 --------- api/src/main/java/ai/djl/nn/AbstractBaseBlock.java | 10 +--------- api/src/main/java/ai/djl/nn/Block.java | 9 +-------- .../main/java/ai/djl/repository/zoo/ZooModel.java | 5 ----- .../ai/djl/onnxruntime/engine/OrtSymbolBlock.java | 14 -------------- 6 files changed, 2 insertions(+), 54 deletions(-) diff --git a/api/src/main/java/ai/djl/BaseModel.java b/api/src/main/java/ai/djl/BaseModel.java index a9ce3489ea7..5705d480b38 100644 --- a/api/src/main/java/ai/djl/BaseModel.java +++ b/api/src/main/java/ai/djl/BaseModel.java @@ -63,7 +63,6 @@ public abstract class BaseModel implements Model { protected DataType dataType; protected boolean wasLoaded; protected PairList inputData; - protected PairList metadata; protected Map artifacts = new ConcurrentHashMap<>(); protected Map properties = new ConcurrentHashMap<>(); @@ -142,14 +141,6 @@ public PairList describeInput() { return inputData; } - @Override - public PairList getCustomMetadata() { - if (metadata == null) { - metadata = block.getCustomMetadata(); - } - return metadata; - } - /** {@inheritDoc} */ @Override public PairList describeOutput() { diff --git a/api/src/main/java/ai/djl/Model.java b/api/src/main/java/ai/djl/Model.java index 34fed685c48..dab2568949f 100644 --- a/api/src/main/java/ai/djl/Model.java +++ b/api/src/main/java/ai/djl/Model.java @@ -311,15 +311,6 @@ default Predictor newPredictor(Translator translator) { */ PairList describeInput(); - /** - * Returns the custom metadata of the model. - * - *

It contains the custom metadata information that can be obtained from the model. - * - * @return a PairList of String and String - */ - PairList getCustomMetadata(); - /** * Returns the output descriptor of the model. * diff --git a/api/src/main/java/ai/djl/nn/AbstractBaseBlock.java b/api/src/main/java/ai/djl/nn/AbstractBaseBlock.java index 53258d7cf53..9fabd2b4df7 100644 --- a/api/src/main/java/ai/djl/nn/AbstractBaseBlock.java +++ b/api/src/main/java/ai/djl/nn/AbstractBaseBlock.java @@ -377,12 +377,4 @@ public Shape[] getInputShapes() { public DataType[] getOutputDataTypes() { return outputDataTypes; } - - /** - * {@inheritDoc} - */ - @Override - public PairList getCustomMetadata() { - return null; - } -} \ No newline at end of file +} diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java index 52663abfde6..56c92a23e6b 100644 --- a/api/src/main/java/ai/djl/nn/Block.java +++ b/api/src/main/java/ai/djl/nn/Block.java @@ -353,11 +353,4 @@ static void validateLayout(LayoutType[] expectedLayout, LayoutType[] actualLayou } } } - - /** - * Returns a map of all the custom metadata of the block. - * - * @return the map of {@link PairList} - */ - PairList getCustomMetadata(); -} \ No newline at end of file +} diff --git a/api/src/main/java/ai/djl/repository/zoo/ZooModel.java b/api/src/main/java/ai/djl/repository/zoo/ZooModel.java index ecbd9c4bd53..bf1595445a8 100644 --- a/api/src/main/java/ai/djl/repository/zoo/ZooModel.java +++ b/api/src/main/java/ai/djl/repository/zoo/ZooModel.java @@ -176,11 +176,6 @@ public PairList describeOutput() { return model.describeOutput(); } - @Override - public PairList getCustomMetadata() { - return model.getCustomMetadata(); - } - /** {@inheritDoc} */ @Override public String[] getArtifactNames() { diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java index 9ac9809db72..4e8df210d40 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java @@ -26,7 +26,6 @@ import ai.djl.util.PairList; import ai.onnxruntime.OnnxJavaType; import ai.onnxruntime.OnnxMap; -import ai.onnxruntime.OnnxModelMetadata; import ai.onnxruntime.OnnxSequence; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OnnxValue; @@ -129,19 +128,6 @@ public PairList describeInput() { return result; } - /** - * {@inheritDoc} - */ - @Override - public PairList getCustomMetadata() { - try { - OnnxModelMetadata modelMetadata = session.getMetadata(); - return new PairList<>(modelMetadata.getCustomMetadata()); - } catch (OrtException e) { - throw new EngineException(e); - } - } - private NDList evaluateOutput(OrtSession.Result results) { NDList output = new NDList(); for (Map.Entry r : results) { From 97a2913f13ed8ff6074dbc4a4429b0884a42e65c Mon Sep 17 00:00:00 2001 From: Vanja Radulovic Date: Fri, 24 Jan 2025 23:26:05 +0100 Subject: [PATCH 3/3] Expose OnnxRuntime getMetadata() in DJL API --- api/src/main/java/ai/djl/nn/Block.java | 11 +++++++++++ .../ai/djl/onnxruntime/engine/OrtSymbolBlock.java | 12 ++++++++++++ 2 files changed, 23 insertions(+) diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java index 56c92a23e6b..14bd85735ab 100644 --- a/api/src/main/java/ai/djl/nn/Block.java +++ b/api/src/main/java/ai/djl/nn/Block.java @@ -26,6 +26,8 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.util.Collections; +import java.util.Map; import java.util.function.Predicate; /** @@ -353,4 +355,13 @@ static void validateLayout(LayoutType[] expectedLayout, LayoutType[] actualLayou } } } + + /** + * Returns a map of all the custom metadata of the block. + * + * @return the map of {@link PairList} + */ + default Map getCustomMetadata() { + return Collections.emptyMap(); + } } diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java index 4e8df210d40..9eb0f875140 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java @@ -26,6 +26,7 @@ import ai.djl.util.PairList; import ai.onnxruntime.OnnxJavaType; import ai.onnxruntime.OnnxMap; +import ai.onnxruntime.OnnxModelMetadata; import ai.onnxruntime.OnnxSequence; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OnnxValue; @@ -128,6 +129,17 @@ public PairList describeInput() { return result; } + /** {@inheritDoc} */ + @Override + public Map getCustomMetadata() { + try { + OnnxModelMetadata modelMetadata = session.getMetadata(); + return modelMetadata.getCustomMetadata(); + } catch (OrtException e) { + throw new EngineException(e); + } + } + private NDList evaluateOutput(OrtSession.Result results) { NDList output = new NDList(); for (Map.Entry r : results) {