Skip to content

Commit

Permalink
[api] Adds Safetensors support
Browse files Browse the repository at this point in the history
Fixes #2719
  • Loading branch information
frankfliu committed Aug 23, 2023
1 parent 59af05f commit d8b8b71
Show file tree
Hide file tree
Showing 10 changed files with 262 additions and 19 deletions.
133 changes: 122 additions & 11 deletions api/src/main/java/ai/djl/ndarray/NDList.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
package ai.djl.ndarray;

import ai.djl.Device;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.JsonUtils;
import ai.djl.util.Pair;

import com.google.gson.JsonObject;
import com.google.gson.annotations.SerializedName;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand All @@ -25,10 +31,14 @@
import java.io.PushbackInputStream;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
Expand Down Expand Up @@ -84,7 +94,7 @@ public NDList(Collection<NDArray> other) {
* @return {@code NDList}
*/
public static NDList decode(NDManager manager, byte[] byteArray) {
if (byteArray.length < 4) {
if (byteArray.length < 9) {
throw new IllegalArgumentException("Invalid input length: " + byteArray.length);
}
try {
Expand All @@ -96,6 +106,8 @@ public static NDList decode(NDManager manager, byte[] byteArray) {
&& byteArray[3] == 'M') {
return new NDList(
NDSerializer.decode(manager, new ByteArrayInputStream(byteArray)));
} else if (byteArray[8] == '{') {
return decodeSafetensors(manager, new ByteArrayInputStream(byteArray));
}

ByteBuffer bb = ByteBuffer.wrap(byteArray);
Expand Down Expand Up @@ -124,10 +136,10 @@ public static NDList decode(NDManager manager, byte[] byteArray) {
public static NDList decode(NDManager manager, InputStream is) {
try {
DataInputStream dis = new DataInputStream(is);
byte[] magic = new byte[4];
byte[] magic = new byte[9];
dis.readFully(magic);

PushbackInputStream pis = new PushbackInputStream(is, 4);
PushbackInputStream pis = new PushbackInputStream(is, 9);
pis.unread(magic);
if (magic[0] == 'P' && magic[1] == 'K') {
// assume this is npz file
Expand All @@ -137,6 +149,8 @@ public static NDList decode(NDManager manager, InputStream is) {
&& magic[2] == 'U'
&& magic[3] == 'M') {
return new NDList(NDSerializer.decode(manager, pis));
} else if (magic[8] == '{') {
return decodeSafetensors(manager, pis);
}

dis = new DataInputStream(pis);
Expand All @@ -154,6 +168,54 @@ public static NDList decode(NDManager manager, InputStream is) {
}
}

private static NDList decodeSafetensors(NDManager manager, InputStream is) throws IOException {
DataInputStream dis;
if (is instanceof DataInputStream) {
dis = (DataInputStream) is;
} else {
dis = new DataInputStream(is);
}

byte[] buf = new byte[8];
dis.readFully(buf);
int len = Math.toIntExact(ByteBuffer.wrap(buf).order(ByteOrder.LITTLE_ENDIAN).getLong());
buf = new byte[len];
dis.readFully(buf);
String json = new String(buf, StandardCharsets.UTF_8);
// rust implementation sort by name, our implementation preserve the order.
JsonObject jsonObject = JsonUtils.GSON.fromJson(json, JsonObject.class);
List<Pair<String, SafeTensor>> list = new ArrayList<>();
int max = 0;
for (String key : jsonObject.keySet()) {
if ("__metadata__".equals(key)) {
continue;
}
SafeTensor value = JsonUtils.GSON.fromJson(jsonObject.get(key), SafeTensor.class);
if (value.offsets.length != 2) {
throw new IOException("Malformed safetensors metadata: " + json);
}
max = Math.max(max, value.offsets[1]);
list.add(new Pair<>(key, value));
}
buf = new byte[max];
dis.readFully(buf);
NDList ret = new NDList(list.size());
for (Pair<String, SafeTensor> pair : list) {
if ("__metadata__".equals(pair.getKey())) {
continue;
}
SafeTensor st = pair.getValue();
Shape shape = new Shape(st.shape);
ByteBuffer bb = ByteBuffer.wrap(buf, st.offsets[0], st.size());
bb.order(ByteOrder.LITTLE_ENDIAN);
DataType dataType = DataType.fromSafetensors(st.dtype);
NDArray array = manager.create(bb, shape, dataType);
array.setName(pair.getKey());
ret.add(array);
}
return ret;
}

private static NDList decodeNumpy(NDManager manager, InputStream is) throws IOException {
NDList list = new NDList();
ZipInputStream zis = new ZipInputStream(is);
Expand Down Expand Up @@ -340,18 +402,18 @@ public void detach() {
* @return the byte array
*/
public byte[] encode() {
return encode(false);
return encode(Encoding.ND_LIST);
}

/**
* Encodes the NDList to byte array.
*
* @param numpy encode in npz format if true
* @param encoding encode mode, one of ndlist/npz/safetensor format
* @return the byte array
*/
public byte[] encode(boolean numpy) {
public byte[] encode(Encoding encoding) {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
encode(baos, numpy);
encode(baos, encoding);
return baos.toByteArray();
} catch (IOException e) {
throw new AssertionError("NDList is not writable", e);
Expand All @@ -365,18 +427,18 @@ public byte[] encode(boolean numpy) {
* @throws IOException if failed on IO operation
*/
public void encode(OutputStream os) throws IOException {
encode(os, false);
encode(os, Encoding.ND_LIST);
}

/**
* Writes the encoded NDList to {@code OutputStream}.
*
* @param os the {@code OutputStream} to be written to
* @param numpy encode in npz format if true
* @param encoding encode mode, one of ndlist/npz/safetensor format
* @throws IOException if failed on IO operation
*/
public void encode(OutputStream os, boolean numpy) throws IOException {
if (numpy) {
public void encode(OutputStream os, Encoding encoding) throws IOException {
if (encoding == Encoding.NPZ) {
ZipOutputStream zos = new ZipOutputStream(os);
int i = 0;
for (NDArray nd : this) {
Expand All @@ -392,6 +454,36 @@ public void encode(OutputStream os, boolean numpy) throws IOException {
zos.finish();
zos.flush();
return;
} else if (encoding == Encoding.SAFETENSORS) {
Map<String, SafeTensor> map = new ConcurrentHashMap<>(size());
int i = 0;
int offset = 0;
for (NDArray nd : this) {
String name = nd.getName();
if (name == null) {
name = "arr_" + i;
++i;
}
SafeTensor st = new SafeTensor();
st.dtype = nd.getDataType().asSafetensors();
st.shape = nd.getShape().getShape();
long size = nd.getDataType().getNumOfBytes() * nd.size();
int limit = offset + Math.toIntExact(size);
st.offsets = new int[] {offset, limit};
map.put(name, st);
offset = limit;
}
byte[] json = JsonUtils.GSON.toJson(map).getBytes(StandardCharsets.UTF_8);

ByteBuffer buf = ByteBuffer.allocate(8);
buf.order(ByteOrder.LITTLE_ENDIAN);
buf.putLong(0, json.length);
os.write(buf.array());
os.write(json);
for (NDArray nd : this) {
os.write(nd.toByteArray());
}
return;
}

DataOutputStream dos = new DataOutputStream(os);
Expand Down Expand Up @@ -450,4 +542,23 @@ public String toString() {
}
return builder.toString();
}

/** An enum represents NDList serialization format. */
public enum Encoding {
ND_LIST,
NPZ,
SAFETENSORS
}

private static final class SafeTensor {
String dtype;
long[] shape;

@SerializedName("data_offsets")
int[] offsets;

int size() {
return offsets[1] - offsets[0];
}
}
}
68 changes: 68 additions & 0 deletions api/src/main/java/ai/djl/ndarray/types/DataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,37 @@ public static DataType fromNumpy(String dtype) {
}
}

/**
* Returns the data type from Safetensors value.
*
* @param dtype the Safetensors datatype
* @return the data type
*/
public static DataType fromSafetensors(String dtype) {
switch (dtype) {
case "F64":
return FLOAT64;
case "F32":
return FLOAT32;
case "F16":
return FLOAT16;
case "BF16":
return BFLOAT16;
case "I64":
return INT64;
case "I32":
return INT32;
case "I8":
return INT8;
case "U8":
return UINT8;
case "BOOL":
return BOOLEAN;
default:
throw new IllegalArgumentException("Unsupported safetensors dataType: " + dtype);
}
}

/**
* Converts a {@link ByteBuffer} to a buffer for this data type.
*
Expand Down Expand Up @@ -260,6 +291,43 @@ public String asNumpy() {
}
}

/**
* Returns a safetensors string value.
*
* @return a safetensors string value
*/
public String asSafetensors() {
switch (this) {
case FLOAT64:
return "F64";
case FLOAT32:
return "F32";
case FLOAT16:
return "F16";
case BFLOAT16:
return "BF16";
case INT64:
return "I64";
case INT32:
return "I32";
case INT8:
return "I8";
case UINT8:
return "U8";
case BOOLEAN:
return "BOOL";
case INT16:
case UINT64:
case UINT32:
case UINT16:
case STRING:
case COMPLEX64:
case UNKNOWN:
default:
throw new IllegalArgumentException("Unsupported dataType: " + this);
}
}

/** {@inheritDoc} */
@Override
public String toString() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,14 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
Output output = new Output();
if ("tensor/npz".equalsIgnoreCase(accept)
|| "tensor/npz".equalsIgnoreCase(contentType)) {
output.add(list.encode(true));
output.add(list.encode(NDList.Encoding.NPZ));
output.addProperty("Content-Type", "tensor/npz");
} else if ("tensor/safetensors".equalsIgnoreCase(accept)
|| "tensor/safetensors".equalsIgnoreCase(contentType)) {
output.add(list.encode(NDList.Encoding.SAFETENSORS));
output.addProperty("Content-Type", "tensor/safetensors");
} else {
output.add(list.encode(false));
output.add(list.encode());
output.addProperty("Content-Type", "tensor/ndlist");
}
return output;
Expand Down
17 changes: 16 additions & 1 deletion api/src/test/java/ai/djl/ndarray/NDListTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,25 @@ public void testNumpy() throws IOException {
NDList decoded = NDList.decode(manager, data);

ByteArrayOutputStream bos = new ByteArrayOutputStream(data.length + 1);
decoded.encode(bos, true);
decoded.encode(bos, NDList.Encoding.NPZ);
NDList list = NDList.decode(manager, bos.toByteArray());
Assert.assertEquals(list.size(), 2);
Assert.assertEquals(list.get(0).getName(), "bool8");
}
}

@Test
public void testSafetensors() throws IOException {
try (NDManager manager = NDManager.newBaseManager(Device.cpu())) {
byte[] data = NDSerializerTest.readFile("list.safetensors");
NDList decoded = NDList.decode(manager, data);

ByteArrayOutputStream bos = new ByteArrayOutputStream(data.length + 1);
decoded.encode(bos, NDList.Encoding.SAFETENSORS);
NDList list = NDList.decode(manager, bos.toByteArray());
Assert.assertEquals(list.size(), 2);
Assert.assertEquals(list.get(0).getName(), "attention");
Assert.assertEquals(list.get(0).toByteArray(), new byte[] {0, 1, 2, 3, 4, 5});
}
}
}
31 changes: 31 additions & 0 deletions api/src/test/java/ai/djl/ndarray/types/DataTypeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,35 @@ public void numpyTest() {
Assert.expectThrows(IllegalArgumentException.class, DataType.UNKNOWN::asNumpy);
Assert.expectThrows(IllegalArgumentException.class, () -> DataType.fromNumpy("|i8"));
}

@Test
public void safetensorsTest() {
Assert.assertEquals(DataType.FLOAT64.asSafetensors(), "F64");
Assert.assertEquals(DataType.FLOAT32.asSafetensors(), "F32");
Assert.assertEquals(DataType.FLOAT16.asSafetensors(), "F16");
Assert.assertEquals(DataType.BFLOAT16.asSafetensors(), "BF16");
Assert.assertEquals(DataType.INT64.asSafetensors(), "I64");
Assert.assertEquals(DataType.INT32.asSafetensors(), "I32");
Assert.assertEquals(DataType.INT8.asSafetensors(), "I8");
Assert.assertEquals(DataType.UINT8.asSafetensors(), "U8");
Assert.assertEquals(DataType.BOOLEAN.asSafetensors(), "BOOL");

Assert.assertEquals(DataType.fromSafetensors("F64"), DataType.FLOAT64);
Assert.assertEquals(DataType.fromSafetensors("F32"), DataType.FLOAT32);
Assert.assertEquals(DataType.fromSafetensors("F16"), DataType.FLOAT16);
Assert.assertEquals(DataType.fromSafetensors("BF16"), DataType.BFLOAT16);
Assert.assertEquals(DataType.fromSafetensors("I64"), DataType.INT64);
Assert.assertEquals(DataType.fromSafetensors("I32"), DataType.INT32);
Assert.assertEquals(DataType.fromSafetensors("I8"), DataType.INT8);
Assert.assertEquals(DataType.fromSafetensors("U8"), DataType.UINT8);
Assert.assertEquals(DataType.fromSafetensors("BOOL"), DataType.BOOLEAN);

Assert.expectThrows(IllegalArgumentException.class, DataType.UINT64::asSafetensors);
Assert.expectThrows(IllegalArgumentException.class, DataType.UINT32::asSafetensors);
Assert.expectThrows(IllegalArgumentException.class, DataType.UINT16::asSafetensors);
Assert.expectThrows(IllegalArgumentException.class, DataType.COMPLEX64::asSafetensors);
Assert.expectThrows(IllegalArgumentException.class, DataType.STRING::asSafetensors);
Assert.expectThrows(IllegalArgumentException.class, DataType.UNKNOWN::asSafetensors);
Assert.expectThrows(IllegalArgumentException.class, () -> DataType.fromSafetensors("U16"));
}
}
Loading

0 comments on commit d8b8b71

Please sign in to comment.