Skip to content

Commit

Permalink
Spawn in TF saving/serializing in a new process to avoid a locked GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
ashao committed Dec 1, 2023
1 parent c68b2af commit 762db80
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 70 deletions.
63 changes: 25 additions & 38 deletions smartsim/_core/_cli/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@


if t.TYPE_CHECKING:
# Pylint disables needed for old version of pylint w/ TF 2.6.2
# pylint: disable-next=unused-import
from multiprocessing.connection import Connection

# pylint: disable-next=unsubscriptable-object
_TemporaryDirectory = tempfile.TemporaryDirectory[str]
else:
Expand Down Expand Up @@ -86,16 +82,23 @@ def execute(args: argparse.Namespace, /) -> int:
"""Validate the SmartSim installation works as expected given a
simple experiment
"""
from importlib.util import find_spec

torch_available = find_spec("torch")
tensorflow_available = find_spec("tensorflow")
onnx_available = find_spec("skl2onnx") and find_spec("sklearn")

backends = installed_redisai_backends()
has_tf = False
try:
with _VerificationTempDir(dir=os.getcwd()) as temp_dir:
test_install(
location=temp_dir,
port=args.port,
device=args.device.upper(),
with_tf="tensorflow" in backends,
with_pt="torch" in backends,
with_onnx="onnxruntime" in backends,
with_tf="tensorflow" in backends and torch_available,
with_pt="torch" in backends and tensorflow_available,
with_onnx="onnxruntime" in backends and onnx_available,
)
except Exception as e:
logger.error(
Expand Down Expand Up @@ -146,12 +149,18 @@ def test_install(
if with_tf:
logger.info("Verifying TensorFlow Backend")
_test_tf_install(client, location, device)
else:
logger.warning("Tensorflow not available. Skipping test")
if with_pt:
logger.info("Verifying Torch Backend")
_test_torch_install(client, device)
else:
logger.warning("Torch not available. Skipping test")
if with_onnx:
logger.info("Verifying ONNX Backend")
_test_onnx_install(client, device)
else:
logger.warning("ONNX not available. Skipping test")


@contextmanager
Expand All @@ -178,39 +187,10 @@ def _find_free_port() -> int:


def _test_tf_install(client: Client, tmp_dir: str, device: _TCapitalDeviceStr) -> None:
recv_conn, send_conn = mp.Pipe(duplex=False)
# Build the model in a subproc so that keras does not hog the gpu
proc = mp.Process(target=_build_tf_frozen_model, args=(send_conn, tmp_dir))
proc.start()

# do not need the sending connection in this proc anymore
send_conn.close()

proc.join(timeout=120)
if proc.is_alive():
proc.terminate()
raise Exception("Failed to build a simple keras model within 2 minutes")
try:
model_path, inputs, outputs = recv_conn.recv()
except EOFError as e:
raise Exception(
"Failed to receive serialized model from subprocess. "
"Is the `tensorflow` python package installed?"
) from e

client.set_model_from_file(
"keras-fcn", model_path, "TF", device=device, inputs=inputs, outputs=outputs
)
client.put_tensor("keras-input", np.random.rand(1, 28, 28).astype(np.float32))
client.run_model("keras-fcn", inputs=["keras-input"], outputs=["keras-output"])
client.get_tensor("keras-output")


def _build_tf_frozen_model(conn: "Connection", tmp_dir: str) -> None:
from tensorflow import keras

from smartsim.ml.tf import freeze_model

# Build a small TF model and freeze it
fcn = keras.Sequential(
layers=[
keras.layers.InputLayer(input_shape=(28, 28), name="input"),
Expand All @@ -224,7 +204,14 @@ def _build_tf_frozen_model(conn: "Connection", tmp_dir: str) -> None:
optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
model_path, inputs, outputs = freeze_model(fcn, tmp_dir, "keras_model.pb")
conn.send((model_path, inputs, outputs))

# Try to set the model and use it
client.set_model_from_file(
"keras-fcn", model_path, "TF", device=device, inputs=inputs, outputs=outputs
)
client.put_tensor("keras-input", np.random.rand(1, 28, 28).astype(np.float32))
client.run_model("keras-fcn", inputs=["keras-input"], outputs=["keras-output"])
client.get_tensor("keras-output")


def _test_torch_install(client: Client, device: _TCapitalDeviceStr) -> None:
Expand Down
103 changes: 71 additions & 32 deletions smartsim/ml/tf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from __future__ import annotations
from pathlib import Path

import tensorflow as tf
Expand All @@ -32,35 +33,43 @@
from tensorflow.python.framework.convert_to_constants import (
convert_variables_to_constants_v2,
)
import multiprocessing as mp

if t.TYPE_CHECKING:
from multiprocessing.connection import Connection

def freeze_model(
model: keras.Model, output_dir: str, file_name: str
) -> t.Tuple[str, t.List[str], t.List[str]]:
"""Freeze a Keras or TensorFlow Graph
def _serialize_internals(connection: "Connection", model: keras.Model) -> None:

to use a Keras or TensorFlow model in SmartSim, the model
must be frozen and the inputs and outputs provided to the
smartredis.client.set_model_from_file() method.
full_model = tf.function(model)
full_model = full_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)
)

This utiliy function provides everything users need to take
a trained model and put it inside an ``orchestrator`` instance
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

:param model: TensorFlow or Keras model
:type model: tf.Module
:param output_dir: output dir to save model file to
:type output_dir: str
:param file_name: name of model file to create
:type file_name: str
:return: path to model file, model input layer names, model output layer names
:rtype: str, list[str], list[str]
input_names = [x.name.split(":")[0] for x in frozen_func.inputs]
output_names = [x.name.split(":")[0] for x in frozen_func.outputs]

model_serialized = frozen_func.graph.as_graph_def().SerializeToString(
deterministic=True
)

connection.send((model_serialized, input_names, output_names))
connection.close()

def _freeze_internals(
connection: "Connection", model: keras.Model, output_dir: str, file_name: str
) -> None:
"""
Needed to run the freezing in separate process
to avoid locking up the GPU
"""
# TODO figure out why layer names don't match up to
# specified name in Model init.

if not file_name.endswith(".pb"):
file_name = file_name + ".pb"


full_model = tf.function(model)
full_model = full_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)
Expand All @@ -79,6 +88,42 @@ def freeze_model(
as_text=False,
)
model_file_path = str(Path(output_dir, file_name).resolve())
connection.send((model_file_path, input_names, output_names))
connection.close()

def freeze_model(
model: keras.Model, output_dir: str, file_name: str
) -> t.Tuple[str, t.List[str], t.List[str]]:
"""Freeze a Keras or TensorFlow Graph
to use a Keras or TensorFlow model in SmartSim, the model
must be frozen and the inputs and outputs provided to the
smartredis.client.set_model_from_file() method.
This utiliy function provides everything users need to take
a trained model and put it inside an ``orchestrator`` instance
:param model: TensorFlow or Keras model
:type model: tf.Module
:param output_dir: output dir to save model file to
:type output_dir: str
:param file_name: name of model file to create
:type file_name: str
:return: path to model file, model input layer names, model output layer names
:rtype: str, list[str], list[str]
"""
# TODO figure out why layer names don't match up to
# specified name in Model init.


parent_connection, child_connection = mp.Pipe()
graph_freeze_process = mp.Process(
target=_freeze_internals,
args=(child_connection, model, output_dir, file_name)
)
graph_freeze_process.start()
model_file_path, input_names, output_names = parent_connection.recv()
graph_freeze_process.join()
return model_file_path, input_names, output_names


Expand All @@ -98,19 +143,13 @@ def serialize_model(model: keras.Model) -> t.Tuple[str, t.List[str], t.List[str]
:rtype: str, list[str], list[str]
"""

full_model = tf.function(model)
full_model = full_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)
)

frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

input_names = [x.name.split(":")[0] for x in frozen_func.inputs]
output_names = [x.name.split(":")[0] for x in frozen_func.outputs]

model_serialized = frozen_func.graph.as_graph_def().SerializeToString(
deterministic=True
parent_connection, child_connection = mp.Pipe()
graph_freeze_process = mp.Process(
target=_serialize_internals,
args=(child_connection, model)
)

graph_freeze_process.start()
model_serialized, input_names, output_names = parent_connection.recv()
graph_freeze_process.join()
return model_serialized, input_names, output_names

0 comments on commit 762db80

Please sign in to comment.