Skip to content

Commit

Permalink
colpali v1.3 by AndrewOgn (#427)
Browse files Browse the repository at this point in the history
* wip: design draft

* Operators fix

* Fix model inputs

* Import from fastembed.late_interaction_multimodal

* Fixed method misspelling

* Tests, which do not run in CI
Docstring improvements

* Fix tests

* Bump colpali to version v1.3

* Remove colpali v1.2

* Remove colpali v1.2 from tests

* partial fix of change requests:
descriptions
docs
black

* query_max_length

* black colpali

* Added comment for EMPTY_TEXT_PLACEHOLDER

* Review fixes

* Removed redundant VISUAL_PROMPT_PREFIX

* type fix + model info

* new: add specific model path to colpali

* fix: revert accidental renaming

* fix: remove max_length from encode_batch

* refactoring: remove redundant QUERY_MAX_LENGTH variable

* refactoring: remove redundant document marker token id

* fix: fix type hints, fix tests, handle single image path embed, rename model, update description

* license: add gemma to NOTICE

* fix: do not run colpali test in ci

* fix: fix colpali test

---------

Co-authored-by: d.rudenko <dmitrii.rudenko@qdrant.com>
  • Loading branch information
joein and I8dNLo authored Feb 6, 2025
1 parent 969ea29 commit 2fe33c5
Show file tree
Hide file tree
Showing 9 changed files with 826 additions and 4 deletions.
6 changes: 6 additions & 0 deletions NOTICE
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,11 @@ This distribution includes the following Jina AI models, each with its respectiv

These models are developed by Jina (https://jina.ai/) and are subject to Jina AI's licensing terms.

This distribution includes the following Google models, each with its respective license:
- vidore/colpali-v1.3
- License: gemma

Gemma is provided under and subject to the Gemma Terms of Use found at ai.google.dev/gemma/terms

Additional Notes:
This project also includes third-party libraries with their respective licenses. Please refer to the documentation of each library for details regarding its usage and licensing terms.
3 changes: 1 addition & 2 deletions fastembed/image/image_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ def embed(
**kwargs: Any,
) -> Iterable[NumpyArray]:
"""
Encode a list of documents into list of embeddings.
We use mean pooling with attention so that the model can handle variable-length inputs.
Encode a list of images into list of embeddings.
Args:
images: Iterator of image paths or single image path to embed
Expand Down
4 changes: 2 additions & 2 deletions fastembed/image/transform/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _get_convert_to_rgb(transforms: list[Transform], config: dict[str, Any]) ->
@classmethod
def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]) -> None:
mode = config.get("image_processor_type", "CLIPImageProcessor")
if mode == "CLIPImageProcessor":
if mode in ("CLIPImageProcessor", "SiglipImageProcessor"):
if config.get("do_resize", False):
size = config["size"]
if "shortest_edge" in size:
Expand Down Expand Up @@ -202,7 +202,7 @@ def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]) -> Non
@staticmethod
def _get_center_crop(transforms: list[Transform], config: dict[str, Any]) -> None:
mode = config.get("image_processor_type", "CLIPImageProcessor")
if mode == "CLIPImageProcessor":
if mode in ("CLIPImageProcessor", "SiglipImageProcessor"):
if config.get("do_center_crop", False):
crop_size_raw = config["crop_size"]
crop_size: tuple[int, int]
Expand Down
5 changes: 5 additions & 0 deletions fastembed/late_interaction_multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding import (
LateInteractionMultimodalEmbedding,
)

__all__ = ["LateInteractionMultimodalEmbedding"]
301 changes: 301 additions & 0 deletions fastembed/late_interaction_multimodal/colpali.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np
from tokenizers import Encoding

from fastembed.common import OnnxProvider, ImageInput
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.utils import define_cache_dir
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
LateInteractionMultimodalEmbeddingBase,
)
from fastembed.late_interaction_multimodal.onnx_multimodal_model import (
OnnxMultimodalModel,
TextEmbeddingWorker,
ImageEmbeddingWorker,
)

supported_colpali_models = [
{
"model": "Qdrant/colpali-v1.3-fp16",
"dim": 128,
"description": "Text embeddings, Multimodal (text&image), English, 50 tokens query length truncation, 2024.",
"license": "mit",
"size_in_GB": 6.5,
"sources": {
"hf": "Qdrant/colpali-v1.3-fp16",
},
"additional_files": [
"model.onnx_data",
],
"model_file": "model.onnx",
},
]


class ColPali(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[np.ndarray]):
QUERY_PREFIX = "Query: "
BOS_TOKEN = "<s>"
PAD_TOKEN = "<pad>"
QUERY_MARKER_TOKEN_ID = [2, 5098]
IMAGE_PLACEHOLDER_SIZE = (3, 448, 448)
EMPTY_TEXT_PLACEHOLDER = np.array(
[257152] * 1024 + [2, 50721, 573, 2416, 235265, 108]
) # This is a tokenization of '<image>' * 1024 + '<bos>Describe the image.\n' line which is used as placeholder
# while processing an image
EVEN_ATTENTION_MASK = np.array([1] * 1030)

def __init__(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
lazy_load: bool = False,
device_id: Optional[int] = None,
specific_model_path: Optional[str] = None,
**kwargs,
):
"""
Args:
model_name (str): The name of the model to use.
cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
Defaults to False.
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
"""

super().__init__(model_name, cache_dir, threads, **kwargs)
self.providers = providers
self.lazy_load = lazy_load

# List of device ids, that can be used for data parallel processing in workers
self.device_ids = device_ids
self.cuda = cuda

# This device_id will be used if we need to load model in current process
if device_id is not None:
self.device_id = device_id
elif self.device_ids is not None:
self.device_id = self.device_ids[0]
else:
self.device_id = None

self.model_description = self._get_model_description(model_name)
self.cache_dir = define_cache_dir(cache_dir)

self._model_dir = self.download_model(
self.model_description,
self.cache_dir,
local_files_only=self._local_files_only,
specific_model_path=specific_model_path,
)
self.mask_token_id = None
self.pad_token_id = None

if not self.lazy_load:
self.load_onnx_model()

@classmethod
def list_supported_models(cls) -> list[dict[str, Any]]:
"""Lists the supported models.
Returns:
list[dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_colpali_models

def load_onnx_model(self) -> None:
self._load_onnx_model(
model_dir=self._model_dir,
model_file=self.model_description["model_file"],
threads=self.threads,
providers=self.providers,
cuda=self.cuda,
device_id=self.device_id,
)

def _post_process_onnx_image_output(
self,
output: OnnxOutputContext,
) -> Iterable[np.ndarray]:
"""
Post-process the ONNX model output to convert it into a usable format.
Args:
output (OnnxOutputContext): The raw output from the ONNX model.
Returns:
Iterable[np.ndarray]: Post-processed output as NumPy arrays.
"""
return output.model_output.reshape(
output.model_output.shape[0], -1, self.model_description["dim"]
).astype(np.float32)

def _post_process_onnx_text_output(
self,
output: OnnxOutputContext,
) -> Iterable[np.ndarray]:
"""
Post-process the ONNX model output to convert it into a usable format.
Args:
output (OnnxOutputContext): The raw output from the ONNX model.
Returns:
Iterable[np.ndarray]: Post-processed output as NumPy arrays.
"""
return output.model_output.astype(np.float32)

def tokenize(self, documents: list[str], **_) -> list[Encoding]:
texts_query: list[str] = []
for query in documents:
query = self.BOS_TOKEN + self.QUERY_PREFIX + query + self.PAD_TOKEN * 10
query += "\n"

texts_query.append(query)
encoded = self.tokenizer.encode_batch(texts_query)
return encoded

def _preprocess_onnx_text_input(
self, onnx_input: dict[str, np.ndarray], **kwargs
) -> dict[str, np.ndarray]:
onnx_input["input_ids"] = np.array(
[
self.QUERY_MARKER_TOKEN_ID + input_ids[2:].tolist()
for input_ids in onnx_input["input_ids"]
]
)
empty_image_placeholder = np.zeros(self.IMAGE_PLACEHOLDER_SIZE, dtype=np.float32)
onnx_input["pixel_values"] = np.array(
[empty_image_placeholder for _ in onnx_input["input_ids"]]
)
return onnx_input

def _preprocess_onnx_image_input(
self, onnx_input: dict[str, np.ndarray], **kwargs
) -> dict[str, np.ndarray]:
"""
Add placeholders for text input when processing image data for ONNX.
Args:
onnx_input (Dict[str, np.ndarray]): Preprocessed image inputs.
**kwargs: Additional arguments.
Returns:
Dict[str, np.ndarray]: ONNX input with text placeholders.
"""

onnx_input["input_ids"] = np.array(
[self.EMPTY_TEXT_PLACEHOLDER for _ in onnx_input["input_ids"]]
)
onnx_input["attention_mask"] = np.array(
[self.EVEN_ATTENTION_MASK for _ in onnx_input["input_ids"]]
)
return onnx_input

def embed_text(
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
"""
Encode a list of documents into list of embeddings.
Args:
documents: Iterator of documents or single document to embed
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
Returns:
List of embeddings, one per document
"""
yield from self._embed_documents(
model_name=self.model_name,
cache_dir=str(self.cache_dir),
documents=documents,
batch_size=batch_size,
parallel=parallel,
providers=self.providers,
cuda=self.cuda,
device_ids=self.device_ids,
**kwargs,
)

def embed_image(
self,
images: ImageInput,
batch_size: int = 16,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
"""
Encode a list of images into list of embeddings.
Args:
images: Iterator of image paths or single image path to embed
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
Returns:
List of embeddings, one per document
"""
yield from self._embed_images(
model_name=self.model_name,
cache_dir=str(self.cache_dir),
images=images,
batch_size=batch_size,
parallel=parallel,
providers=self.providers,
cuda=self.cuda,
device_ids=self.device_ids,
**kwargs,
)

@classmethod
def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker]:
return ColPaliTextEmbeddingWorker

@classmethod
def _get_image_worker_class(cls) -> Type[ImageEmbeddingWorker]:
return ColPaliImageEmbeddingWorker


class ColPaliTextEmbeddingWorker(TextEmbeddingWorker):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali:
return ColPali(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)


class ColPaliImageEmbeddingWorker(ImageEmbeddingWorker):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali:
return ColPali(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)
Loading

0 comments on commit 2fe33c5

Please sign in to comment.