Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch] Adds IValue Dict(str, IValue) support #1765

Merged
merged 1 commit into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,24 @@ public static IValue stringMapFrom(Map<String, PtNDArray> map) {
return new IValue(PyTorchLibrary.LIB.iValueFromStringMap(keys, handles));
}

/**
* Creates a new {@code IValue} of type {@code Map[String, IValue]}.
*
* @param map the Map[String, IValue] value
* @return a new {@code IValue} of type {@code Map[String, IValue]}
*/
public static IValue stringIValueMapFrom(Map<String, IValue> map) {
String[] keys = new String[map.size()];
long[] handles = new long[map.size()];
int i = 0;
for (Map.Entry<String, IValue> entry : map.entrySet()) {
keys[i] = entry.getKey();
handles[i] = entry.getValue().getHandle();
++i;
}
return new IValue(PyTorchLibrary.LIB.iValueFromStringIValueMap(keys, handles));
}

/**
* Returns the {@code boolean} value of this IValue.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,8 @@ native long moduleLoad(

native long iValueFromStringMap(String[] keys, long[] tensorHandles);

native long iValueFromStringIValueMap(String[] keys, long[] tensorHandles);

native long iValueToTensor(long iValueHandle);

native boolean iValueToBool(long iValueHandle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ public void testIValue() {
Assert.assertEquals(list.get("data1"), array1);
}

// (Dict(str, Tensor[])
Map<String, IValue> iValueMap = new ConcurrentHashMap<>();
try (IValue v1 = IValue.listFrom(array1);
IValue v2 = IValue.listFrom(array2)) {
iValueMap.put("data1", v1);
iValueMap.put("data2", v2);
try (IValue ivalue = IValue.stringIValueMapFrom(iValueMap)) {
Assert.assertTrue(ivalue.isMap());
Assert.assertEquals(ivalue.getType(), "Dict(str, Tensor[])");
}
}

try (IValue iv1 = IValue.from(1);
IValue iv2 = IValue.from(2);
IValue ivalue = IValue.listFrom(iv1, iv2)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,30 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromStringM
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromStringIValueMap(
JNIEnv* env, jobject jthis, jobjectArray jkeys, jlongArray jvalues) {
API_BEGIN()
auto len = static_cast<size_t>(env->GetArrayLength(jvalues));
jlong* jptrs = env->GetLongArrayElements(jvalues, JNI_FALSE);
if (len == 0) {
const auto* ivalue_ptr = new torch::IValue{c10::impl::GenericDict(c10::StringType::get(), c10::TensorType::get())};
return reinterpret_cast<uintptr_t>(ivalue_ptr);
}

auto* firstEntryValue = reinterpret_cast<torch::IValue*>(jptrs[0]);
c10::impl::GenericDict dict(c10::StringType::get(), c10::unshapedType(firstEntryValue->type()));
for (size_t i = 0; i < len; ++i) {
auto jname = (jstring) env->GetObjectArrayElement(jkeys, i);
std::string name = djl::utils::jni::GetStringFromJString(env, jname);
dict.insert(name, *reinterpret_cast<torch::IValue*>(jptrs[i]));
}
env->ReleaseLongArrayElements(jvalues, jptrs, JNI_ABORT);
env->DeleteLocalRef(jkeys);
const auto* ivalue_ptr = new torch::IValue{dict};
return reinterpret_cast<uintptr_t>(ivalue_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToTensor(
JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
Expand Down