Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new: Added jina clip v1 #408

Merged
merged 20 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
8dbc562
WIP: Added jina clip text embedding
hh-space-invader Nov 19, 2024
c1ff4b4
WIP: Added preprocess for jina clip
hh-space-invader Nov 19, 2024
464f2f4
WIP: Added jina clip vision (not sure if it works yet)
hh-space-invader Nov 19, 2024
0c6bf70
improve: Improved mean pooling if the output doesnt have seq length
hh-space-invader Nov 20, 2024
ef673a2
fix: Fixed jina clip text
hh-space-invader Nov 20, 2024
98fc92e
nit
hh-space-invader Nov 20, 2024
7684805
fix: Fixed jina clip image preprocessor
hh-space-invader Nov 20, 2024
34ae2e0
fix: Fix type hints
hh-space-invader Nov 20, 2024
e7e0986
tests: Add jina clip vision test case
hh-space-invader Nov 20, 2024
cf4ac9e
nit
hh-space-invader Nov 21, 2024
eb7b425
refactor: Update fastembed/image/transform/operators.py
hh-space-invader Nov 24, 2024
e8a15b9
fix: Fix indentation
hh-space-invader Nov 24, 2024
2d2e708
refactor: Refactored how we call padding for image
hh-space-invader Nov 24, 2024
b22abcc
fix: Fix pad to image when resized size larger than new square canvas
hh-space-invader Nov 25, 2024
0f9cbf6
refactor: minor refactor
hh-space-invader Nov 26, 2024
377836c
refactor: Refactor some functions in preprocess image
hh-space-invader Dec 4, 2024
67d3ef7
fix: Fix to pad image with specified fill color
hh-space-invader Dec 5, 2024
63e294b
refactor: Change resize to classmethod
hh-space-invader Dec 5, 2024
7627d78
fix: Fix jina clip text v1
hh-space-invader Dec 10, 2024
3a847f6
fix: fix pad to square for some rectangular images (#421)
joein Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions fastembed/image/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@
},
"model_file": "model.onnx",
},
{
"model": "jinaai/jina-clip-v1",
"dim": 768,
"description": "Image embeddings, Multimodal (text&image), 2024 year",
"license": "apache-2.0",
"size_in_GB": 0.34,
"sources": {
"hf": "jinaai/jina-clip-v1",
},
"model_file": "onnx/vision_model.onnx",
},
]


Expand Down
36 changes: 31 additions & 5 deletions fastembed/image/transform/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def center_crop(

def normalize(
image: np.ndarray,
mean=Union[float, np.ndarray],
std=Union[float, np.ndarray],
mean: Union[float, np.ndarray],
std: Union[float, np.ndarray],
) -> np.ndarray:
if not isinstance(image, np.ndarray):
raise ValueError("image must be a numpy array")
Expand Down Expand Up @@ -96,10 +96,10 @@ def normalize(


def resize(
image: Image,
image: Image.Image,
size: Union[int, tuple[int, int]],
resample: Image.Resampling = Image.Resampling.BILINEAR,
) -> Image:
resample: Union[int, Image.Resampling] = Image.Resampling.BILINEAR,
) -> Image.Image:
if isinstance(size, tuple):
return image.resize(size, resample)

Expand All @@ -122,3 +122,29 @@ def pil2ndarray(image: Union[Image.Image, np.ndarray]):
if isinstance(image, Image.Image):
return np.asarray(image).transpose((2, 0, 1))
return image


def pad2square(
image: Image.Image,
size: int,
fill_color: Union[str, int, tuple[int, ...]] = 0,
) -> Image.Image:
height, width = image.height, image.width

left, right = 0, width
top, bottom = 0, height

crop_required = False
if width > size:
left = (width - size) // 2
right = left + size
crop_required = True

if height > size:
top = (height - size) // 2
bottom = top + size
crop_required = True

new_image = Image.new(mode="RGB", size=(size, size), color=fill_color)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we pass a grayscale image?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our post processor, the first operation is to change the image to RGB, so it shouldn't happen

new_image.paste(image.crop((left, top, right, bottom)) if crop_required else image)
return new_image
82 changes: 79 additions & 3 deletions fastembed/image/transform/operators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Union
from typing import Any, Union, Optional

import numpy as np
from PIL import Image
Expand All @@ -10,6 +10,7 @@
pil2ndarray,
rescale,
resize,
pad2square,
)


Expand Down Expand Up @@ -66,6 +67,21 @@ def __call__(self, images: list[Union[Image.Image, np.ndarray]]) -> list[np.ndar
return [pil2ndarray(image) for image in images]


class PadtoSquare(Transform):
def __init__(
self,
size: int,
fill_color: Optional[Union[str, int, tuple[int, ...]]] = None,
):
self.size = size
self.fill_color = fill_color

def __call__(self, images: list[Image.Image]) -> list[Image.Image]:
return [
pad2square(image=image, size=self.size, fill_color=self.fill_color) for image in images
]


class Compose:
def __init__(self, transforms: list[Transform]):
self.transforms = transforms
Expand All @@ -85,14 +101,20 @@ def from_config(cls, config: dict[str, Any]) -> "Compose":

Valid keys:
- do_resize
- resize_mode
- size
- fill_color
- do_center_crop
- crop_size
- do_rescale
- rescale_factor
- do_normalize
- image_mean
- mean
- image_std
- std
- resample
- interpolation
Valid size keys (nested):
- {"height", "width"}
- {"shortest_edge"}
Expand All @@ -103,6 +125,7 @@ def from_config(cls, config: dict[str, Any]) -> "Compose":
transforms = []
cls._get_convert_to_rgb(transforms, config)
cls._get_resize(transforms, config)
cls._get_pad2square(transforms, config)
cls._get_center_crop(transforms, config)
cls._get_pil2ndarray(transforms, config)
cls._get_rescale(transforms, config)
Expand All @@ -113,8 +136,8 @@ def from_config(cls, config: dict[str, Any]) -> "Compose":
def _get_convert_to_rgb(transforms: list[Transform], config: dict[str, Any]):
transforms.append(ConvertToRGB())

@staticmethod
def _get_resize(transforms: list[Transform], config: dict[str, Any]):
@classmethod
def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]):
mode = config.get("image_processor_type", "CLIPImageProcessor")
if mode == "CLIPImageProcessor":
if config.get("do_resize", False):
Expand Down Expand Up @@ -157,6 +180,24 @@ def _get_resize(transforms: list[Transform], config: dict[str, Any]):
resample=config.get("resample", Image.Resampling.BICUBIC),
)
)
elif mode == "JinaCLIPImageProcessor":
interpolation = config.get("interpolation")
if isinstance(interpolation, str):
resample = cls._interpolation_resolver(interpolation)
else:
resample = interpolation or Image.Resampling.BICUBIC

