Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: type hints for colpali #469

Merged
merged 4 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import onnxruntime as ort

from numpy.typing import NDArray
from tokenizers import Tokenizer

from fastembed.common.types import OnnxProvider, NumpyArray
from fastembed.parallel_processor import Worker
Expand All @@ -31,8 +32,8 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
raise NotImplementedError("Subclasses must implement this method")

def __init__(self) -> None:
self.model = None
self.tokenizer = None
self.model: Optional[ort.InferenceSession] = None
self.tokenizer: Optional[Tokenizer] = None

def _preprocess_onnx_input(
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
Expand Down
4 changes: 2 additions & 2 deletions fastembed/image/onnx_image_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def load_onnx_model(self) -> None:
raise NotImplementedError("Subclasses must implement this method")

def _build_onnx_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]:
input_name = self.model.get_inputs()[0].name # type: ignore
input_name = self.model.get_inputs()[0].name # type: ignore[union-attr]
return {input_name: encoded}

def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
Expand All @@ -74,7 +74,7 @@ def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputConte
encoded = np.array(self.processor(image_files))
onnx_input = self._build_onnx_input(encoded)
onnx_input = self._preprocess_onnx_input(onnx_input)
model_output = self.model.run(None, onnx_input) # type: ignore
model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]
embeddings = model_output[0].reshape(len(images), -1)
return OnnxOutputContext(model_output=embeddings)

Expand Down
4 changes: 1 addition & 3 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def tokenize(self, documents: list[str], is_doc: bool = True, **kwargs: Any) ->

def _tokenize_query(self, query: str) -> list[Encoding]:
assert self.tokenizer is not None

encoded = self.tokenizer.encode_batch([query])
# colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance
if len(encoded[0].ids) < self.MIN_QUERY_LENGTH:
Expand All @@ -109,8 +108,7 @@ def _tokenize_query(self, query: str) -> list[Encoding]:
return encoded

def _tokenize_documents(self, documents: list[str]) -> list[Encoding]:
assert self.tokenizer is not None
encoded = self.tokenizer.encode_batch(documents)
encoded = self.tokenizer.encode_batch(documents) # type: ignore[union-attr]
return encoded

@classmethod
Expand Down
62 changes: 32 additions & 30 deletions fastembed/late_interaction_multimodal/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from fastembed.common import OnnxProvider, ImageInput
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.types import NumpyArray
from fastembed.common.utils import define_cache_dir
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
LateInteractionMultimodalEmbeddingBase,
Expand Down Expand Up @@ -33,7 +34,7 @@
]


class ColPali(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[np.ndarray]):
class ColPali(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[NumpyArray]):
QUERY_PREFIX = "Query: "
BOS_TOKEN = "<s>"
PAD_TOKEN = "<pad>"
Expand All @@ -56,7 +57,7 @@ def __init__(
lazy_load: bool = False,
device_id: Optional[int] = None,
specific_model_path: Optional[str] = None,
**kwargs,
**kwargs: Any,
):
"""
Args:
Expand Down Expand Up @@ -88,15 +89,14 @@ def __init__(
self.cuda = cuda

# This device_id will be used if we need to load model in current process
self.device_id: Optional[int] = None
if device_id is not None:
self.device_id = device_id
elif self.device_ids is not None:
self.device_id = self.device_ids[0]
else:
self.device_id = None

self.model_description = self._get_model_description(model_name)
self.cache_dir = define_cache_dir(cache_dir)
self.cache_dir = str(define_cache_dir(cache_dir))

self._model_dir = self.download_model(
self.model_description,
Expand Down Expand Up @@ -132,15 +132,15 @@ def load_onnx_model(self) -> None:
def _post_process_onnx_image_output(
self,
output: OnnxOutputContext,
) -> Iterable[np.ndarray]:
) -> Iterable[NumpyArray]:
"""
Post-process the ONNX model output to convert it into a usable format.

Args:
output (OnnxOutputContext): The raw output from the ONNX model.

Returns:
Iterable[np.ndarray]: Post-processed output as NumPy arrays.
Iterable[NumpyArray]: Post-processed output as NumPy arrays.
"""
return output.model_output.reshape(
output.model_output.shape[0], -1, self.model_description["dim"]
Expand All @@ -149,53 +149,55 @@ def _post_process_onnx_image_output(
def _post_process_onnx_text_output(
self,
output: OnnxOutputContext,
) -> Iterable[np.ndarray]:
) -> Iterable[NumpyArray]:
"""
Post-process the ONNX model output to convert it into a usable format.

Args:
output (OnnxOutputContext): The raw output from the ONNX model.

