diff --git a/NOTICE b/NOTICE index caa664b7..2d3d3b29 100644 --- a/NOTICE +++ b/NOTICE @@ -12,5 +12,11 @@ This distribution includes the following Jina AI models, each with its respectiv These models are developed by Jina (https://jina.ai/) and are subject to Jina AI's licensing terms. +This distribution includes the following Google models, each with its respective license: +- vidore/colpali-v1.3 + - License: gemma + +Gemma is provided under and subject to the Gemma Terms of Use found at ai.google.dev/gemma/terms + Additional Notes: This project also includes third-party libraries with their respective licenses. Please refer to the documentation of each library for details regarding its usage and licensing terms. diff --git a/fastembed/image/image_embedding.py b/fastembed/image/image_embedding.py index 01157c1c..7c4140de 100644 --- a/fastembed/image/image_embedding.py +++ b/fastembed/image/image_embedding.py @@ -79,8 +79,7 @@ def embed( **kwargs: Any, ) -> Iterable[NumpyArray]: """ - Encode a list of documents into list of embeddings. - We use mean pooling with attention so that the model can handle variable-length inputs. + Encode a list of images into list of embeddings. Args: images: Iterator of image paths or single image path to embed diff --git a/fastembed/image/transform/operators.py b/fastembed/image/transform/operators.py index 58355b43..55a45074 100644 --- a/fastembed/image/transform/operators.py +++ b/fastembed/image/transform/operators.py @@ -139,7 +139,7 @@ def _get_convert_to_rgb(transforms: list[Transform], config: dict[str, Any]) -> @classmethod def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]) -> None: mode = config.get("image_processor_type", "CLIPImageProcessor") - if mode == "CLIPImageProcessor": + if mode in ("CLIPImageProcessor", "SiglipImageProcessor"): if config.get("do_resize", False): size = config["size"] if "shortest_edge" in size: @@ -202,7 +202,7 @@ def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]) -> Non @staticmethod def _get_center_crop(transforms: list[Transform], config: dict[str, Any]) -> None: mode = config.get("image_processor_type", "CLIPImageProcessor") - if mode == "CLIPImageProcessor": + if mode in ("CLIPImageProcessor", "SiglipImageProcessor"): if config.get("do_center_crop", False): crop_size_raw = config["crop_size"] crop_size: tuple[int, int] diff --git a/fastembed/late_interaction_multimodal/__init__.py b/fastembed/late_interaction_multimodal/__init__.py new file mode 100644 index 00000000..50588cde --- /dev/null +++ b/fastembed/late_interaction_multimodal/__init__.py @@ -0,0 +1,5 @@ +from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding import ( + LateInteractionMultimodalEmbedding, +) + +__all__ = ["LateInteractionMultimodalEmbedding"] diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py new file mode 100644 index 00000000..5944ed5f --- /dev/null +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -0,0 +1,301 @@ +from typing import Any, Iterable, Optional, Sequence, Type, Union + +import numpy as np +from tokenizers import Encoding + +from fastembed.common import OnnxProvider, ImageInput +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.utils import define_cache_dir +from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import ( + LateInteractionMultimodalEmbeddingBase, +) +from fastembed.late_interaction_multimodal.onnx_multimodal_model import ( + OnnxMultimodalModel, + TextEmbeddingWorker, + ImageEmbeddingWorker, +) + +supported_colpali_models = [ + { + "model": "Qdrant/colpali-v1.3-fp16", + "dim": 128, + "description": "Text embeddings, Multimodal (text&image), English, 50 tokens query length truncation, 2024.", + "license": "mit", + "size_in_GB": 6.5, + "sources": { + "hf": "Qdrant/colpali-v1.3-fp16", + }, + "additional_files": [ + "model.onnx_data", + ], + "model_file": "model.onnx", + }, +] + + +class ColPali(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[np.ndarray]): + QUERY_PREFIX = "Query: " + BOS_TOKEN = "" + PAD_TOKEN = "" + QUERY_MARKER_TOKEN_ID = [2, 5098] + IMAGE_PLACEHOLDER_SIZE = (3, 448, 448) + EMPTY_TEXT_PLACEHOLDER = np.array( + [257152] * 1024 + [2, 50721, 573, 2416, 235265, 108] + ) # This is a tokenization of '' * 1024 + 'Describe the image.\n' line which is used as placeholder + # while processing an image + EVEN_ATTENTION_MASK = np.array([1] * 1030) + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + lazy_load: bool = False, + device_id: Optional[int] = None, + specific_model_path: Optional[str] = None, + **kwargs, + ): + """ + Args: + model_name (str): The name of the model to use. + cache_dir (str, optional): The path to the cache directory. + Can be set using the `FASTEMBED_CACHE_PATH` env variable. + Defaults to `fastembed_cache` in the system's temp directory. + threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None. + providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use. + Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None. + cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers` + Defaults to False. + device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in + workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None. + lazy_load (bool, optional): Whether to load the model during class initialization or on demand. + Should be set to True when using multiple-gpu and parallel encoding. Defaults to False. + device_id (Optional[int], optional): The device id to use for loading the model in the worker process. + + Raises: + ValueError: If the model_name is not in the format / e.g. BAAI/bge-base-en. + """ + + super().__init__(model_name, cache_dir, threads, **kwargs) + self.providers = providers + self.lazy_load = lazy_load + + # List of device ids, that can be used for data parallel processing in workers + self.device_ids = device_ids + self.cuda = cuda + + # This device_id will be used if we need to load model in current process + if device_id is not None: + self.device_id = device_id + elif self.device_ids is not None: + self.device_id = self.device_ids[0] + else: + self.device_id = None + + self.model_description = self._get_model_description(model_name) + self.cache_dir = define_cache_dir(cache_dir) + + self._model_dir = self.download_model( + self.model_description, + self.cache_dir, + local_files_only=self._local_files_only, + specific_model_path=specific_model_path, + ) + self.mask_token_id = None + self.pad_token_id = None + + if not self.lazy_load: + self.load_onnx_model() + + @classmethod + def list_supported_models(cls) -> list[dict[str, Any]]: + """Lists the supported models. + + Returns: + list[dict[str, Any]]: A list of dictionaries containing the model information. + """ + return supported_colpali_models + + def load_onnx_model(self) -> None: + self._load_onnx_model( + model_dir=self._model_dir, + model_file=self.model_description["model_file"], + threads=self.threads, + providers=self.providers, + cuda=self.cuda, + device_id=self.device_id, + ) + + def _post_process_onnx_image_output( + self, + output: OnnxOutputContext, + ) -> Iterable[np.ndarray]: + """ + Post-process the ONNX model output to convert it into a usable format. + + Args: + output (OnnxOutputContext): The raw output from the ONNX model. + + Returns: + Iterable[np.ndarray]: Post-processed output as NumPy arrays. + """ + return output.model_output.reshape( + output.model_output.shape[0], -1, self.model_description["dim"] + ).astype(np.float32) + + def _post_process_onnx_text_output( + self, + output: OnnxOutputContext, + ) -> Iterable[np.ndarray]: + """ + Post-process the ONNX model output to convert it into a usable format. + + Args: + output (OnnxOutputContext): The raw output from the ONNX model. + + Returns: + Iterable[np.ndarray]: Post-processed output as NumPy arrays. + """ + return output.model_output.astype(np.float32) + + def tokenize(self, documents: list[str], **_) -> list[Encoding]: + texts_query: list[str] = [] + for query in documents: + query = self.BOS_TOKEN + self.QUERY_PREFIX + query + self.PAD_TOKEN * 10 + query += "\n" + + texts_query.append(query) + encoded = self.tokenizer.encode_batch(texts_query) + return encoded + + def _preprocess_onnx_text_input( + self, onnx_input: dict[str, np.ndarray], **kwargs + ) -> dict[str, np.ndarray]: + onnx_input["input_ids"] = np.array( + [ + self.QUERY_MARKER_TOKEN_ID + input_ids[2:].tolist() + for input_ids in onnx_input["input_ids"] + ] + ) + empty_image_placeholder = np.zeros(self.IMAGE_PLACEHOLDER_SIZE, dtype=np.float32) + onnx_input["pixel_values"] = np.array( + [empty_image_placeholder for _ in onnx_input["input_ids"]] + ) + return onnx_input + + def _preprocess_onnx_image_input( + self, onnx_input: dict[str, np.ndarray], **kwargs + ) -> dict[str, np.ndarray]: + """ + Add placeholders for text input when processing image data for ONNX. + Args: + onnx_input (Dict[str, np.ndarray]): Preprocessed image inputs. + **kwargs: Additional arguments. + Returns: + Dict[str, np.ndarray]: ONNX input with text placeholders. + """ + + onnx_input["input_ids"] = np.array( + [self.EMPTY_TEXT_PLACEHOLDER for _ in onnx_input["input_ids"]] + ) + onnx_input["attention_mask"] = np.array( + [self.EVEN_ATTENTION_MASK for _ in onnx_input["input_ids"]] + ) + return onnx_input + + def embed_text( + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: Optional[int] = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Encode a list of documents into list of embeddings. + + Args: + documents: Iterator of documents or single document to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per document + """ + yield from self._embed_documents( + model_name=self.model_name, + cache_dir=str(self.cache_dir), + documents=documents, + batch_size=batch_size, + parallel=parallel, + providers=self.providers, + cuda=self.cuda, + device_ids=self.device_ids, + **kwargs, + ) + + def embed_image( + self, + images: ImageInput, + batch_size: int = 16, + parallel: Optional[int] = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Encode a list of images into list of embeddings. + + Args: + images: Iterator of image paths or single image path to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per document + """ + yield from self._embed_images( + model_name=self.model_name, + cache_dir=str(self.cache_dir), + images=images, + batch_size=batch_size, + parallel=parallel, + providers=self.providers, + cuda=self.cuda, + device_ids=self.device_ids, + **kwargs, + ) + + @classmethod + def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker]: + return ColPaliTextEmbeddingWorker + + @classmethod + def _get_image_worker_class(cls) -> Type[ImageEmbeddingWorker]: + return ColPaliImageEmbeddingWorker + + +class ColPaliTextEmbeddingWorker(TextEmbeddingWorker): + def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali: + return ColPali( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) + + +class ColPaliImageEmbeddingWorker(ImageEmbeddingWorker): + def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali: + return ColPali( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py new file mode 100644 index 00000000..f1c7b794 --- /dev/null +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -0,0 +1,125 @@ +from typing import Any, Iterable, Optional, Sequence, Type, Union + +import numpy as np + +from fastembed.common import OnnxProvider, ImageInput +from fastembed.late_interaction_multimodal.colpali import ColPali + +from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import ( + LateInteractionMultimodalEmbeddingBase, +) + + +class LateInteractionMultimodalEmbedding(LateInteractionMultimodalEmbeddingBase): + EMBEDDINGS_REGISTRY: list[Type[LateInteractionMultimodalEmbeddingBase]] = [ColPali] + + @classmethod + def list_supported_models(cls) -> list[dict[str, Any]]: + """ + Lists the supported models. + + Returns: + list[dict[str, Any]]: A list of dictionaries containing the model information. + + Example: + ``` + [ + { + "model": "AndrewOgn/colpali-v1.3-merged-onnx", + "dim": 128, + "description": "Text embeddings, Unimodal (text), Aligned to image latent space, ColBERT-compatible, 512 tokens max, 2024.", + "license": "mit", + "size_in_GB": 6.06, + "sources": { + "hf": "AndrewOgn/colpali-v1.3-merged-onnx", + }, + "additional_files": [ + "model.onnx_data", + ], + "model_file": "model.onnx", + }, + ] + ``` + """ + result = [] + for embedding in cls.EMBEDDINGS_REGISTRY: + result.extend(embedding.list_supported_models()) + return result + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + lazy_load: bool = False, + **kwargs, + ): + super().__init__(model_name, cache_dir, threads, **kwargs) + for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: + supported_models = EMBEDDING_MODEL_TYPE.list_supported_models() + if any(model_name.lower() == model["model"].lower() for model in supported_models): + self.model = EMBEDDING_MODEL_TYPE( + model_name, + cache_dir, + threads=threads, + providers=providers, + cuda=cuda, + device_ids=device_ids, + lazy_load=lazy_load, + **kwargs, + ) + return + + raise ValueError( + f"Model {model_name} is not supported in LateInteractionMultimodalEmbedding." + "Please check the supported models using `LateInteractionMultimodalEmbedding.list_supported_models()`" + ) + + def embed_text( + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: Optional[int] = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Encode a list of documents into list of embeddings. + + Args: + documents: Iterator of documents or single document to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per document + """ + yield from self.model.embed_text(documents, batch_size, parallel, **kwargs) + + def embed_image( + self, + images: Union[ImageInput, Iterable[ImageInput]], + batch_size: int = 16, + parallel: Optional[int] = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Encode a list of images into list of embeddings. + + Args: + images: Iterator of image paths or single image path to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per image + """ + yield from self.model.embed_image(images, batch_size, parallel, **kwargs) diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py new file mode 100644 index 00000000..5cfe45ba --- /dev/null +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -0,0 +1,66 @@ +from typing import Iterable, Optional, Union + +import numpy as np + +from fastembed.common import ImageInput +from fastembed.common.model_management import ModelManagement + + +class LateInteractionMultimodalEmbeddingBase(ModelManagement): + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + **kwargs, + ): + self.model_name = model_name + self.cache_dir = cache_dir + self.threads = threads + self._local_files_only = kwargs.pop("local_files_only", False) + + def embed_text( + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: Optional[int] = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Embeds a list of documents into a list of embeddings. + + Args: + documents (Iterable[str]): The list of texts to embed. + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + **kwargs: Additional keyword argument to pass to the embed method. + + Yields: + Iterable[np.ndarray]: The embeddings. + """ + raise NotImplementedError() + + def embed_image( + self, + images: Union[ImageInput, Iterable[ImageInput]], + batch_size: int = 16, + parallel: Optional[int] = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Encode a list of images into list of embeddings. + Args: + images: Iterator of image paths or single image path to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per image + """ + raise NotImplementedError() diff --git a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py new file mode 100644 index 00000000..75031cfa --- /dev/null +++ b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py @@ -0,0 +1,236 @@ +import contextlib +import os +from multiprocessing import get_all_start_methods +from pathlib import Path +from typing import Any, Iterable, Optional, Sequence, Type, Union, get_args + +import numpy as np +from PIL import Image +from tokenizers import Encoding + +from fastembed.common import OnnxProvider, ImageInput +from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T +from fastembed.common.preprocessor_utils import load_tokenizer, load_preprocessor +from fastembed.common.utils import iter_batch +from fastembed.parallel_processor import ParallelWorkerPool + + +class OnnxMultimodalModel(OnnxModel[T]): + ONNX_OUTPUT_NAMES: Optional[list[str]] = None + + def __init__(self) -> None: + super().__init__() + self.tokenizer = None + self.processor = None + self.special_token_to_id = {} + + def _preprocess_onnx_text_input( + self, onnx_input: dict[str, np.ndarray], **kwargs + ) -> dict[str, np.ndarray]: + """ + Preprocess the onnx input. + """ + return onnx_input + + def _preprocess_onnx_image_input( + self, onnx_input: dict[str, np.ndarray], **kwargs + ) -> dict[str, np.ndarray]: + """ + Preprocess the onnx input. + """ + return onnx_input + + @classmethod + def _get_text_worker_class(cls) -> Type["TextEmbeddingWorker"]: + raise NotImplementedError("Subclasses must implement this method") + + @classmethod + def _get_image_worker_class(cls) -> Type["ImageEmbeddingWorker"]: + raise NotImplementedError("Subclasses must implement this method") + + def _post_process_onnx_image_output(self, output: OnnxOutputContext) -> Iterable[T]: + raise NotImplementedError("Subclasses must implement this method") + + def _post_process_onnx_text_output(self, output: OnnxOutputContext) -> Iterable[T]: + raise NotImplementedError("Subclasses must implement this method") + + def _load_onnx_model( + self, + model_dir: Path, + model_file: str, + threads: Optional[int], + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_id: Optional[int] = None, + ) -> None: + super()._load_onnx_model( + model_dir=model_dir, + model_file=model_file, + threads=threads, + providers=providers, + cuda=cuda, + device_id=device_id, + ) + self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir) + self.processor = load_preprocessor(model_dir=model_dir) + + def load_onnx_model(self) -> None: + raise NotImplementedError("Subclasses must implement this method") + + def tokenize(self, documents: list[str], **kwargs) -> list[Encoding]: + return self.tokenizer.encode_batch(documents) + + def onnx_embed_text( + self, + documents: list[str], + **kwargs, + ) -> OnnxOutputContext: + encoded = self.tokenize(documents, **kwargs) + input_ids = np.array([e.ids for e in encoded]) + attention_mask = np.array([e.attention_mask for e in encoded]) + input_names = {node.name for node in self.model.get_inputs()} + onnx_input = { + "input_ids": np.array(input_ids, dtype=np.int64), + } + if "attention_mask" in input_names: + onnx_input["attention_mask"] = np.array(attention_mask, dtype=np.int64) + if "token_type_ids" in input_names: + onnx_input["token_type_ids"] = np.array( + [np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64 + ) + + onnx_input = self._preprocess_onnx_text_input(onnx_input, **kwargs) + model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) + return OnnxOutputContext( + model_output=model_output[0], + attention_mask=onnx_input.get("attention_mask", attention_mask), + input_ids=onnx_input.get("input_ids", input_ids), + ) + + def _embed_documents( + self, + model_name: str, + cache_dir: str, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + **kwargs, + ) -> Iterable[T]: + is_small = False + + if isinstance(documents, str): + documents = [documents] + is_small = True + + if isinstance(documents, list): + if len(documents) < batch_size: + is_small = True + + if parallel is None or is_small: + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() + for batch in iter_batch(documents, batch_size): + yield from self._post_process_onnx_text_output(self.onnx_embed_text(batch)) + else: + if parallel == 0: + parallel = os.cpu_count() + + start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" + params = { + "model_name": model_name, + "cache_dir": cache_dir, + "providers": providers, + **kwargs, + } + + pool = ParallelWorkerPool( + num_workers=parallel or 1, + worker=self._get_text_worker_class(), + cuda=cuda, + device_ids=device_ids, + start_method=start_method, + ) + for batch in pool.ordered_map(iter_batch(documents, batch_size), **params): + yield from self._post_process_onnx_text_output(batch) + + def _build_onnx_image_input(self, encoded: np.ndarray) -> dict[str, np.ndarray]: + return {node.name: encoded for node in self.model.get_inputs()} + + def onnx_embed_image(self, images: list[ImageInput], **kwargs) -> OnnxOutputContext: + with contextlib.ExitStack(): + image_files = [ + Image.open(image) if not isinstance(image, Image.Image) else image + for image in images + ] + encoded = self.processor(image_files) + onnx_input = self._build_onnx_image_input(encoded) + onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs) + model_output = self.model.run(None, onnx_input) + embeddings = model_output[0].reshape(len(images), -1) + return OnnxOutputContext(model_output=embeddings) + + def _embed_images( + self, + model_name: str, + cache_dir: str, + images: Union[Iterable[ImageInput], ImageInput], + batch_size: int = 256, + parallel: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + **kwargs, + ) -> Iterable[T]: + is_small = False + + if isinstance(images, get_args(ImageInput)): + images = [images] + is_small = True + + if isinstance(images, list) and len(images) < batch_size: + is_small = True + + if parallel is None or is_small: + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() + + for batch in iter_batch(images, batch_size): + yield from self._post_process_onnx_image_output(self.onnx_embed_image(batch)) + else: + if parallel == 0: + parallel = os.cpu_count() + + start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" + params = { + "model_name": model_name, + "cache_dir": cache_dir, + "providers": providers, + **kwargs, + } + + pool = ParallelWorkerPool( + num_workers=parallel or 1, + worker=self._get_image_worker_class(), + cuda=cuda, + device_ids=device_ids, + start_method=start_method, + ) + for batch in pool.ordered_map(iter_batch(images, batch_size), **params): + yield from self._post_process_onnx_image_output(batch) + + +class TextEmbeddingWorker(EmbeddingWorker): + def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]: + for idx, batch in items: + onnx_output = self.model.onnx_embed_text(batch) + yield idx, onnx_output + + +class ImageEmbeddingWorker(EmbeddingWorker): + def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]: + for idx, batch in items: + embeddings = self.model.onnx_embed_image(batch) + yield idx, embeddings diff --git a/tests/test_late_interaction_multimodal.py b/tests/test_late_interaction_multimodal.py new file mode 100644 index 00000000..0083ee98 --- /dev/null +++ b/tests/test_late_interaction_multimodal.py @@ -0,0 +1,84 @@ +import os + +from PIL import Image +import numpy as np + +from fastembed.late_interaction_multimodal import LateInteractionMultimodalEmbedding +from tests.config import TEST_MISC_DIR + + +# vectors are abridged and rounded for brevity +CANONICAL_IMAGE_VALUES = { + "Qdrant/colpali-v1.3-fp16": np.array( + [ + [ + [-0.0345, -0.022, 0.0567, -0.0518, -0.0782, 0.1714, -0.1738], + [-0.1181, -0.099, 0.0268, 0.0774, 0.0228, 0.0563, -0.1021], + [-0.117, -0.0683, 0.0371, 0.0921, 0.0107, 0.0659, -0.0666], + [-0.1393, -0.0948, 0.037, 0.0951, -0.0126, 0.0678, -0.087], + [-0.0957, -0.081, 0.0404, 0.052, 0.0409, 0.0335, -0.064], + [-0.0626, -0.0445, 0.056, 0.0592, -0.0229, 0.0409, -0.0301], + [-0.1299, -0.0691, 0.1097, 0.0728, 0.0123, 0.0519, 0.0122], + ] + ] + ), +} + +CANONICAL_QUERY_VALUES = { + "Qdrant/colpali-v1.3-fp16": np.array( + [ + [-0.0023, 0.1477, 0.1594, 0.046, -0.0196, 0.0554, 0.1567], + [-0.0139, -0.0057, 0.0932, 0.0052, -0.0678, 0.0131, 0.0537], + [0.0054, 0.0364, 0.2078, -0.074, 0.0355, 0.061, 0.1593], + [-0.0076, -0.0154, 0.2266, 0.0103, 0.0089, -0.024, 0.098], + [-0.0274, 0.0098, 0.2106, -0.0634, 0.0616, -0.0021, 0.0708], + [0.0074, 0.0025, 0.1631, -0.0802, 0.0418, -0.0219, 0.1022], + [-0.0165, -0.0106, 0.1672, -0.0768, 0.0389, -0.0038, 0.1137], + ] + ), +} + +queries = ["hello world", "flag embedding"] +images = [ + TEST_MISC_DIR / "image.jpeg", + str(TEST_MISC_DIR / "image.jpeg"), + Image.open((TEST_MISC_DIR / "image.jpeg")), +] + + +def test_batch_embedding(): + is_ci = os.getenv("CI") + + if not is_ci: + for model_name, expected_result in CANONICAL_IMAGE_VALUES.items(): + print("evaluating", model_name) + model = LateInteractionMultimodalEmbedding(model_name=model_name) + result = list(model.embed_image(images, batch_size=2)) + + for value in result: + batch_size, token_num, abridged_dim = expected_result.shape + assert np.allclose(value[:token_num, :abridged_dim], expected_result, atol=1e-3) + + +def test_single_embedding(): + is_ci = os.getenv("CI") + if not is_ci: + for model_name, expected_result in CANONICAL_IMAGE_VALUES.items(): + print("evaluating", model_name) + model = LateInteractionMultimodalEmbedding(model_name=model_name) + result = next(iter(model.embed_image(images, batch_size=6))) + batch_size, token_num, abridged_dim = expected_result.shape + assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3) + + +def test_single_embedding_query(): + is_ci = os.getenv("CI") + if not is_ci: + queries_to_embed = queries + + for model_name, expected_result in CANONICAL_QUERY_VALUES.items(): + print("evaluating", model_name) + model = LateInteractionMultimodalEmbedding(model_name=model_name) + result = next(iter(model.embed_text(queries_to_embed))) + token_num, abridged_dim = expected_result.shape + assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3)