Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

colpali v1.3 by AndrewOgn #427

Merged
merged 26 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
cdfa64e
wip: design draft
joein Dec 18, 2024
3ae11ae
Operators fix
I8dNLo Dec 19, 2024
287639f
Fix model inputs
I8dNLo Dec 20, 2024
bcc2685
Import from fastembed.late_interaction_multimodal
I8dNLo Dec 20, 2024
27082d9
Fixed method misspelling
I8dNLo Dec 20, 2024
de95e4b
Tests, which do not run in CI
I8dNLo Dec 23, 2024
cc28ee0
Fix tests
I8dNLo Dec 27, 2024
9b46f4d
Bump colpali to version v1.3
I8dNLo Jan 13, 2025
e738d50
Remove colpali v1.2
I8dNLo Jan 13, 2025
e2628ce
Remove colpali v1.2 from tests
I8dNLo Jan 13, 2025
c89b93d
partial fix of change requests:
I8dNLo Jan 13, 2025
4a283bb
query_max_length
I8dNLo Jan 13, 2025
afd0831
black colpali
I8dNLo Jan 15, 2025
d9b6bd7
Added comment for EMPTY_TEXT_PLACEHOLDER
I8dNLo Jan 16, 2025
68195a6
Review fixes
I8dNLo Jan 16, 2025
fe55cd4
Removed redundant VISUAL_PROMPT_PREFIX
I8dNLo Jan 20, 2025
b7dd679
type fix + model info
I8dNLo Jan 23, 2025
363dfef
new: add specific model path to colpali
joein Feb 2, 2025
f32a132
fix: revert accidental renaming
joein Feb 3, 2025
8f4e9f8
fix: remove max_length from encode_batch
joein Feb 6, 2025
c0fedd7
refactoring: remove redundant QUERY_MAX_LENGTH variable
joein Feb 6, 2025
9062cfe
refactoring: remove redundant document marker token id
joein Feb 6, 2025
13f7a82
fix: fix type hints, fix tests, handle single image path embed, renam…
joein Feb 6, 2025
31bb5f4
license: add gemma to NOTICE
joein Feb 6, 2025
837ed40
fix: do not run colpali test in ci
joein Feb 6, 2025
7ebe90d
fix: fix colpali test
joein Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions NOTICE
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,11 @@ This distribution includes the following Jina AI models, each with its respectiv

