-
Notifications
You must be signed in to change notification settings - Fork 361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mutable module improvement #3394
base: main
Are you sure you want to change the base?
Conversation
210bd3a
to
02011eb
Compare
f183050
to
9580f92
Compare
3db0aec
to
ec2d674
Compare
py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py
Outdated
Show resolved
Hide resolved
py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py
Outdated
Show resolved
Hide resolved
py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py
Outdated
Show resolved
Hide resolved
py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py
Outdated
Show resolved
Hide resolved
0ca42f3
to
0643d96
Compare
py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py
Outdated
Show resolved
Hide resolved
py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py
Outdated
Show resolved
Hide resolved
a5ea97b
to
a59d92d
Compare
a59d92d
to
a8e0b48
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have the persistent cache example?
cef1620
to
1a309b8
Compare
I talked to Boris and it seems like save and load is what he is looking for. I am adding engine caching example to MutableTorchTensorRTModule as well. |
7cf4ad9
to
76437f9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM. Added minor comments
@@ -63,16 +65,14 @@ | |||
# Saving Mutable Torch TensorRT Module | |||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |||
|
|||
# Currently, saving is only enabled for C++ runtime, not python runtime. | |||
# Currently, saving is only when "use_python" = False in settings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean use_python_runtime=False ?
pipe.to(device) | ||
|
||
# The only extra line you need | ||
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings) | ||
|
||
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] | ||
BATCH = torch.export.Dim("BATCH", min=1 * 2, max=12 * 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason why it is written as 1*2
and 12*2
? instead of 2 and 24 ?
enabled_precisions = {torch.float} | ||
debug = False | ||
min_block_size = 1 | ||
use_python_runtime = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this necessary ?
kwargs_dynamic_shape: dict[str, Any], | ||
) -> None: | ||
""" | ||
Set the dynamic shape range. The shape hint should EXACTLY follow arg_inputs and kwarg_inputs passed to the forward function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add this link to reference torch.export's convention : https://pytorch.org/docs/stable/export.html#expressing-dynamism ?
""" | ||
assert isinstance( | ||
args_dynamic_shape, tuple | ||
), "args dynamic shape has to be a tuple" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add - but the provided type is {type(args_dynamic_shape)} ?
), "args dynamic shape has to be a tuple" | ||
assert isinstance( | ||
kwargs_dynamic_shape, dict | ||
), "args dynamic shape has to be a dictionary" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add - but the provided type is {type(kwargs_dynamic_shape)} ?
dynamic_shape = {"a": {1: dim}, "b": [{}, {}], "c": {"a": {}, "b": [{}, {}]}} | ||
assertions.assertFalse( | ||
torch_trt.MutableTorchTensorRTModule._check_inputs_shape(a, b), | ||
msg=f"test_check_output_equal is not correct.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_check_input_shape_dynamic
) | ||
assertions.assertTrue( | ||
torch_trt.MutableTorchTensorRTModule._check_inputs_shape(a, b, dynamic_shape), | ||
msg=f"test_check_output_equal is not correct.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_check_input_shape_dynamic
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: