Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] Adds Safetensors support #2763

Merged
merged 1 commit into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 122 additions & 11 deletions api/src/main/java/ai/djl/ndarray/NDList.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
package ai.djl.ndarray;

import ai.djl.Device;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.JsonUtils;
import ai.djl.util.Pair;

import com.google.gson.JsonObject;
import com.google.gson.annotations.SerializedName;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand All @@ -25,10 +31,14 @@
import java.io.PushbackInputStream;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
Expand Down Expand Up @@ -84,7 +94,7 @@ public NDList(Collection<NDArray> other) {
* @return {@code NDList}
*/
public static NDList decode(NDManager manager, byte[] byteArray) {
if (byteArray.length < 4) {
if (byteArray.length < 9) {
throw new IllegalArgumentException("Invalid input length: " + byteArray.length);
}
try {
Expand All @@ -96,6 +106,8 @@ public static NDList decode(NDManager manager, byte[] byteArray) {
&& byteArray[3] == 'M') {
return new NDList(
NDSerializer.decode(manager, new ByteArrayInputStream(byteArray)));
} else if (byteArray[8] == '{') {
return decodeSafetensors(manager, new ByteArrayInputStream(byteArray));
}

ByteBuffer bb = ByteBuffer.wrap(byteArray);
Expand Down Expand Up @@ -124,10 +136,10 @@ public static NDList decode(NDManager manager, byte[] byteArray) {
public static NDList decode(NDManager manager, InputStream is) {
try {
DataInputStream dis = new DataInputStream(is);
byte[] magic = new byte[4];
byte[] magic = new byte[9];
dis.readFully(magic);

PushbackInputStream pis = new PushbackInputStream(is, 4);
PushbackInputStream pis = new PushbackInputStream(is, 9);
pis.unread(magic);
if (magic[0] == 'P' && magic[1] == 'K') {
// assume this is npz file
Expand All @@ -137,6 +149,8 @@ public static NDList decode(NDManager manager, InputStream is) {
&& magic[2] == 'U'
&& magic[3] == 'M') {
return new NDList(NDSerializer.decode(manager, pis));
} else if (magic[8] == '{') {
return decodeSafetensors(manager, pis);
}

dis = new DataInputStream(pis);
Expand All @@ -154,6 +168,54 @@ public static NDList decode(NDManager manager, InputStream is) {
}
}

private static NDList decodeSafetensors(NDManager manager, InputStream is) throws IOException {
DataInputStream dis;
if (is instanceof DataInputStream) {
dis = (DataInputStream) is;
} else {
dis = new DataInputStream(is);
}

byte[] buf = new byte[8];
dis.readFully(buf);
int len = Math.toIntExact(ByteBuffer.wrap(buf).order(ByteOrder.LITTLE_ENDIAN).getLong());
buf = new byte[len];
dis.readFully(buf);
String json = new String(buf, StandardCharsets.UTF_8);
// rust implementation sort by name, our implementation preserve the order.
JsonObject jsonObject = JsonUtils.GSON.fromJson(json, JsonObject.class);
List<Pair<String, SafeTensor>> list = new ArrayList<>();
int max = 0;
for (String key : jsonObject.keySet()) {
if ("__metadata__".equals(key)) {
continue;
}
SafeTensor value = JsonUtils.GSON.fromJson(jsonObject.get(key), SafeTensor.class);
if (value.offsets.length != 2) {
throw new IOException("Malformed safetensors metadata: " + json);
}
max = Math.max(max, value.offsets[1]);
list.add(new Pair<>(key, value));
}
buf = new byte[max];
dis.readFully(buf);
NDList ret = new NDList(list.size());
for (Pair<String, SafeTensor> pair : list) {
if ("__metadata__".equals(pair.getKey())) {
continue;
}
SafeTensor st = pair.getValue();
Shape shape = new Shape(st.shape);
ByteBuffer bb = ByteBuffer.wrap(buf, st.offsets[0], st.size());
bb.order(ByteOrder.LITTLE_ENDIAN);
DataType dataType = DataType.fromSafetensors(st.dtype);
NDArray array = manager.create(bb, shape, dataType);
array.setName(pair.getKey());
ret.add(array);
}
return ret;
}

private static NDList decodeNumpy(NDManager manager, InputStream is) throws IOException {
NDList list = new NDList();
ZipInputStream zis = new ZipInputStream(is);
Expand Down Expand Up @@ -340,18 +402,18 @@ public void detach() {
* @return the byte array
*/
public byte[] encode() {
return encode(false);
return encode(Encoding.ND_LIST);
}

/**
* Encodes the NDList to byte array.
*
* @param numpy encode in npz format if true
* @param encoding encode mode, one of ndlist/npz/safetensor format
* @return the byte array
*/
public byte[] encode(boolean numpy) {
public byte[] encode(Encoding encoding) {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
encode(baos, numpy);
encode(baos, encoding);
return baos.toByteArray();
} catch (IOException e) {
throw new AssertionError("NDList is not writable", e);
Expand All @@ -365,18 +427,18 @@ public byte[] encode(boolean numpy) {
* @throws IOException if failed on IO operation
*/
public void encode(OutputStream os) throws IOException {
encode(os, false);
encode(os, Encoding.ND_LIST);
}

/**
* Writes the encoded NDList to {@code OutputStream}.
*
* @param os the {@code OutputStream} to be written to
* @param numpy encode in npz format if true
* @param encoding encode mode, one of ndlist/npz/safetensor format
* @throws IOException if failed on IO operation
*/
public void encode(OutputStream os, boolean numpy) throws IOException {
if (numpy) {
public void encode(OutputStream os, Encoding encoding) throws IOException {
if (encoding == Encoding.NPZ) {
ZipOutputStream zos = new ZipOutputStream(os);
int i = 0;
for (NDArray nd : this) {
Expand All @@ -392,6 +454,36 @@ public void encode(OutputStream os, boolean numpy) throws IOException {
zos.finish();
zos.flush();
return;
} else if (encoding == Encoding.SAFETENSORS) {
Map<String, SafeTensor> map = new ConcurrentHashMap<>(size());
int i = 0;
int offset = 0;
for (NDArray nd : this) {
String name = nd.getName();
if (name == null) {
name = "arr_" + i;
++i;
}
SafeTensor st = new SafeTensor();
st.dtype = nd.getDataType().asSafetensors();
st.shape = nd.getShape().getShape();
long size = nd.getDataType().getNumOfBytes() * nd.size();
int limit = offset + Math.toIntExact(size);
st.offsets = new int[] {offset, limit};
map.put(name, st);
offset = limit;
}
byte[] json = JsonUtils.GSON.toJson(map).getBytes(StandardCharsets.UTF_8);

ByteBuffer buf = ByteBuffer.allocate(8);
buf.order(ByteOrder.LITTLE_ENDIAN);
buf.putLong(0, json.length);
os.write(buf.array());
os.write(json);
for (NDArray nd : this) {
os.write(nd.toByteArray());
}
return;
}

DataOutputStream dos = new DataOutputStream(os);
Expand Down Expand Up @@ -450,4 +542,23 @@ public String toString() {
}
return builder.toString();
}

/** An enum represents NDList serialization format. */
public enum Encoding {
ND_LIST,
NPZ,
SAFETENSORS
}

private static final class SafeTensor {
String dtype;
long[] shape;

@SerializedName("data_offsets")
int[] offsets;

int size() {
return offsets[1] - offsets[0];
}
}
}
68 changes: 68 additions & 0 deletions api/src/main/java/ai/djl/ndarray/types/DataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,37 @@ public static DataType fromNumpy(String dtype) {
}
}

/**
* Returns the data type from Safetensors value.
*
* @param dtype the Safetensors datatype
* @return the data type
*/
public static DataType fromSafetensors(String dtype) {
switch (dtype) {
case "F64":
return FLOAT64;
case "F32":
return FLOAT32;
case "F16":
return FLOAT16;
case "BF16":
return BFLOAT16;
case "I64":
return INT64;
case "I32":
return INT32;
case "I8":
return INT8;
case "U8":
return UINT8;
case "BOOL":
return BOOLEAN;
default:
throw new IllegalArgumentException("Unsupported safetensors dataType: " + dtype);
}
}

/**
* Converts a {@link ByteBuffer} to a buffer for this data type.
*
Expand Down Expand Up @@ -260,6 +291,43 @@ public String asNumpy() {
}
}

/**
* Returns a safetensors string value.
*
* @return a safetensors string value
*/
public String asSafetensors() {
switch (this) {
case FLOAT64:
return "F64";
case FLOAT32:
return "F32";
case FLOAT16:
return "F16";
case BFLOAT16:
return "BF16";
case INT64:
return "I64";
case INT32:
return "I32";
case INT8:
return "I8";
case UINT8:
return "U8";
case BOOLEAN:
return "BOOL";
case INT16:
case UINT64:
case UINT32:
case UINT16:
case STRING:
case COMPLEX64:
case UNKNOWN:
default:
throw new IllegalArgumentException("Unsupported dataType: " + this);
}
}

/** {@inheritDoc} */
@Override
public String toString() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,14 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
Output output = new Output();
if ("tensor/npz".equalsIgnoreCase(accept)
|| "tensor/npz".equalsIgnoreCase(contentType)) {
output.add(list.encode(true));
output.add(list.encode(NDList.Encoding.NPZ));
output.addProperty("Content-Type", "tensor/npz");
} else if ("tensor/safetensors".equalsIgnoreCase(accept)
|| "tensor/safetensors".equalsIgnoreCase(contentType)) {
output.add(list.encode(NDList.Encoding.SAFETENSORS));
output.addProperty("Content-Type", "tensor/safetensors");
} else {
output.add(list.encode(false));
output.add(list.encode());
output.addProperty("Content-Type", "tensor/ndlist");
}
return output;
Expand Down
17 changes: 16 additions & 1 deletion api/src/test/java/ai/djl/ndarray/NDListTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,25 @@ public void testNumpy() throws IOException {
NDList decoded = NDList.decode(manager, data);

ByteArrayOutputStream bos = new ByteArrayOutputStream(data.length + 1);
decoded.encode(bos, true);
decoded.encode(bos, NDList.Encoding.NPZ);
NDList list = NDList.decode(manager, bos.toByteArray());
Assert.assertEquals(list.size(), 2);
Assert.assertEquals(list.get(0).getName(), "bool8");
}
}

@Test
public void testSafetensors() throws IOException {
try (NDManager manager = NDManager.newBaseManager(Device.cpu())) {
byte[] data = NDSerializerTest.readFile("list.safetensors");
NDList decoded = NDList.decode(manager, data);

ByteArrayOutputStream bos = new ByteArrayOutputStream(data.length + 1);
decoded.encode(bos, NDList.Encoding.SAFETENSORS);
NDList list = NDList.decode(manager, bos.toByteArray());
Assert.assertEquals(list.size(), 2);
Assert.assertEquals(list.get(0).getName(), "attention");
Assert.assertEquals(list.get(0).toByteArray(), new byte[] {0, 1, 2, 3, 4, 5});
}
}
}
31 changes: 31 additions & 0 deletions api/src/test/java/ai/djl/ndarray/types/DataTypeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,35 @@ public void numpyTest() {
Assert.expectThrows(IllegalArgumentException.class, DataType.UNKNOWN::asNumpy);
Assert.expectThrows(IllegalArgumentException.class, () -> DataType.fromNumpy("|i8"));
}

@Test
public void safetensorsTest() {
Assert.assertEquals(DataType.FLOAT64.asSafetensors(), "F64");
Assert.assertEquals(DataType.FLOAT32.asSafetensors(), "F32");
Assert.assertEquals(DataType.FLOAT16.asSafetensors(), "F16");
Assert.assertEquals(DataType.BFLOAT16.asSafetensors(), "BF16");
Assert.assertEquals(DataType.INT64.asSafetensors(), "I64");
Assert.assertEquals(DataType.INT32.asSafetensors(), "I32");
Assert.assertEquals(DataType.INT8.asSafetensors(), "I8");
Assert.assertEquals(DataType.UINT8.asSafetensors(), "U8");
Assert.assertEquals(DataType.BOOLEAN.asSafetensors(), "BOOL");

Assert.assertEquals(DataType.fromSafetensors("F64"), DataType.FLOAT64);
Assert.assertEquals(DataType.fromSafetensors("F32"), DataType.FLOAT32);
Assert.assertEquals(DataType.fromSafetensors("F16"), DataType.FLOAT16);
Assert.assertEquals(DataType.fromSafetensors("BF16"), DataType.BFLOAT16);
Assert.assertEquals(DataType.fromSafetensors("I64"), DataType.INT64);
Assert.assertEquals(DataType.fromSafetensors("I32"), DataType.INT32);
Assert.assertEquals(DataType.fromSafetensors("I8"), DataType.INT8);
Assert.assertEquals(DataType.fromSafetensors("U8"), DataType.UINT8);
Assert.assertEquals(DataType.fromSafetensors("BOOL"), DataType.BOOLEAN);

Assert.expectThrows(IllegalArgumentException.class, DataType.UINT64::asSafetensors);
Assert.expectThrows(IllegalArgumentException.class, DataType.UINT32::asSafetensors);
Assert.expectThrows(IllegalArgumentException.class, DataType.UINT16::asSafetensors);
Assert.expectThrows(IllegalArgumentException.class, DataType.COMPLEX64::asSafetensors);
Assert.expectThrows(IllegalArgumentException.class, DataType.STRING::asSafetensors);
Assert.expectThrows(IllegalArgumentException.class, DataType.UNKNOWN::asSafetensors);
Assert.expectThrows(IllegalArgumentException.class, () -> DataType.fromSafetensors("U16"));
}
}
Loading