Returns:
Iterable[np.ndarray]: Post-processed output as NumPy arrays.
Iterable[NumpyArray]: Post-processed output as NumPy arrays.
"""
return output.model_output.astype(np.float32)

def tokenize(self, documents: list[str], **_) -> list[Encoding]:
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
texts_query: list[str] = []
for query in documents:
query = self.BOS_TOKEN + self.QUERY_PREFIX + query + self.PAD_TOKEN * 10
query += "\n"

texts_query.append(query)
encoded = self.tokenizer.encode_batch(texts_query)
encoded = self.tokenizer.encode_batch(texts_query) # type: ignore[union-attr]
return encoded

def _preprocess_onnx_text_input(
self, onnx_input: dict[str, np.ndarray], **kwargs
) -> dict[str, np.ndarray]:
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
) -> dict[str, NumpyArray]:
onnx_input["input_ids"] = np.array(
[
self.QUERY_MARKER_TOKEN_ID + input_ids[2:].tolist()
for input_ids in onnx_input["input_ids"]
]
)
empty_image_placeholder = np.zeros(self.IMAGE_PLACEHOLDER_SIZE, dtype=np.float32)
empty_image_placeholder: NumpyArray = np.zeros(
self.IMAGE_PLACEHOLDER_SIZE, dtype=np.float32
)
onnx_input["pixel_values"] = np.array(
[empty_image_placeholder for _ in onnx_input["input_ids"]]
[empty_image_placeholder for _ in onnx_input["input_ids"]],
)
return onnx_input

def _preprocess_onnx_image_input(
self, onnx_input: dict[str, np.ndarray], **kwargs
) -> dict[str, np.ndarray]:
self, onnx_input: dict[str, np.ndarray], **kwargs: Any
) -> dict[str, NumpyArray]:
"""
Add placeholders for text input when processing image data for ONNX.
Args:
onnx_input (Dict[str, np.ndarray]): Preprocessed image inputs.
onnx_input (Dict[str, NumpyArray]): Preprocessed image inputs.
**kwargs: Additional arguments.
Returns:
Dict[str, np.ndarray]: ONNX input with text placeholders.
Dict[str, NumpyArray]: ONNX input with text placeholders.
"""

onnx_input["input_ids"] = np.array(
Expand All @@ -211,8 +213,8 @@ def embed_text(
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
**kwargs: Any,
) -> Iterable[NumpyArray]:
"""
Encode a list of documents into list of embeddings.

Expand Down Expand Up @@ -241,11 +243,11 @@ def embed_text(

def embed_image(
self,
images: ImageInput,
images: Union[ImageInput, Iterable[ImageInput]],
batch_size: int = 16,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
**kwargs: Any,
) -> Iterable[NumpyArray]:
"""
Encode a list of images into list of embeddings.

Expand Down Expand Up @@ -273,16 +275,16 @@ def embed_image(
)

@classmethod
def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker]:
def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]:
return ColPaliTextEmbeddingWorker

@classmethod
def _get_image_worker_class(cls) -> Type[ImageEmbeddingWorker]:
def _get_image_worker_class(cls) -> Type[ImageEmbeddingWorker[NumpyArray]]:
return ColPaliImageEmbeddingWorker


class ColPaliTextEmbeddingWorker(TextEmbeddingWorker):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali:
class ColPaliTextEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> ColPali:
return ColPali(
model_name=model_name,
cache_dir=cache_dir,
Expand All @@ -291,8 +293,8 @@ def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali:
)


class ColPaliImageEmbeddingWorker(ImageEmbeddingWorker):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali:
class ColPaliImageEmbeddingWorker(ImageEmbeddingWorker[NumpyArray]):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> ColPali:
return ColPali(
model_name=model_name,
cache_dir=cache_dir,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np

from fastembed.common import OnnxProvider, ImageInput
from fastembed.common.types import NumpyArray
from fastembed.late_interaction_multimodal.colpali import ColPali

from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
Expand Down Expand Up @@ -41,7 +40,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
]
```
"""
result = []
result: list[dict[str, Any]] = []
for embedding in cls.EMBEDDINGS_REGISTRY:
result.extend(embedding.list_supported_models())
return result
Expand All @@ -55,7 +54,7 @@ def __init__(
cuda: bool = False,
device_ids: Optional[list[int]] = None,
lazy_load: bool = False,
**kwargs,
**kwargs: Any,
):
super().__init__(model_name, cache_dir, threads, **kwargs)
for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
Expand Down Expand Up @@ -83,8 +82,8 @@ def embed_text(
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
**kwargs: Any,
) -> Iterable[NumpyArray]:
"""
Encode a list of documents into list of embeddings.

Expand All @@ -106,8 +105,8 @@ def embed_image(
images: Union[ImageInput, Iterable[ImageInput]],
batch_size: int = 16,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
**kwargs: Any,
) -> Iterable[NumpyArray]:
"""
Encode a list of images into list of embeddings.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Iterable, Optional, Union
from typing import Iterable, Optional, Union, Any

import numpy as np

from fastembed.common import ImageInput
from fastembed.common.model_management import ModelManagement
from fastembed.common.types import NumpyArray


class LateInteractionMultimodalEmbeddingBase(ModelManagement):
Expand All @@ -12,7 +12,7 @@ def __init__(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs,
**kwargs: Any,
):
self.model_name = model_name
self.cache_dir = cache_dir
Expand All @@ -24,8 +24,8 @@ def embed_text(
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
**kwargs: Any,
) -> Iterable[NumpyArray]:
"""
Embeds a list of documents into a list of embeddings.

Expand All @@ -39,7 +39,7 @@ def embed_text(
**kwargs: Additional keyword argument to pass to the embed method.

Yields:
Iterable[np.ndarray]: The embeddings.
Iterable[NumpyArray]: The embeddings.
"""
raise NotImplementedError()

Expand All @@ -48,8 +48,8 @@ def embed_image(
images: Union[ImageInput, Iterable[ImageInput]],
batch_size: int = 16,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
**kwargs: Any,
) -> Iterable[NumpyArray]:
"""
Encode a list of images into list of embeddings.
Args:
Expand Down
Loading