From b931f8ea47456a19e35e7584b4c2c3f876aa4ef3 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 23 Aug 2023 10:51:59 -0700 Subject: [PATCH] [api] Adds Safetensors support Fixes #2719 --- api/src/main/java/ai/djl/ndarray/NDList.java | 131 ++++++++++++++++-- .../java/ai/djl/ndarray/types/DataType.java | 68 +++++++++ .../NoopServingTranslatorFactory.java | 8 +- .../test/java/ai/djl/ndarray/NDListTest.java | 17 ++- .../ai/djl/ndarray/types/DataTypeTest.java | 31 +++++ .../djl/translate/ServingTranslatorTest.java | 18 ++- api/src/test/resources/list.safetensors | Bin 0 -> 146 bytes .../timeseries/AirPassengersDeepAR.java | 2 +- .../timeseries/M5ForecastingDeepAR.java | 2 +- .../binary/NpBinaryTranslator.scala | 2 +- 10 files changed, 261 insertions(+), 18 deletions(-) create mode 100644 api/src/test/resources/list.safetensors diff --git a/api/src/main/java/ai/djl/ndarray/NDList.java b/api/src/main/java/ai/djl/ndarray/NDList.java index 27f2d91b0110..b8a72834ab4e 100644 --- a/api/src/main/java/ai/djl/ndarray/NDList.java +++ b/api/src/main/java/ai/djl/ndarray/NDList.java @@ -13,7 +13,13 @@ package ai.djl.ndarray; import ai.djl.Device; +import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; +import ai.djl.util.JsonUtils; +import ai.djl.util.Pair; + +import com.google.gson.JsonObject; +import com.google.gson.annotations.SerializedName; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -25,10 +31,14 @@ import java.io.PushbackInputStream; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.zip.ZipEntry; import java.util.zip.ZipInputStream; import java.util.zip.ZipOutputStream; @@ -84,7 +94,7 @@ public NDList(Collection other) { * @return {@code NDList} */ public static NDList decode(NDManager manager, byte[] byteArray) { - if (byteArray.length < 4) { + if (byteArray.length < 9) { throw new IllegalArgumentException("Invalid input length: " + byteArray.length); } try { @@ -96,6 +106,8 @@ public static NDList decode(NDManager manager, byte[] byteArray) { && byteArray[3] == 'M') { return new NDList( NDSerializer.decode(manager, new ByteArrayInputStream(byteArray))); + } else if (byteArray[8] == '{') { + return decodeSafetensors(manager, new ByteArrayInputStream(byteArray)); } ByteBuffer bb = ByteBuffer.wrap(byteArray); @@ -124,7 +136,7 @@ public static NDList decode(NDManager manager, byte[] byteArray) { public static NDList decode(NDManager manager, InputStream is) { try { DataInputStream dis = new DataInputStream(is); - byte[] magic = new byte[4]; + byte[] magic = new byte[9]; dis.readFully(magic); PushbackInputStream pis = new PushbackInputStream(is, 4); @@ -137,6 +149,8 @@ public static NDList decode(NDManager manager, InputStream is) { && magic[2] == 'U' && magic[3] == 'M') { return new NDList(NDSerializer.decode(manager, pis)); + } else if (magic[8] == '{') { + return decodeSafetensors(manager, pis); } dis = new DataInputStream(pis); @@ -154,6 +168,54 @@ public static NDList decode(NDManager manager, InputStream is) { } } + private static NDList decodeSafetensors(NDManager manager, InputStream is) throws IOException { + DataInputStream dis; + if (is instanceof DataInputStream) { + dis = (DataInputStream) is; + } else { + dis = new DataInputStream(is); + } + + byte[] buf = new byte[8]; + dis.readFully(buf); + int len = Math.toIntExact(ByteBuffer.wrap(buf).order(ByteOrder.LITTLE_ENDIAN).getLong()); + buf = new byte[len]; + dis.readFully(buf); + String json = new String(buf, StandardCharsets.UTF_8); + // rust implementation sort by name, our implementation preserve the order. + JsonObject jsonObject = JsonUtils.GSON.fromJson(json, JsonObject.class); + List> list = new ArrayList<>(); + int max = 0; + for (String key : jsonObject.keySet()) { + if ("__metadata__".equals(key)) { + continue; + } + SafeTensor value = JsonUtils.GSON.fromJson(jsonObject.get(key), SafeTensor.class); + if (value.offsets.length != 2) { + throw new IOException("Malformed safetensors metadata: " + json); + } + max = Math.max(max, value.offsets[1]); + list.add(new Pair<>(key, value)); + } + buf = new byte[max]; + dis.readFully(buf); + NDList ret = new NDList(list.size()); + for (Pair pair : list) { + if ("__metadata__".equals(pair.getKey())) { + continue; + } + SafeTensor st = pair.getValue(); + Shape shape = new Shape(st.shape); + ByteBuffer bb = ByteBuffer.wrap(buf, st.offsets[0], st.size()); + bb.order(ByteOrder.LITTLE_ENDIAN); + DataType dataType = DataType.fromSafetensors(st.dtype); + NDArray array = manager.create(bb, shape, dataType); + array.setName(pair.getKey()); + ret.add(array); + } + return ret; + } + private static NDList decodeNumpy(NDManager manager, InputStream is) throws IOException { NDList list = new NDList(); ZipInputStream zis = new ZipInputStream(is); @@ -340,18 +402,18 @@ public void detach() { * @return the byte array */ public byte[] encode() { - return encode(false); + return encode(Mode.ND_LIST); } /** * Encodes the NDList to byte array. * - * @param numpy encode in npz format if true + * @param mode encode mode, one of ndlist/npz/safetensor format * @return the byte array */ - public byte[] encode(boolean numpy) { + public byte[] encode(Mode mode) { try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { - encode(baos, numpy); + encode(baos, mode); return baos.toByteArray(); } catch (IOException e) { throw new AssertionError("NDList is not writable", e); @@ -365,18 +427,18 @@ public byte[] encode(boolean numpy) { * @throws IOException if failed on IO operation */ public void encode(OutputStream os) throws IOException { - encode(os, false); + encode(os, Mode.ND_LIST); } /** * Writes the encoded NDList to {@code OutputStream}. * * @param os the {@code OutputStream} to be written to - * @param numpy encode in npz format if true + * @param mode encode mode, one of ndlist/npz/safetensor format * @throws IOException if failed on IO operation */ - public void encode(OutputStream os, boolean numpy) throws IOException { - if (numpy) { + public void encode(OutputStream os, Mode mode) throws IOException { + if (mode == Mode.NPZ) { ZipOutputStream zos = new ZipOutputStream(os); int i = 0; for (NDArray nd : this) { @@ -392,6 +454,36 @@ public void encode(OutputStream os, boolean numpy) throws IOException { zos.finish(); zos.flush(); return; + } else if (mode == Mode.SAFETENSORS) { + Map map = new ConcurrentHashMap<>(size()); + int i = 0; + int offset = 0; + for (NDArray nd : this) { + String name = nd.getName(); + if (name == null) { + name = "arr_" + i; + ++i; + } + SafeTensor st = new SafeTensor(); + st.dtype = nd.getDataType().asSafetensors(); + st.shape = nd.getShape().getShape(); + long size = nd.getDataType().getNumOfBytes() * nd.size(); + int limit = offset + Math.toIntExact(size); + st.offsets = new int[] {offset, limit}; + map.put(name, st); + offset = limit; + } + byte[] json = JsonUtils.GSON.toJson(map).getBytes(StandardCharsets.UTF_8); + + ByteBuffer buf = ByteBuffer.allocate(8); + buf.order(ByteOrder.LITTLE_ENDIAN); + buf.putLong(0, json.length); + os.write(buf.array()); + os.write(json); + for (NDArray nd : this) { + os.write(nd.toByteArray()); + } + return; } DataOutputStream dos = new DataOutputStream(os); @@ -450,4 +542,23 @@ public String toString() { } return builder.toString(); } + + /** An enum represents NDList serialization format. */ + public enum Mode { + ND_LIST, + NPZ, + SAFETENSORS + } + + private static final class SafeTensor { + String dtype; + long[] shape; + + @SerializedName("data_offsets") + int[] offsets; + + int size() { + return offsets[1] - offsets[0]; + } + } } diff --git a/api/src/main/java/ai/djl/ndarray/types/DataType.java b/api/src/main/java/ai/djl/ndarray/types/DataType.java index 2313569e1171..8bc5e8fa255d 100644 --- a/api/src/main/java/ai/djl/ndarray/types/DataType.java +++ b/api/src/main/java/ai/djl/ndarray/types/DataType.java @@ -188,6 +188,37 @@ public static DataType fromNumpy(String dtype) { } } + /** + * Returns the data type from Safetensors value. + * + * @param dtype the Safetensors datatype + * @return the data type + */ + public static DataType fromSafetensors(String dtype) { + switch (dtype) { + case "F64": + return FLOAT64; + case "F32": + return FLOAT32; + case "F16": + return FLOAT16; + case "BF16": + return BFLOAT16; + case "I64": + return INT64; + case "I32": + return INT32; + case "I8": + return INT8; + case "U8": + return UINT8; + case "BOOL": + return BOOLEAN; + default: + throw new IllegalArgumentException("Unsupported safetensors dataType: " + dtype); + } + } + /** * Converts a {@link ByteBuffer} to a buffer for this data type. * @@ -260,6 +291,43 @@ public String asNumpy() { } } + /** + * Returns a safetensors string value. + * + * @return a safetensors string value + */ + public String asSafetensors() { + switch (this) { + case FLOAT64: + return "F64"; + case FLOAT32: + return "F32"; + case FLOAT16: + return "F16"; + case BFLOAT16: + return "BF16"; + case INT64: + return "I64"; + case INT32: + return "I32"; + case INT8: + return "I8"; + case UINT8: + return "U8"; + case BOOLEAN: + return "BOOL"; + case INT16: + case UINT64: + case UINT32: + case UINT16: + case STRING: + case COMPLEX64: + case UNKNOWN: + default: + throw new IllegalArgumentException("Unsupported dataType: " + this); + } + } + /** {@inheritDoc} */ @Override public String toString() { diff --git a/api/src/main/java/ai/djl/translate/NoopServingTranslatorFactory.java b/api/src/main/java/ai/djl/translate/NoopServingTranslatorFactory.java index 0935abdd41d3..f3b41e0aef58 100644 --- a/api/src/main/java/ai/djl/translate/NoopServingTranslatorFactory.java +++ b/api/src/main/java/ai/djl/translate/NoopServingTranslatorFactory.java @@ -82,10 +82,14 @@ public Output processOutput(TranslatorContext ctx, NDList list) { Output output = new Output(); if ("tensor/npz".equalsIgnoreCase(accept) || "tensor/npz".equalsIgnoreCase(contentType)) { - output.add(list.encode(true)); + output.add(list.encode(NDList.Mode.NPZ)); output.addProperty("Content-Type", "tensor/npz"); + } else if ("tensor/safetensors".equalsIgnoreCase(accept) + || "tensor/safetensors".equalsIgnoreCase(contentType)) { + output.add(list.encode(NDList.Mode.SAFETENSORS)); + output.addProperty("Content-Type", "tensor/safetensors"); } else { - output.add(list.encode(false)); + output.add(list.encode()); output.addProperty("Content-Type", "tensor/ndlist"); } return output; diff --git a/api/src/test/java/ai/djl/ndarray/NDListTest.java b/api/src/test/java/ai/djl/ndarray/NDListTest.java index 8ff0c044edc2..b450a04102d9 100644 --- a/api/src/test/java/ai/djl/ndarray/NDListTest.java +++ b/api/src/test/java/ai/djl/ndarray/NDListTest.java @@ -29,10 +29,25 @@ public void testNumpy() throws IOException { NDList decoded = NDList.decode(manager, data); ByteArrayOutputStream bos = new ByteArrayOutputStream(data.length + 1); - decoded.encode(bos, true); + decoded.encode(bos, NDList.Mode.NPZ); NDList list = NDList.decode(manager, bos.toByteArray()); Assert.assertEquals(list.size(), 2); Assert.assertEquals(list.get(0).getName(), "bool8"); } } + + @Test + public void testSafetensors() throws IOException { + try (NDManager manager = NDManager.newBaseManager(Device.cpu())) { + byte[] data = NDSerializerTest.readFile("list.safetensors"); + NDList decoded = NDList.decode(manager, data); + + ByteArrayOutputStream bos = new ByteArrayOutputStream(data.length + 1); + decoded.encode(bos, NDList.Mode.SAFETENSORS); + NDList list = NDList.decode(manager, bos.toByteArray()); + Assert.assertEquals(list.size(), 2); + Assert.assertEquals(list.get(0).getName(), "attention"); + Assert.assertEquals(list.get(0).toByteArray(), new byte[] {0, 1, 2, 3, 4, 5}); + } + } } diff --git a/api/src/test/java/ai/djl/ndarray/types/DataTypeTest.java b/api/src/test/java/ai/djl/ndarray/types/DataTypeTest.java index 23b9f2d687fb..74596bb6fe28 100644 --- a/api/src/test/java/ai/djl/ndarray/types/DataTypeTest.java +++ b/api/src/test/java/ai/djl/ndarray/types/DataTypeTest.java @@ -38,4 +38,35 @@ public void numpyTest() { Assert.expectThrows(IllegalArgumentException.class, DataType.UNKNOWN::asNumpy); Assert.expectThrows(IllegalArgumentException.class, () -> DataType.fromNumpy("|i8")); } + + @Test + public void safetensorsTest() { + Assert.assertEquals(DataType.FLOAT64.asSafetensors(), "F64"); + Assert.assertEquals(DataType.FLOAT32.asSafetensors(), "F32"); + Assert.assertEquals(DataType.FLOAT16.asSafetensors(), "F16"); + Assert.assertEquals(DataType.BFLOAT16.asSafetensors(), "BF16"); + Assert.assertEquals(DataType.INT64.asSafetensors(), "I64"); + Assert.assertEquals(DataType.INT32.asSafetensors(), "I32"); + Assert.assertEquals(DataType.INT8.asSafetensors(), "I8"); + Assert.assertEquals(DataType.UINT8.asSafetensors(), "U8"); + Assert.assertEquals(DataType.BOOLEAN.asSafetensors(), "BOOL"); + + Assert.assertEquals(DataType.fromSafetensors("F64"), DataType.FLOAT64); + Assert.assertEquals(DataType.fromSafetensors("F32"), DataType.FLOAT32); + Assert.assertEquals(DataType.fromSafetensors("F16"), DataType.FLOAT16); + Assert.assertEquals(DataType.fromSafetensors("BF16"), DataType.BFLOAT16); + Assert.assertEquals(DataType.fromSafetensors("I64"), DataType.INT64); + Assert.assertEquals(DataType.fromSafetensors("I32"), DataType.INT32); + Assert.assertEquals(DataType.fromSafetensors("I8"), DataType.INT8); + Assert.assertEquals(DataType.fromSafetensors("U8"), DataType.UINT8); + Assert.assertEquals(DataType.fromSafetensors("BOOL"), DataType.BOOLEAN); + + Assert.expectThrows(IllegalArgumentException.class, DataType.UINT64::asSafetensors); + Assert.expectThrows(IllegalArgumentException.class, DataType.UINT32::asSafetensors); + Assert.expectThrows(IllegalArgumentException.class, DataType.UINT16::asSafetensors); + Assert.expectThrows(IllegalArgumentException.class, DataType.COMPLEX64::asSafetensors); + Assert.expectThrows(IllegalArgumentException.class, DataType.STRING::asSafetensors); + Assert.expectThrows(IllegalArgumentException.class, DataType.UNKNOWN::asSafetensors); + Assert.expectThrows(IllegalArgumentException.class, () -> DataType.fromSafetensors("U16")); + } } diff --git a/api/src/test/java/ai/djl/translate/ServingTranslatorTest.java b/api/src/test/java/ai/djl/translate/ServingTranslatorTest.java index 7fed3025de5f..e95676227c13 100644 --- a/api/src/test/java/ai/djl/translate/ServingTranslatorTest.java +++ b/api/src/test/java/ai/djl/translate/ServingTranslatorTest.java @@ -45,6 +45,15 @@ public void tierDown() { @Test public void testNumpy() throws IOException, TranslateException, ModelException { + test("tensor/npz"); + } + + @Test + public void testSafetensors() throws IOException, TranslateException, ModelException { + test("tensor/safetensors"); + } + + private void test(String contentType) throws IOException, TranslateException, ModelException { Path path = Paths.get("build/model"); Files.createDirectories(path); Input input = new Input(); @@ -58,8 +67,13 @@ public void testNumpy() throws IOException, TranslateException, ModelException { model.close(); NDList list = new NDList(); list.add(manager.create(10f)); - input.add(list.encode(true)); - input.add("Content-Type", "tensor/npz"); + if ("tensor/safetensors".equalsIgnoreCase(contentType)) { + input.add(list.encode(NDList.Mode.SAFETENSORS)); + input.add("Content-Type", "tensor/safetensors"); + } else { + input.add(list.encode(NDList.Mode.NPZ)); + input.add("Content-Type", "tensor/npz"); + } } Criteria criteria = diff --git a/api/src/test/resources/list.safetensors b/api/src/test/resources/list.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..16148cf507ee54adc2272195e80a84410b2f41f3 GIT binary patch literal 146 zcmZo*fPiYH#FCQKypqiPJSD4YrIeD&f>b3dB~J?_9i`%oL=Y$1NXIx predict() private static void saveNDArray(NDArray array) throws IOException { Path path = Paths.get("build").resolve(array.getName() + ".npz"); try (OutputStream os = Files.newOutputStream(path)) { - new NDList(new NDList(array)).encode(os, true); + new NDList(new NDList(array)).encode(os, NDList.Mode.NPZ); } } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/translator/binary/NpBinaryTranslator.scala b/extensions/spark/src/main/scala/ai/djl/spark/translator/binary/NpBinaryTranslator.scala index 95cecc8527d7..01db3944ed2c 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/translator/binary/NpBinaryTranslator.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/translator/binary/NpBinaryTranslator.scala @@ -27,7 +27,7 @@ class NpBinaryTranslator(val batchifier: Batchifier) extends Translator[Array[By /** @inheritdoc */ override def processOutput(ctx: TranslatorContext, list: NDList): Array[Byte] = { - list.encode(true) + list.encode(NDList.Mode.NPZ) } /** @inheritdoc */