-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
new: Added jina clip v1 #408
Changes from all commits
8dbc562
c1ff4b4
464f2f4
0c6bf70
ef673a2
98fc92e
7684805
34ae2e0
e7e0986
cf4ac9e
eb7b425
e8a15b9
2d2e708
b22abcc
0f9cbf6
377836c
67d3ef7
63e294b
7627d78
3a847f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from typing import Any, Union | ||
from typing import Any, Union, Optional | ||
|
||
import numpy as np | ||
from PIL import Image | ||
|
@@ -10,6 +10,7 @@ | |
pil2ndarray, | ||
rescale, | ||
resize, | ||
pad2square, | ||
) | ||
|
||
|
||
|
@@ -66,6 +67,21 @@ def __call__(self, images: list[Union[Image.Image, np.ndarray]]) -> list[np.ndar | |
return [pil2ndarray(image) for image in images] | ||
|
||
|
||
class PadtoSquare(Transform): | ||
def __init__( | ||
self, | ||
size: int, | ||
fill_color: Optional[Union[str, int, tuple[int, ...]]] = None, | ||
): | ||
self.size = size | ||
self.fill_color = fill_color | ||
|
||
def __call__(self, images: list[Image.Image]) -> list[Image.Image]: | ||
return [ | ||
pad2square(image=image, size=self.size, fill_color=self.fill_color) for image in images | ||
] | ||
|
||
|
||
class Compose: | ||
def __init__(self, transforms: list[Transform]): | ||
self.transforms = transforms | ||
|
@@ -85,14 +101,20 @@ def from_config(cls, config: dict[str, Any]) -> "Compose": | |
|
||
Valid keys: | ||
- do_resize | ||
- resize_mode | ||
- size | ||
- fill_color | ||
- do_center_crop | ||
- crop_size | ||
- do_rescale | ||
- rescale_factor | ||
- do_normalize | ||
- image_mean | ||
- mean | ||
- image_std | ||
- std | ||
- resample | ||
- interpolation | ||
Valid size keys (nested): | ||
- {"height", "width"} | ||
- {"shortest_edge"} | ||
|
@@ -103,6 +125,7 @@ def from_config(cls, config: dict[str, Any]) -> "Compose": | |
transforms = [] | ||
cls._get_convert_to_rgb(transforms, config) | ||
cls._get_resize(transforms, config) | ||
cls._get_pad2square(transforms, config) | ||
cls._get_center_crop(transforms, config) | ||
cls._get_pil2ndarray(transforms, config) | ||
cls._get_rescale(transforms, config) | ||
|
@@ -113,8 +136,8 @@ def from_config(cls, config: dict[str, Any]) -> "Compose": | |
def _get_convert_to_rgb(transforms: list[Transform], config: dict[str, Any]): | ||
transforms.append(ConvertToRGB()) | ||
|
||
@staticmethod | ||
def _get_resize(transforms: list[Transform], config: dict[str, Any]): | ||
@classmethod | ||
def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]): | ||
mode = config.get("image_processor_type", "CLIPImageProcessor") | ||
if mode == "CLIPImageProcessor": | ||
if config.get("do_resize", False): | ||
|
@@ -157,6 +180,24 @@ def _get_resize(transforms: list[Transform], config: dict[str, Any]): | |
resample=config.get("resample", Image.Resampling.BICUBIC), | ||
) | ||
) | ||
elif mode == "JinaCLIPImageProcessor": | ||
interpolation = config.get("interpolation") | ||
if isinstance(interpolation, str): | ||
resample = cls._interpolation_resolver(interpolation) | ||
else: | ||
resample = interpolation or Image.Resampling.BICUBIC | ||
|
||
if "size" in config: | ||
resize_mode = config.get("resize_mode", "shortest") | ||
if resize_mode == "shortest": | ||
transforms.append( | ||
Resize( | ||
size=config["size"], | ||
resample=resample, | ||
) | ||
) | ||
else: | ||
raise ValueError(f"Preprocessor {mode} is not supported") | ||
|
||
@staticmethod | ||
def _get_center_crop(transforms: list[Transform], config: dict[str, Any]): | ||
|
@@ -173,6 +214,8 @@ def _get_center_crop(transforms: list[Transform], config: dict[str, Any]): | |
transforms.append(CenterCrop(size=crop_size)) | ||
elif mode == "ConvNextFeatureExtractor": | ||
pass | ||
elif mode == "JinaCLIPImageProcessor": | ||
hh-space-invader marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pass | ||
else: | ||
raise ValueError(f"Preprocessor {mode} is not supported") | ||
|
||
|
@@ -190,3 +233,36 @@ def _get_rescale(transforms: list[Transform], config: dict[str, Any]): | |
def _get_normalize(transforms: list[Transform], config: dict[str, Any]): | ||
if config.get("do_normalize", False): | ||
transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"])) | ||
elif "mean" in config and "std" in config: | ||
transforms.append(Normalize(mean=config["mean"], std=config["std"])) | ||
|
||
@staticmethod | ||
def _get_pad2square(transforms: list[Transform], config: dict[str, Any]): | ||
mode = config.get("image_processor_type", "CLIPImageProcessor") | ||
if mode == "CLIPImageProcessor": | ||
pass | ||
elif mode == "ConvNextFeatureExtractor": | ||
pass | ||
hh-space-invader marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elif mode == "JinaCLIPImageProcessor": | ||
transforms.append( | ||
PadtoSquare( | ||
size=config["size"], | ||
fill_color=config.get("fill_color", 0), | ||
) | ||
) | ||
|
||
@staticmethod | ||
def _interpolation_resolver(resample: Optional[str] = None) -> Image.Resampling: | ||
interpolation_map = { | ||
"nearest": Image.Resampling.NEAREST, | ||
"lanczos": Image.Resampling.LANCZOS, | ||
"bilinear": Image.Resampling.BILINEAR, | ||
"bicubic": Image.Resampling.BICUBIC, | ||
"box": Image.Resampling.BOX, | ||
"hamming": Image.Resampling.HAMMING, | ||
} | ||
|
||
if resample and (method := interpolation_map.get(resample.lower())): | ||
return method | ||
|
||
raise ValueError(f"Unknown interpolation method: {resample}") | ||
Comment on lines
+254
to
+268
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. feels like it should not be a part of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I felt the same. Got any suggestions ? fastembed/common/utils ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. idk for sure, we can just move it out of the class There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've changed get_resize to cls method as _interpolation_resolver would only be used in here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we pass a grayscale image?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our post processor, the first operation is to change the image to RGB, so it shouldn't happen