These models are developed by Jina (https://jina.ai/) and are subject to Jina AI's licensing terms.

This distribution includes the following Google models, each with its respective license:
- vidore/colpali-v1.3
- License: gemma

Gemma is provided under and subject to the Gemma Terms of Use found at ai.google.dev/gemma/terms

Additional Notes:
This project also includes third-party libraries with their respective licenses. Please refer to the documentation of each library for details regarding its usage and licensing terms.
3 changes: 1 addition & 2 deletions fastembed/image/image_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ def embed(
**kwargs: Any,
) -> Iterable[NumpyArray]:
"""
Encode a list of documents into list of embeddings.
We use mean pooling with attention so that the model can handle variable-length inputs.
Encode a list of images into list of embeddings.

Args:
images: Iterator of image paths or single image path to embed
Expand Down
4 changes: 2 additions & 2 deletions fastembed/image/transform/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _get_convert_to_rgb(transforms: list[Transform], config: dict[str, Any]) ->
@classmethod
def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]) -> None:
mode = config.get("image_processor_type", "CLIPImageProcessor")
if mode == "CLIPImageProcessor":
if mode in ("CLIPImageProcessor", "SiglipImageProcessor"):
if config.get("do_resize", False):
size = config["size"]
if "shortest_edge" in size:
Expand Down Expand Up @@ -202,7 +202,7 @@ def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]) -> Non
@staticmethod
def _get_center_crop(transforms: list[Transform], config: dict[str, Any]) -> None:
mode = config.get("image_processor_type", "CLIPImageProcessor")
if mode == "CLIPImageProcessor":
if mode in ("CLIPImageProcessor", "SiglipImageProcessor"):
if config.get("do_center_crop", False):
crop_size_raw = config["crop_size"]
crop_size: tuple[int, int]
Expand Down
5 changes: 5 additions & 0 deletions fastembed/late_interaction_multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding import (
LateInteractionMultimodalEmbedding,
)

__all__ = ["LateInteractionMultimodalEmbedding"]
301 changes: 301 additions & 0 deletions fastembed/late_interaction_multimodal/colpali.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np
from tokenizers import Encoding

from fastembed.common import OnnxProvider, ImageInput
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.utils import define_cache_dir
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
LateInteractionMultimodalEmbeddingBase,
)
from fastembed.late_interaction_multimodal.onnx_multimodal_model import (
OnnxMultimodalModel,
TextEmbeddingWorker,
ImageEmbeddingWorker,
)

supported_colpali_models = [
{
"model": "Qdrant/colpali-v1.3-fp16",
"dim": 128,
"description": "Text embeddings, Multimodal (text&image), English, 50 tokens query length truncation, 2024.",
"license": "mit",
"size_in_GB": 6.5,
"sources": {
"hf": "Qdrant/colpali-v1.3-fp16",
},
"additional_files": [
"model.onnx_data",
],
"model_file": "model.onnx",
},
]


class ColPali(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[np.ndarray]):
QUERY_PREFIX = "Query: "
BOS_TOKEN = "<s>"
PAD_TOKEN = "<pad>"
QUERY_MARKER_TOKEN_ID = [2, 5098]
IMAGE_PLACEHOLDER_SIZE = (3, 448, 448)
EMPTY_TEXT_PLACEHOLDER = np.array(
[257152] * 1024 + [2, 50721, 573, 2416, 235265, 108]
) # This is a tokenization of '<image>' * 1024 + '<bos>Describe the image.\n' line which is used as placeholder
# while processing an image
EVEN_ATTENTION_MASK = np.array([1] * 1030)

def __init__(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
lazy_load: bool = False,
device_id: Optional[int] = None,
specific_model_path: Optional[str] = None,
**kwargs,
):
"""
Args:
model_name (str): The name of the model to use.
cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
Defaults to False.
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.

Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
"""

super().__init__(model_name, cache_dir, threads, **kwargs)
self.providers = providers
self.lazy_load = lazy_load

# List of device ids, that can be used for data parallel processing in workers
self.device_ids = device_ids
self.cuda = cuda

# This device_id will be used if we need to load model in current process
if device_id is not None:
self.device_id = device_id
elif self.device_ids is not None:
self.device_id = self.device_ids[0]
else:
self.device_id = None

self.model_description = self._get_model_description(model_name)
self.cache_dir = define_cache_dir(cache_dir)

self._model_dir = self.download_model(
self.model_description,
self.cache_dir,
local_files_only=self._local_files_only,
specific_model_path=specific_model_path,
)
self.mask_token_id = None
self.pad_token_id = None

if not self.lazy_load:
self.load_onnx_model()

@classmethod
def list_supported_models(cls) -> list[dict[str, Any]]:
"""Lists the supported models.

Returns:
list[dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_colpali_models

def load_onnx_model(self) -> None:
self._load_onnx_model(
model_dir=self._model_dir,
model_file=self.model_description["model_file"],
threads=self.threads,
providers=self.providers,
cuda=self.cuda,
device_id=self.device_id,
)

def _post_process_onnx_image_output(
self,
output: OnnxOutputContext,
) -> Iterable[np.ndarray]:
"""
Post-process the ONNX model output to convert it into a usable format.

Args:
output (OnnxOutputContext): The raw output from the ONNX model.

Returns:
Iterable[np.ndarray]: Post-processed output as NumPy arrays.
"""
return output.model_output.reshape(
output.model_output.shape[0], -1, self.model_description["dim"]
).astype(np.float32)

def _post_process_onnx_text_output(
self,
output: OnnxOutputContext,
) -> Iterable[np.ndarray]:
"""
Post-process the ONNX model output to convert it into a usable format.

Args:
output (OnnxOutputContext): The raw output from the ONNX model.

Returns:
Iterable[np.ndarray]: Post-processed output as NumPy arrays.
"""
return output.model_output.astype(np.float32)

def tokenize(self, documents: list[str], **_) -> list[Encoding]:
texts_query: list[str] = []
for query in documents:
query = self.BOS_TOKEN + self.QUERY_PREFIX + query + self.PAD_TOKEN * 10
query += "\n"

texts_query.append(query)
encoded = self.tokenizer.encode_batch(texts_query)
return encoded

def _preprocess_onnx_text_input(
self, onnx_input: dict[str, np.ndarray], **kwargs
) -> dict[str, np.ndarray]:
onnx_input["input_ids"] = np.array(
[
self.QUERY_MARKER_TOKEN_ID + input_ids[2:].tolist()
for input_ids in onnx_input["input_ids"]
]
)
empty_image_placeholder = np.zeros(self.IMAGE_PLACEHOLDER_SIZE, dtype=np.float32)
onnx_input["pixel_values"] = np.array(
[empty_image_placeholder for _ in onnx_input["input_ids"]]
)
return onnx_input

def _preprocess_onnx_image_input(
self, onnx_input: dict[str, np.ndarray], **kwargs
) -> dict[str, np.ndarray]:
"""
Add placeholders for text input when processing image data for ONNX.
Args:
onnx_input (Dict[str, np.ndarray]): Preprocessed image inputs.
**kwargs: Additional arguments.
Returns:
Dict[str, np.ndarray]: ONNX input with text placeholders.
"""

onnx_input["input_ids"] = np.array(
[self.EMPTY_TEXT_PLACEHOLDER for _ in onnx_input["input_ids"]]
)
onnx_input["attention_mask"] = np.array(
[self.EVEN_ATTENTION_MASK for _ in onnx_input["input_ids"]]
)
return onnx_input

def embed_text(
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
"""
Encode a list of documents into list of embeddings.

Args:
documents: Iterator of documents or single document to embed
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.

Returns:
List of embeddings, one per document
"""
yield from self._embed_documents(
model_name=self.model_name,
cache_dir=str(self.cache_dir),
documents=documents,
batch_size=batch_size,
parallel=parallel,
providers=self.providers,
cuda=self.cuda,
device_ids=self.device_ids,
**kwargs,
)

def embed_image(
self,
images: ImageInput,
batch_size: int = 16,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
"""
Encode a list of images into list of embeddings.

Args:
images: Iterator of image paths or single image path to embed
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.

Returns:
List of embeddings, one per document
"""
yield from self._embed_images(
model_name=self.model_name,
cache_dir=str(self.cache_dir),
images=images,
batch_size=batch_size,
parallel=parallel,
providers=self.providers,
cuda=self.cuda,
device_ids=self.device_ids,
**kwargs,
)

@classmethod
def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker]:
return ColPaliTextEmbeddingWorker

@classmethod
def _get_image_worker_class(cls) -> Type[ImageEmbeddingWorker]:
return ColPaliImageEmbeddingWorker


class ColPaliTextEmbeddingWorker(TextEmbeddingWorker):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali:
return ColPali(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)


class ColPaliImageEmbeddingWorker(ImageEmbeddingWorker):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali:
return ColPali(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)
Loading