Skip to content

Commit

Permalink
new: Added jina clip v1 (#408)
Browse files Browse the repository at this point in the history
* WIP: Added jina clip text embedding

* WIP: Added preprocess for jina clip

* WIP: Added jina clip vision (not sure if it works yet)

* improve: Improved mean pooling if the output doesnt have seq length

* fix: Fixed jina clip text

* nit

* fix: Fixed jina clip image preprocessor

* fix: Fix type hints
new: added resize2square

* tests: Add jina clip vision test case

* nit

* refactor: Update fastembed/image/transform/operators.py

Co-authored-by: George <george.panchuk@qdrant.tech>

* fix: Fix indentation

* refactor: Refactored how we call padding for image

* fix: Fix pad to image when resized size larger than new square canvas

* refactor: minor refactor

* refactor: Refactor some functions in preprocess image

* fix: Fix to pad image with specified fill color

* refactor: Change resize to classmethod

* fix: Fix jina clip text v1

* fix: fix pad to square for some rectangular images (#421)

---------

Co-authored-by: George <george.panchuk@qdrant.tech>
  • Loading branch information
hh-space-invader and joein authored Dec 16, 2024
1 parent 516170c commit 3b5e4c8
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 10 deletions.
11 changes: 11 additions & 0 deletions fastembed/image/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@
},
"model_file": "model.onnx",
},
{
"model": "jinaai/jina-clip-v1",
"dim": 768,
"description": "Image embeddings, Multimodal (text&image), 2024 year",
"license": "apache-2.0",
"size_in_GB": 0.34,
"sources": {
"hf": "jinaai/jina-clip-v1",
},
"model_file": "onnx/vision_model.onnx",
},
]


Expand Down
36 changes: 31 additions & 5 deletions fastembed/image/transform/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def center_crop(

def normalize(
image: np.ndarray,
mean=Union[float, np.ndarray],
std=Union[float, np.ndarray],
mean: Union[float, np.ndarray],
std: Union[float, np.ndarray],
) -> np.ndarray:
if not isinstance(image, np.ndarray):
raise ValueError("image must be a numpy array")
Expand Down Expand Up @@ -96,10 +96,10 @@ def normalize(


def resize(
image: Image,
image: Image.Image,
size: Union[int, tuple[int, int]],
resample: Image.Resampling = Image.Resampling.BILINEAR,
) -> Image:
resample: Union[int, Image.Resampling] = Image.Resampling.BILINEAR,
) -> Image.Image:
if isinstance(size, tuple):
return image.resize(size, resample)

Expand All @@ -122,3 +122,29 @@ def pil2ndarray(image: Union[Image.Image, np.ndarray]):
if isinstance(image, Image.Image):
return np.asarray(image).transpose((2, 0, 1))
return image


def pad2square(
image: Image.Image,
size: int,
fill_color: Union[str, int, tuple[int, ...]] = 0,
) -> Image.Image:
height, width = image.height, image.width

left, right = 0, width
top, bottom = 0, height

crop_required = False
if width > size:
left = (width - size) // 2
right = left + size
crop_required = True

if height > size:
top = (height - size) // 2
bottom = top + size
crop_required = True

new_image = Image.new(mode="RGB", size=(size, size), color=fill_color)
new_image.paste(image.crop((left, top, right, bottom)) if crop_required else image)
return new_image
82 changes: 79 additions & 3 deletions fastembed/image/transform/operators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Union
from typing import Any, Union, Optional

import numpy as np
from PIL import Image
Expand All @@ -10,6 +10,7 @@
pil2ndarray,
rescale,
resize,
pad2square,
)


Expand Down Expand Up @@ -66,6 +67,21 @@ def __call__(self, images: list[Union[Image.Image, np.ndarray]]) -> list[np.ndar
return [pil2ndarray(image) for image in images]


class PadtoSquare(Transform):
def __init__(
self,
size: int,
fill_color: Optional[Union[str, int, tuple[int, ...]]] = None,
):
self.size = size
self.fill_color = fill_color

def __call__(self, images: list[Image.Image]) -> list[Image.Image]:
return [
pad2square(image=image, size=self.size, fill_color=self.fill_color) for image in images
]


class Compose:
def __init__(self, transforms: list[Transform]):
self.transforms = transforms
Expand All @@ -85,14 +101,20 @@ def from_config(cls, config: dict[str, Any]) -> "Compose":
Valid keys:
- do_resize
- resize_mode
- size
- fill_color
- do_center_crop
- crop_size
- do_rescale
- rescale_factor
- do_normalize
- image_mean
- mean
- image_std
- std
- resample
- interpolation
Valid size keys (nested):
- {"height", "width"}
- {"shortest_edge"}
Expand All @@ -103,6 +125,7 @@ def from_config(cls, config: dict[str, Any]) -> "Compose":
transforms = []
cls._get_convert_to_rgb(transforms, config)
cls._get_resize(transforms, config)
cls._get_pad2square(transforms, config)
cls._get_center_crop(transforms, config)
cls._get_pil2ndarray(transforms, config)
cls._get_rescale(transforms, config)
Expand All @@ -113,8 +136,8 @@ def from_config(cls, config: dict[str, Any]) -> "Compose":
def _get_convert_to_rgb(transforms: list[Transform], config: dict[str, Any]):
transforms.append(ConvertToRGB())

@staticmethod
def _get_resize(transforms: list[Transform], config: dict[str, Any]):
@classmethod
def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]):
mode = config.get("image_processor_type", "CLIPImageProcessor")
if mode == "CLIPImageProcessor":
if config.get("do_resize", False):
Expand Down Expand Up @@ -157,6 +180,24 @@ def _get_resize(transforms: list[Transform], config: dict[str, Any]):
resample=config.get("resample", Image.Resampling.BICUBIC),
)
)
elif mode == "JinaCLIPImageProcessor":
interpolation = config.get("interpolation")
if isinstance(interpolation, str):
resample = cls._interpolation_resolver(interpolation)
else:
resample = interpolation or Image.Resampling.BICUBIC

if "size" in config:
resize_mode = config.get("resize_mode", "shortest")
if resize_mode == "shortest":
transforms.append(
Resize(
size=config["size"],
resample=resample,
)
)
else:
raise ValueError(f"Preprocessor {mode} is not supported")

@staticmethod
def _get_center_crop(transforms: list[Transform], config: dict[str, Any]):
Expand All @@ -173,6 +214,8 @@ def _get_center_crop(transforms: list[Transform], config: dict[str, Any]):
transforms.append(CenterCrop(size=crop_size))
elif mode == "ConvNextFeatureExtractor":
pass
elif mode == "JinaCLIPImageProcessor":
pass
else:
raise ValueError(f"Preprocessor {mode} is not supported")

Expand All @@ -190,3 +233,36 @@ def _get_rescale(transforms: list[Transform], config: dict[str, Any]):
def _get_normalize(transforms: list[Transform], config: dict[str, Any]):
if config.get("do_normalize", False):
transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"]))
elif "mean" in config and "std" in config:
transforms.append(Normalize(mean=config["mean"], std=config["std"]))

@staticmethod
def _get_pad2square(transforms: list[Transform], config: dict[str, Any]):
mode = config.get("image_processor_type", "CLIPImageProcessor")
if mode == "CLIPImageProcessor":
pass
elif mode == "ConvNextFeatureExtractor":
pass
elif mode == "JinaCLIPImageProcessor":
transforms.append(
PadtoSquare(
size=config["size"],
fill_color=config.get("fill_color", 0),
)
)

@staticmethod
def _interpolation_resolver(resample: Optional[str] = None) -> Image.Resampling:
interpolation_map = {
"nearest": Image.Resampling.NEAREST,
"lanczos": Image.Resampling.LANCZOS,
"bilinear": Image.Resampling.BILINEAR,
"bicubic": Image.Resampling.BICUBIC,
"box": Image.Resampling.BOX,
"hamming": Image.Resampling.HAMMING,
}

if resample and (method := interpolation_map.get(resample.lower())):
return method

raise ValueError(f"Unknown interpolation method: {resample}")
19 changes: 18 additions & 1 deletion fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,17 @@
},
"model_file": "onnx/model.onnx",
},
{
"model": "jinaai/jina-clip-v1",
"dim": 768,
"description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year",
"license": "apache-2.0",
"size_in_GB": 0.55,
"sources": {
"hf": "jinaai/jina-clip-v1",
},
"model_file": "onnx/text_model.onnx",
},
]


Expand Down Expand Up @@ -285,7 +296,13 @@ def _preprocess_onnx_input(

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
embeddings = output.model_output
return normalize(embeddings[:, 0]).astype(np.float32)
if embeddings.ndim == 3: # (batch_size, seq_len, embedding_dim)
processed_embeddings = embeddings[:, 0]
elif embeddings.ndim == 2: # (batch_size, embedding_dim)
processed_embeddings = embeddings
else:
raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
return normalize(processed_embeddings).astype(np.float32)

def load_onnx_model(self) -> None:
self._load_onnx_model(
Expand Down
2 changes: 1 addition & 1 deletion fastembed/text/text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(
return

raise ValueError(
f"Model {model_name} is not supported in TextEmbedding."
f"Model {model_name} is not supported in TextEmbedding. "
"Please check the supported models using `TextEmbedding.list_supported_models()`"
)

Expand Down
3 changes: 3 additions & 0 deletions tests/test_image_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
"Qdrant/Unicom-ViT-B-32": np.array(
[0.0418, 0.0550, 0.0003, 0.0253, -0.0185, 0.0016, -0.0368, -0.0402, -0.0891, -0.0186]
),
"jinaai/jina-clip-v1": np.array(
[-0.029, 0.0216, 0.0396, 0.0283, -0.0023, 0.0151, 0.011, -0.0235, 0.0251, -0.0343]
),
}


Expand Down
1 change: 1 addition & 0 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"snowflake/snowflake-arctic-embed-l": np.array([0.0189, -0.0673, 0.0183, 0.0124, 0.0146]),
"Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]),
"thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]),
"jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]),
}


Expand Down

0 comments on commit 3b5e4c8

Please sign in to comment.