if "size" in config:
resize_mode = config.get("resize_mode", "shortest")
if resize_mode == "shortest":
transforms.append(
Resize(
size=config["size"],
resample=resample,
)
)
else:
raise ValueError(f"Preprocessor {mode} is not supported")

@staticmethod
def _get_center_crop(transforms: list[Transform], config: dict[str, Any]):
Expand All @@ -173,6 +214,8 @@ def _get_center_crop(transforms: list[Transform], config: dict[str, Any]):
transforms.append(CenterCrop(size=crop_size))
elif mode == "ConvNextFeatureExtractor":
pass
elif mode == "JinaCLIPImageProcessor":
pass
else:
raise ValueError(f"Preprocessor {mode} is not supported")

Expand All @@ -190,3 +233,36 @@ def _get_rescale(transforms: list[Transform], config: dict[str, Any]):
def _get_normalize(transforms: list[Transform], config: dict[str, Any]):
if config.get("do_normalize", False):
transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"]))
elif "mean" in config and "std" in config:
transforms.append(Normalize(mean=config["mean"], std=config["std"]))

@staticmethod
def _get_pad2square(transforms: list[Transform], config: dict[str, Any]):
mode = config.get("image_processor_type", "CLIPImageProcessor")
if mode == "CLIPImageProcessor":
pass
elif mode == "ConvNextFeatureExtractor":
pass
elif mode == "JinaCLIPImageProcessor":
transforms.append(
PadtoSquare(
size=config["size"],
fill_color=config.get("fill_color", 0),
)
)

@staticmethod
def _interpolation_resolver(resample: Optional[str] = None) -> Image.Resampling:
interpolation_map = {
"nearest": Image.Resampling.NEAREST,
"lanczos": Image.Resampling.LANCZOS,
"bilinear": Image.Resampling.BILINEAR,
"bicubic": Image.Resampling.BICUBIC,
"box": Image.Resampling.BOX,
"hamming": Image.Resampling.HAMMING,
}

if resample and (method := interpolation_map.get(resample.lower())):
return method

raise ValueError(f"Unknown interpolation method: {resample}")
Comment on lines +254 to +268
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels like it should not be a part of Compose class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt the same. Got any suggestions ? fastembed/common/utils ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idk for sure, we can just move it out of the class
at least, this Compose._interpolation_resolver is super ugly, if we keep this method here, we need to make get_resize a class method, not a static, it is only used inside of Compose class anyway

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed get_resize to cls method as _interpolation_resolver would only be used in here

19 changes: 18 additions & 1 deletion fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,17 @@
},
"model_file": "onnx/model.onnx",
},
{
"model": "jinaai/jina-clip-v1",
"dim": 768,
"description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year",
"license": "apache-2.0",
"size_in_GB": 0.55,
"sources": {
"hf": "jinaai/jina-clip-v1",
},
"model_file": "onnx/text_model.onnx",
},
]


Expand Down Expand Up @@ -285,7 +296,13 @@ def _preprocess_onnx_input(

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
embeddings = output.model_output
return normalize(embeddings[:, 0]).astype(np.float32)
if embeddings.ndim == 3: # (batch_size, seq_len, embedding_dim)
processed_embeddings = embeddings[:, 0]
elif embeddings.ndim == 2: # (batch_size, embedding_dim)
processed_embeddings = embeddings
else:
raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
return normalize(processed_embeddings).astype(np.float32)

def load_onnx_model(self) -> None:
self._load_onnx_model(
Expand Down
2 changes: 1 addition & 1 deletion fastembed/text/text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(
return

raise ValueError(
f"Model {model_name} is not supported in TextEmbedding."
f"Model {model_name} is not supported in TextEmbedding. "
"Please check the supported models using `TextEmbedding.list_supported_models()`"
)

Expand Down
3 changes: 3 additions & 0 deletions tests/test_image_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
"Qdrant/Unicom-ViT-B-32": np.array(
[0.0418, 0.0550, 0.0003, 0.0253, -0.0185, 0.0016, -0.0368, -0.0402, -0.0891, -0.0186]
),
"jinaai/jina-clip-v1": np.array(
[-0.029, 0.0216, 0.0396, 0.0283, -0.0023, 0.0151, 0.011, -0.0235, 0.0251, -0.0343]
),
}


Expand Down
1 change: 1 addition & 0 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"snowflake/snowflake-arctic-embed-l": np.array([0.0189, -0.0673, 0.0183, 0.0124, 0.0146]),
"Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]),
"thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]),
"jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]),
}


Expand Down
Loading