Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation for the expected input dimension of the model class #8777

Closed
hzhz2020 opened this issue Dec 2, 2024 · 2 comments
Closed

Documentation for the expected input dimension of the model class #8777

hzhz2020 opened this issue Dec 2, 2024 · 2 comments

Comments

@hzhz2020
Copy link

hzhz2020 commented Dec 2, 2024

📚 The doc issue

The built-in models are really convenient. However, the documentation usually did not specified the expected input dimension, I always find it troublesome to confirm what is the correct input dimension for the model class that i want to use.

For example:
https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html
https://pytorch.org/vision/main/models/generated/torchvision.models.swin_t.html
https://pytorch.org/vision/main/models/generated/torchvision.models.video.swin3d_b.html

Is there clear documentation for this issue? Or is there a simple and clear rule that i can use (e.g., a rule that were used to develop these model class in pytorch that are consistent throughout?)

Suggest a potential alternative/fix

No response

@abhi-glitchhg
Copy link
Contributor

The documentation mentions the transforms that need to be applied on the image. resnet model has resize transform and crop transforms which ultimately decides the shape of input tensor to the model.

The inference transforms are available at ResNet18_Weights.IMAGENET1K_V1.transforms and perform the following preprocessing operations: Accepts PIL.Image, batched (B, C, H, W) and single (C, H, W) image torch.Tensor objects. The images are resized to resize_size=[256] using interpolation=InterpolationMode.BILINEAR, followed by a central crop of crop_size=[224]. Finally the values are first rescaled to [0.0, 1.0] and then normalized using mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].

So when using resnet model,

resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
preprocess = ResNet18_Weights.IMAGENET1K_V1.transforms()

input_tensor = preprocess(input_pil_image)
input_batch = input_tensor.unsqueeze(0)
resnet(input_batch) #forward pass

Similarly, you can find the transforms for other models as well.

I hope ive answered your question.

@NicolasHug
Copy link
Member

NicolasHug commented Dec 3, 2024

@abhi-glitchhg is correct - you'll find the necessary info in the weight documentation, at the bottom of the pages you linked: https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet18.html#torchvision.models.ResNet18_Weights

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants