Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Error Code 3: API Usage Error on IConvolutionLayer::setPaddingNd (Paligemma2) #3409

Open
chohk88 opened this issue Feb 24, 2025 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@chohk88
Copy link
Collaborator

chohk88 commented Feb 24, 2025

Bug Description

Attempting torch.compile (backend = torch_tensorrt) the google/paligemma2-3b-pt-224 model, I encountered below message for both torch_tensorrt 2.6 and 2.7 dev.

INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node _to_copy [aten._to_copy.default] (Inputs: (arg0_1: (1, 3, 224, 224)@torch.float16) | Outputs: (_to_copy: (1, 3, 224, 224)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node arg1_1 (kind: arg1_1, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: arg1_1 [shape=[1152, 3, 14, 14], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node arg1_1 [arg1_1] (Inputs: () | Outputs: (arg1_1: (1152, 3, 14, 14)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node arg2_1 (kind: arg2_1, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: arg2_1 [shape=[1152], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node arg2_1 [arg2_1] (Inputs: () | Outputs: (arg2_1: (1152,)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node convolution (kind: aten.convolution.default, args: ('_to_copy <Node>', 'arg1_1 <Node>', 'arg2_1 <Node>', ['14 <int>', '14 <int>'], ['0 <int>'], ['1 <int>', '1 <int>'], 'False <bool>', ['0 <int>'], '1 <int>'))
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.convolution.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.convolution.default
DEBUG:torch_tensorrt [TensorRT Conversion Context]:Kernel weights are not set yet. Kernel weights must be set using setInput(1, kernel_tensor) API call.
ERROR:torch_tensorrt [TensorRT Conversion Context]:IConvolutionLayer::setPaddingNd: Error Code 3: API Usage Error (Parameter check failed, condition: (padding.nbDims == 2 || padding.nbDims == 3) && allDimsGtEq(padding, 0) && allDimsLtEq(padding, kMAX_PADDING). )
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node convolution [aten.convolution.default] (Inputs: (_to_copy: (1, 3, 224, 224)@torch.float16, arg1_1: (1152, 3, 14, 14)@torch.float16, arg2_1: (1152,)@torch.float16, [14, 14], [0], [1, 1], False, [0], 1) | Outputs: (convolution: (1, 1152, 16, 16)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node reshape_default (kind: aten.reshape.default, args: ('convolution <Node>', ['1 <int>', '1152 <int>', '256 <int>']))
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.reshape.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.reshape.default
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node reshape_default [aten.reshape.default] (Inputs: (convolution: (1, 1152, 16, 16)@torch.float16, [1, 1152, 256]) | Outputs: (reshape_default: (1, 1152, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node permute (kind: aten.permute.default, args: ('reshape_default <Node>', ['0 <int>', '2 <int>', '1 <int>']))

To Reproduce

Steps to reproduce the behavior:

import torch
import torch_tensorrt
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
from transformers.image_utils import load_image

DEVICE = "cuda:0"

model_id = "google/paligemma2-3b-pt-224"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)


model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id, torch_dtype=torch.float16).eval()
model.to(DEVICE).to(torch.float16)
# model.forward = model.forward.to(torch.float16).eval()

processor = PaliGemmaProcessor.from_pretrained(model_id)
prompt = ""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.float16).to(DEVICE) # to(DEVICE) # .to(torch.float16).to(DEVICE)
input_len = model_inputs["input_ids"].shape[-1]

# model.config.token_healing = False

with torch.inference_mode():
    pyt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    pyt_generation_out = pyt_generation[0][input_len:]
    pyt_decoded = processor.decode(pyt_generation_out, skip_special_tokens=True)
    print("=============================")
    print("pyt_generation whole text:")
    print(pyt_generation)
    print("=============================")
    print("=============================")
    print("PyTorch generated text:")
    print(pyt_decoded)
    print("=============================")

with torch_tensorrt.logging.debug():
    torch._dynamo.mark_dynamic(model_inputs["input_ids"], 1, min=2, max=1023)
    model.forward = torch.compile(
        model.forward,
        backend="tensorrt",
        dynamic=None,
        options={
            "enabled_precisions": {torch.float16},
            "disable_tf32": True,
            "min_block_size": 1,
            # "use_explicit_typing": True,
            # "use_fp32_acc": True,
            "debug": True,
            # "use_aot_joint_export":False,
        },
    )
    
    with torch.inference_mode():
        trt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) 
        trt_generation_out = trt_generation[0][input_len:]
        trt_decoded = processor.decode(trt_generation_out, skip_special_tokens=True)
        print(trt_generation)
        print("TensorRT generated text:")
        print(trt_decoded)
@chohk88 chohk88 added the bug Something isn't working label Feb 24, 2025
@chohk88
Copy link
Collaborator Author

chohk88 commented Feb 25, 2025

Running this code reproduces the error mentioned above. The issue arises because the fourth argument, which may include multiple values, is currently set to only one zero.

import torch
import torch_tensorrt
import torch.nn.functional as F

class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(torch.randn(16, 3, 14, 14).half())
        self.bias = torch.nn.Parameter(torch.randn(16).half())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.conv2d(x, self.weight, self.bias, stride=1, padding=[0])

if __name__ == "__main__":
    with torch.inference_mode():
        model = MyModule().eval().cuda().half()
        inputs = [torch.randn(1, 3, 224, 224, dtype=torch.half, device="cuda")]

        trt_model = torch_tensorrt.compile(model, "dynamo", inputs, debug=True, min_block_size=1)

        torch.testing.assert_close(trt_model(*inputs), model(*inputs), rtol=5e-3, atol=5e-3)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant