Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tensorflow] Fixes tensorflow session always on gpu(0) bug #1558

Merged
merged 1 commit into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
} catch (InvalidProtocolBufferException e) {
throw new MalformedModelException("Invalid ConfigProto: " + config, e);
}
} else {
// default one
configProto = JavacppUtils.getSessionConfig();
}
Object run = options.get("RunOptions");
if (run instanceof RunOptions) {
Expand All @@ -113,6 +110,10 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
if (tags == null) {
tags = new String[] {"serve"};
}
if (configProto == null) {
// default one
configProto = JavacppUtils.getSessionConfig();
}

SavedModelBundle bundle =
JavacppUtils.loadSavedModelBundle(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
package ai.djl.tensorflow.engine;

import ai.djl.Device;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.DataType;
import ai.djl.tensorflow.engine.javacpp.JavacppUtils;
import ai.djl.util.Preconditions;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
Expand Down Expand Up @@ -50,6 +50,7 @@ final class TfOpExecutor implements AutoCloseable {
// keep the native pointer alive outside of the scope
opHandle.retainReference();
}
setDevice(manager.getDevice());
}

public NDArray[] build(int numOutputs) {
Expand Down Expand Up @@ -133,15 +134,8 @@ public TfOpExecutor addInputList(NDArray[] inputs) {

@SuppressWarnings({"unchecked", "try"})
public TfOpExecutor setDevice(Device device) {
String deviceStr;
try (PointerScope ignore = new PointerScope()) {
if (device.getDeviceType().equals(Device.Type.CPU)) {
deviceStr = "/device:CPU:0";
} else if (device.getDeviceType().equals(Device.Type.GPU)) {
deviceStr = "/device:GPU:" + device.getDeviceId();
} else {
throw new EngineException("Unknown device type to TensorFlow Engine: " + device);
}
String deviceStr = JavacppUtils.toTfDevice(device);
TF_Status status = TF_Status.newStatus();
tensorflow.TFE_OpSetDevice(opHandle, deviceStr, status);
status.throwExceptionIfNotOK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ai.djl.tensorflow.engine.SavedModelBundle;
import ai.djl.tensorflow.engine.TfDataType;
import ai.djl.util.Pair;
import ai.djl.util.cuda.CudaUtils;
import com.google.protobuf.InvalidProtocolBufferException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
Expand Down Expand Up @@ -438,8 +439,20 @@ public static ConfigProto getSessionConfig() {
if (intraop != null) {
configBuilder.setIntraOpParallelismThreads(intraop);
}
GPUOptions gpuOptions = GPUOptions.newBuilder().setVisibleDeviceList("0").build();
configBuilder.setGpuOptions(gpuOptions);
int gpuCount = CudaUtils.getGpuCount();
if (gpuCount > 0) {
StringBuilder sb = new StringBuilder("0");
for (int i = 1; i < gpuCount; ++i) {
sb.append(',').append(i);
}
GPUOptions gpuOptions =
GPUOptions.newBuilder().setVisibleDeviceList(sb.toString()).build();
configBuilder.setGpuOptions(gpuOptions);
configBuilder.setAllowSoftPlacement(true);
if (Boolean.getBoolean("ai.djl.tensorflow.debug")) {
configBuilder.setLogDevicePlacement(true);
}
}
return configBuilder.build();
}

Expand Down