Skip to content

Commit

Permalink
Add resnet (#246)
Browse files Browse the repository at this point in the history
* Resnet support added

* Tests fixed
Shapes matching for Resnet50-onnx
Example of Resnet50 to onnx conversion (basic)

* Removed optional conversion from PIL to np.ndarray and now it it's made default
Fixed test accordingly

* Refactoring of pil2ndarray

* Partial support of convnext preprocessing
Resize logic

* normalize canonical value

* Style changes for review

* new: update resnet repo

---------

Co-authored-by: d.rudenko <dimitriyrudenk@gmail.com>
Co-authored-by: George Panchuk <george.panchuk@qdrant.tech>
  • Loading branch information
3 people authored May 31, 2024
1 parent dfd25d4 commit 85aaae4
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 60 deletions.
122 changes: 122 additions & 0 deletions experiments/Example. Convert Resnet50 to ONNX.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "4bdb2a91-fa2a-4cee-ad5a-176cc957394d",
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T12:15:28.171586Z",
"start_time": "2024-05-23T12:15:28.076314Z"
}
},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'torch'",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)",
"Cell \u001B[0;32mIn[1], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\n\u001B[1;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01monnx\u001B[39;00m\n\u001B[1;32m 3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorchvision\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mmodels\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mmodels\u001B[39;00m\n",
"\u001B[0;31mModuleNotFoundError\u001B[0m: No module named 'torch'"
]
}
],
"source": [
"import torch\n",
"import torch.onnx\n",
"import torchvision.models as models\n",
"import torchvision.transforms as transforms\n",
"from PIL import Image\n",
"import numpy as np\n",
"from tests.config import TEST_MISC_DIR\n",
"\n",
"# Load pre-trained ResNet-50 model\n",
"resnet = models.resnet50(pretrained=True)\n",
"resnet = torch.nn.Sequential(*(list(resnet.children())[:-1])) # Remove the last fully connected layer\n",
"resnet.eval()\n",
"\n",
"# Define preprocessing transform\n",
"preprocess = transforms.Compose([\n",
" transforms.Resize(256),\n",
" transforms.CenterCrop(224),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
"])\n",
"\n",
"# Load and preprocess the image\n",
"def preprocess_image(image_path):\n",
" input_image = Image.open(image_path)\n",
" input_tensor = preprocess(input_image)\n",
" input_batch = input_tensor.unsqueeze(0) # Add batch dimension\n",
" return input_batch\n",
"\n",
"# Example input for exporting\n",
"input_image = preprocess_image('example.jpg')\n",
"\n",
"# Export the model to ONNX with dynamic axes\n",
"torch.onnx.export(\n",
" resnet, \n",
" input_image, \n",
" \"model.onnx\", \n",
" export_params=True, \n",
" opset_version=9, \n",
" input_names=['input'], \n",
" output_names=['output'],\n",
" dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}\n",
")\n",
"\n",
"# Load ONNX model\n",
"import onnx\n",
"import onnxruntime as ort\n",
"\n",
"onnx_model = onnx.load(\"model.onnx\")\n",
"ort_session = ort.InferenceSession(\"model.onnx\")\n",
"\n",
"# Run inference and extract feature vectors\n",
"def extract_feature_vectors(image_paths):\n",
" input_images = [preprocess_image(image_path) for image_path in image_paths]\n",
" input_batch = torch.cat(input_images, dim=0) # Combine images into a single batch\n",
" ort_inputs = {ort_session.get_inputs()[0].name: input_batch.numpy()}\n",
" ort_outs = ort_session.run(None, ort_inputs)\n",
" return ort_outs[0]\n",
"\n",
"# Example usage\n",
"images = [TEST_MISC_DIR / \"image.jpeg\", str(TEST_MISC_DIR / \"small_image.jpeg\")] # Replace with your image paths\n",
"feature_vectors = extract_feature_vectors(images)\n",
"print(\"Feature vector shape:\", feature_vectors.shape)\n"
]
},
{
"cell_type": "code",
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "baa650c4cb3e0e6d"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
12 changes: 11 additions & 1 deletion fastembed/image/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,17 @@
"hf": "Qdrant/clip-ViT-B-32-vision",
},
"model_file": "model.onnx",
}
},
{
"model": "Qdrant/resnet50-onnx",
"dim": 2048,
"description": "ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.",
"size_in_GB": 0.1,
"sources": {
"hf": "Qdrant/resnet50-onnx",
},
"model_file": "model.onnx",
},
]


