-
Notifications
You must be signed in to change notification settings - Fork 361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Automatically generate QDP plugins #3370
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outdated
py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py
Outdated
Show resolved
Hide resolved
py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py
Outdated
Show resolved
Hide resolved
|
||
# Use the helper function to get the required signatures | ||
args_input, kwargs_input, plugin_signature, plugin_impl_signature, register_func_annotation, impl_func_annotation = generate_signature(torch_op) | ||
print(args_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this debug info
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@narendasan when I use:
_LOGGER.debug(f"Plugin registration function: \n{codegen_plugin}")
It won't print anything. How to resolve this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Figure out if QDP will hold the reference to the plugin for us
- Add test cases for 1. (Tensor, Tensor) -> (Tensor) 2. (Tensor, int, float) -> (Tensor) 3. (Tensor, Tensor) -> (Tensor, Tensor)
Rebase as well |
5381375
to
a35a9ec
Compare
a35a9ec
to
56dacec
Compare
capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, | ||
priority: ConverterPriority = ConverterPriority.STANDARD, | ||
supports_dynamic_shapes: bool = False, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a docstring as this is a user API
) | ||
|
||
|
||
def generate_plugin(plugin_name: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after docstrings are added
56dacec
to
6683559
Compare
6683559
to
d5a787d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added minor comments. LGTM
|
||
for tensor_arg in tensor_args: | ||
|
||
sample = {f"{i}": 5 for i in range(tensor_arg.ndim)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is 5 here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's a default value or something, consider storing it in a global variable to make it more clear ?
outputs: Tuple[trtp.Tensor], stream: int, *args: Any, **kwargs: Any | ||
) -> None: | ||
tensor_args = [elem for elem in args if isinstance(elem, trtp.Tensor)] | ||
print(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary ? If so, can you make this message more effective for users ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like I forgot to delete the debugging lines. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some minor comments
tensor_args = [elem for elem in args if isinstance(elem, trtp.Tensor)] | ||
print(args) | ||
non_tensor_args = [elem for elem in args if not isinstance(elem, trtp.Tensor)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary to go for loops twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can write it as:
tensor_args, non_tensor_args = [], []
for elem in args:
(tensor_args if isinstance(elem, trtp.Tensor) else non_tensor_args).append(elem)
Since args won't be long I think the first one will be easier to understand?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's clearer and faster in this way:
tensor_args, non_tensor_args = [], []
for elem in args:
if isinstance(elem, trtp.Tensor):
tensor_args.append(elem)
else:
non_tensor_args.append(elem)
@@ -58,13 +57,24 @@ def custom_kernel_converter( | |||
# Assuming TensorRT preserves kwargs order like PyTorch does | |||
non_tensor_inputs = plugin.input_attrs | |||
|
|||
kwargs = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you intend to override kwargs
here? If yes, it seems kwargs
is not necessary in the arguments (line 46).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this kwargs cannot be erased since it's required in TorchTRT converter here.
d5a787d
to
b60c4c6
Compare
Description
This PR introduces a new feature which enables generating automatic plugin generation using TensorRT QDP feature.
Checklist: