-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft: Resnet support added #246
Changes from 3 commits
52fd9a2
86b39c5
48ba325
8a469d8
d99180c
cfb57cc
f750de1
0ce042d
c9d328b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"id": "4bdb2a91-fa2a-4cee-ad5a-176cc957394d", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Feature vector shape: (2, 2048, 1, 1)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import torch\n", | ||
"import torch.onnx\n", | ||
"import torchvision.models as models\n", | ||
"import torchvision.transforms as transforms\n", | ||
"from PIL import Image\n", | ||
"import numpy as np\n", | ||
"from tests.config import TEST_MISC_DIR\n", | ||
"\n", | ||
"# Load pre-trained ResNet-50 model\n", | ||
"resnet = models.resnet50(pretrained=True)\n", | ||
"resnet = torch.nn.Sequential(*(list(resnet.children())[:-1])) # Remove the last fully connected layer\n", | ||
"resnet.eval()\n", | ||
"\n", | ||
"# Define preprocessing transform\n", | ||
"preprocess = transforms.Compose([\n", | ||
" transforms.Resize(256),\n", | ||
" transforms.CenterCrop(224),\n", | ||
" transforms.ToTensor(),\n", | ||
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", | ||
"])\n", | ||
"\n", | ||
"# Load and preprocess the image\n", | ||
"def preprocess_image(image_path):\n", | ||
" input_image = Image.open(image_path)\n", | ||
" input_tensor = preprocess(input_image)\n", | ||
" input_batch = input_tensor.unsqueeze(0) # Add batch dimension\n", | ||
" return input_batch\n", | ||
"\n", | ||
"# Example input for exporting\n", | ||
"input_image = preprocess_image('example.jpg')\n", | ||
"\n", | ||
"# Export the model to ONNX with dynamic axes\n", | ||
"torch.onnx.export(\n", | ||
" resnet, \n", | ||
" input_image, \n", | ||
" \"model.onnx\", \n", | ||
" export_params=True, \n", | ||
" opset_version=9, \n", | ||
" input_names=['input'], \n", | ||
" output_names=['output'],\n", | ||
" dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}\n", | ||
")\n", | ||
"\n", | ||
"# Load ONNX model\n", | ||
"import onnx\n", | ||
"import onnxruntime as ort\n", | ||
"\n", | ||
"onnx_model = onnx.load(\"model.onnx\")\n", | ||
"ort_session = ort.InferenceSession(\"model.onnx\")\n", | ||
"\n", | ||
"# Run inference and extract feature vectors\n", | ||
"def extract_feature_vectors(image_paths):\n", | ||
" input_images = [preprocess_image(image_path) for image_path in image_paths]\n", | ||
" input_batch = torch.cat(input_images, dim=0) # Combine images into a single batch\n", | ||
" ort_inputs = {ort_session.get_inputs()[0].name: input_batch.numpy()}\n", | ||
" ort_outs = ort_session.run(None, ort_inputs)\n", | ||
" return ort_outs[0]\n", | ||
"\n", | ||
"# Example usage\n", | ||
"images = [TEST_MISC_DIR / \"image.jpeg\", str(TEST_MISC_DIR / \"small_image.jpeg\")] # Replace with your image paths\n", | ||
"feature_vectors = extract_feature_vectors(image_paths)\n", | ||
"print(\"Feature vector shape:\", feature_vectors.shape)\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,6 +59,9 @@ def __init__(self, scale: float = 1 / 255): | |
def __call__(self, images: List[np.ndarray]) -> List[np.ndarray]: | ||
return [rescale(image, scale=self.scale) for image in images] | ||
|
||
class PILtoNDarray: | ||
def __call__(self, images: List[Union[Image.Image, np.ndarray]]) -> List[np.ndarray]: | ||
return [np.asarray(image).swapaxes(2, 0) if isinstance(image, Image.Image) else image for image in images] | ||
I8dNLo marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems to be (H, W, C) -> (C, W, H) but should be (H, W, C) -> (C, H, W) so we need to use transpose((2, 0, 1)) instead of swapaxes, should not we? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But it's literally the same thing, isn't it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a = np.random.random((3,4,5))
b = a.swapaxes(2, 0)
c = a.transpose((2, 0, 1))
print(a.shape, b.shape, c.shape)
>>> ((3, 4, 5), (5, 4, 3), (5, 3, 4)) |
||
|
||
class Compose: | ||
def __init__(self, transforms: List[Transform]): | ||
|
@@ -96,6 +99,9 @@ def from_config(cls, config: Dict[str, Any]) -> "Compose": | |
else: | ||
raise ValueError(f"Invalid crop size: {crop_size}") | ||
transforms.append(CenterCrop(size=crop_size)) | ||
|
||
transforms.append(PILtoNDarray()) | ||
|
||
if config.get("do_rescale", True): | ||
rescale_factor = config.get("rescale_factor", 1 / 255) | ||
transforms.append(Rescale(scale=rescale_factor)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just as a reminder: we might want to inspect other resnet models to have lower dimensionality