Expand Down
11 changes: 7 additions & 4 deletions fastembed/image/onnx_image_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,19 @@ def load_onnx_model(
)
self.processor = load_preprocessor(model_dir=model_dir)

def _build_onnx_input(self, encoded: np.ndarray) -> Dict[str, np.ndarray]:
return {node.name: encoded for node in self.model.get_inputs()}

def onnx_embed(self, images: List[PathInput]) -> OnnxOutputContext:
with contextlib.ExitStack():
image_files = [Image.open(image) for image in images]
encoded = self.processor(image_files)
onnx_input = {"pixel_values": encoded}
onnx_input = self._build_onnx_input(encoded)
onnx_input = self._preprocess_onnx_input(onnx_input)

model_output = self.model.run(None, onnx_input)
embeddings = model_output[0]

embeddings = model_output[0].reshape(len(images), -1)

return OnnxOutputContext(
model_output=embeddings
)
Expand Down Expand Up @@ -82,7 +86,6 @@ def _embed_images(

if parallel is None or is_small:
for batch in iter_batch(images, batch_size):
# open and preprocess images
yield from self._post_process_onnx_output(self.onnx_embed(batch))
else:
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
Expand Down
24 changes: 16 additions & 8 deletions fastembed/image/transform/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ def convert_to_rgb(image: Image.Image) -> Image.Image:


def center_crop(
image: Image.Image,
image: Union[Image.Image, np.ndarray],
size: Tuple[int, int],
) -> np.ndarray:
orig_height, orig_width = image.height, image.width
crop_height, crop_width = size
if isinstance(image, np.ndarray):
_, orig_height, orig_width = image.shape
else:
orig_height, orig_width = image.height, image.width
# (H, W, C) -> (C, H, W)
image = np.array(image).transpose((2, 0, 1))

# (H, W, C) -> (C, H, W)
image = np.array(image).transpose((2, 0, 1))
crop_height, crop_width = size

# left upper corner (0, 0)
top = (orig_height - crop_height) // 2
Expand Down Expand Up @@ -96,7 +99,7 @@ def normalize(
def resize(
image: Image,
size: Union[int, Tuple[int, int]],
resample: Image.Resampling = Image.Resampling.BICUBIC,
resample: Image.Resampling = Image.Resampling.BILINEAR,
) -> Image:
if isinstance(size, tuple):
return image.resize(size, resample)
Expand All @@ -109,9 +112,14 @@ def resize(
new_size = (new_short, new_long)
else:
new_size = (new_long, new_short)

return image.resize(new_size, Image.Resampling.BICUBIC)
return image.resize(new_size, resample)


def rescale(image: np.ndarray, scale: float, dtype=np.float32) -> np.ndarray:
return (image * scale).astype(dtype)


def pil2ndarray(image: Union[Image.Image, np.ndarray]):
if isinstance(image, Image.Image):
return np.asarray(image).transpose((2, 0, 1))
return image
136 changes: 89 additions & 47 deletions fastembed/image/transform/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import numpy as np
from PIL import Image

from fastembed.image.transform.functional import (
center_crop,
normalize,
resize,
convert_to_rgb,
rescale,
pil2ndarray
)


Expand Down Expand Up @@ -59,68 +59,110 @@ def __init__(self, scale: float = 1 / 255):
def __call__(self, images: List[np.ndarray]) -> List[np.ndarray]:
return [rescale(image, scale=self.scale) for image in images]

class PILtoNDarray(Transform):
def __call__(self, images: List[Union[Image.Image, np.ndarray]]) -> List[np.ndarray]:
return [pil2ndarray(image) for image in images]

class Compose:
def __init__(self, transforms: List[Transform]):
self.transforms = transforms

def __call__(
self, images: Union[List[Image.Image], List[np.ndarray]]
) -> Union[List[np.ndarray], List[Image.Image]]:
def __call__(self, images: Union[List[Image.Image], List[np.ndarray]]) -> Union[List[np.ndarray], List[Image.Image]]:
for transform in self.transforms:
images = transform(images)
return images

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "Compose":
"""Creates processor from a config dict.
Args:
config (Dict[str, Any]): Configuration dictionary.
Valid keys:
- do_resize
- size
- do_center_crop
- crop_size
- do_rescale
- rescale_factor
- do_normalize
- image_mean
- image_std
Valid size keys (nested):
- {"height", "width"}
- {"shortest_edge"}
Returns:
Compose: Image processor.
Args:
config (Dict[str, Any]): Configuration dictionary.
Valid keys:
- do_resize
- size
- do_center_crop
- crop_size
- do_rescale
- rescale_factor
- do_normalize
- image_mean
- image_std
Valid size keys (nested):
- {"height", "width"}
- {"shortest_edge"}
Returns:
Compose: Image processor.
"""
transforms = [ConvertToRGB()]
if config.get("do_resize", False):
size = config["size"]
if "shortest_edge" in size:
size = size["shortest_edge"]
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError(
"Size must contain either 'shortest_edge' or 'height' and 'width'."
)
transforms.append(
Resize(size=size, resample=config.get("resample", Image.Resampling.BICUBIC))
)
if config.get("do_center_crop", False):
crop_size = config["crop_size"]
if isinstance(crop_size, int):
crop_size = (crop_size, crop_size)
elif isinstance(crop_size, dict):
crop_size = (crop_size["height"], crop_size["width"])
transforms = []
cls._get_convert_to_rgb(transforms, config)
cls._get_resize(transforms, config)
cls._get_center_crop(transforms, config)
cls._get_pil2ndarray(transforms, config)
cls._get_rescale(transforms, config)
cls._get_normalize(transforms, config)
return cls(transforms=transforms)

@staticmethod
def _get_convert_to_rgb(transforms: List[Transform], config: Dict[str, Any]):
transforms.append(ConvertToRGB())

@staticmethod
def _get_resize(transforms: List[Transform], config: Dict[str, Any]):
mode = config.get('image_processor_type', 'CLIPImageProcessor')
if mode == 'CLIPImageProcessor':
if config.get("do_resize", False):
size = config["size"]
if "shortest_edge" in size:
size = size["shortest_edge"]
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
transforms.append(Resize(size=size, resample=config.get("resample", Image.Resampling.BICUBIC)))
elif mode == 'ConvNextFeatureExtractor':
if 'size' in config and "shortest_edge" not in config['size']:
raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {config['size'].keys()}")
shortest_edge = config['size']["shortest_edge"]
crop_pct = config.get("crop_pct", 0.875)
if shortest_edge < 384:
# maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
resize_shortest_edge = int(shortest_edge / crop_pct)
transforms.append(Resize(size=resize_shortest_edge, resample=config.get("resample", Image.Resampling.BICUBIC)))
transforms.append(CenterCrop(size=(shortest_edge, shortest_edge)))
else:
raise ValueError(f"Invalid crop size: {crop_size}")
transforms.append(CenterCrop(size=crop_size))
transforms.append(Resize(size=(shortest_edge, shortest_edge), resample=config.get("resample", Image.Resampling.BICUBIC)))

@staticmethod
def _get_center_crop(transforms: List[Transform], config: Dict[str, Any]):
mode = config.get('image_processor_type', 'CLIPImageProcessor')
if mode == 'CLIPImageProcessor':
if config.get("do_center_crop", False):
crop_size = config["crop_size"]
if isinstance(crop_size, int):
crop_size = (crop_size, crop_size)
elif isinstance(crop_size, dict):
crop_size = (crop_size["height"], crop_size["width"])
else:
raise ValueError(f"Invalid crop size: {crop_size}")
transforms.append(CenterCrop(size=crop_size))
elif mode == 'ConvNextFeatureExtractor':
pass
else:
raise ValueError(f"Preprocessor {mode} is not supported")

@staticmethod
def _get_pil2ndarray(transforms: List[Transform], config: Dict[str, Any]):
transforms.append(PILtoNDarray())

@staticmethod
def _get_rescale(transforms: List[Transform], config: Dict[str, Any]):
if config.get("do_rescale", True):
rescale_factor = config.get("rescale_factor", 1 / 255)
transforms.append(Rescale(scale=rescale_factor))

@staticmethod
def _get_normalize(transforms: List[Transform], config: Dict[str, Any]):
if config.get("do_normalize", False):
transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"]))
return cls(transforms=transforms)
Loading

0 comments on commit 85aaae4

Please sign in to comment.