Skip to content

Commit

Permalink
wip: type hints for colpali (#469)
Browse files Browse the repository at this point in the history
* wip: type hints for colpali

* new: Add colpali type hints

* refactor: Remove redundant type ignore

* fix: address remaining mypy issues

---------

Co-authored-by: hh-space-invader <h.hagag.ali@gmail.com>
  • Loading branch information
joein and hh-space-invader authored Feb 6, 2025
1 parent 0fa1596 commit 4599b93
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 96 deletions.
5 changes: 3 additions & 2 deletions fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import onnxruntime as ort

from numpy.typing import NDArray
from tokenizers import Tokenizer

from fastembed.common.types import OnnxProvider, NumpyArray
from fastembed.parallel_processor import Worker
Expand All @@ -31,8 +32,8 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
raise NotImplementedError("Subclasses must implement this method")

def __init__(self) -> None:
self.model = None
self.tokenizer = None
self.model: Optional[ort.InferenceSession] = None
self.tokenizer: Optional[Tokenizer] = None

def _preprocess_onnx_input(
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
Expand Down
4 changes: 2 additions & 2 deletions fastembed/image/onnx_image_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def load_onnx_model(self) -> None:
raise NotImplementedError("Subclasses must implement this method")

def _build_onnx_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]:
input_name = self.model.get_inputs()[0].name # type: ignore
input_name = self.model.get_inputs()[0].name # type: ignore[union-attr]
return {input_name: encoded}

def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
Expand All @@ -74,7 +74,7 @@ def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputConte
encoded = np.array(self.processor(image_files))
onnx_input = self._build_onnx_input(encoded)
onnx_input = self._preprocess_onnx_input(onnx_input)
model_output = self.model.run(None, onnx_input) # type: ignore
model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]
embeddings = model_output[0].reshape(len(images), -1)
return OnnxOutputContext(model_output=embeddings)

Expand Down
4 changes: 1 addition & 3 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def tokenize(self, documents: list[str], is_doc: bool = True, **kwargs: Any) ->

def _tokenize_query(self, query: str) -> list[Encoding]:
assert self.tokenizer is not None

encoded = self.tokenizer.encode_batch([query])
# colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance
if len(encoded[0].ids) < self.MIN_QUERY_LENGTH:
Expand All @@ -109,8 +108,7 @@ def _tokenize_query(self, query: str) -> list[Encoding]:
return encoded

def _tokenize_documents(self, documents: list[str]) -> list[Encoding]:
assert self.tokenizer is not None
encoded = self.tokenizer.encode_batch(documents)
encoded = self.tokenizer.encode_batch(documents) # type: ignore[union-attr]
return encoded

@classmethod
Expand Down
62 changes: 32 additions & 30 deletions fastembed/late_interaction_multimodal/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from fastembed.common import OnnxProvider, ImageInput
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.types import NumpyArray
from fastembed.common.utils import define_cache_dir
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
LateInteractionMultimodalEmbeddingBase,
Expand Down Expand Up @@ -33,7 +34,7 @@
]


class ColPali(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[np.ndarray]):
class ColPali(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[NumpyArray]):
QUERY_PREFIX = "Query: "
BOS_TOKEN = "<s>"
PAD_TOKEN = "<pad>"
Expand All @@ -56,7 +57,7 @@ def __init__(
lazy_load: bool = False,
device_id: Optional[int] = None,
specific_model_path: Optional[str] = None,
**kwargs,
**kwargs: Any,
):
"""
Args:
Expand Down Expand Up @@ -88,15 +89,14 @@ def __init__(
self.cuda = cuda

# This device_id will be used if we need to load model in current process
self.device_id: Optional[int] = None
if device_id is not None:
self.device_id = device_id
elif self.device_ids is not None:
self.device_id = self.device_ids[0]
else:
self.device_id = None

self.model_description = self._get_model_description(model_name)
self.cache_dir = define_cache_dir(cache_dir)
self.cache_dir = str(define_cache_dir(cache_dir))

self._model_dir = self.download_model(
self.model_description,
Expand Down Expand Up @@ -132,15 +132,15 @@ def load_onnx_model(self) -> None:
def _post_process_onnx_image_output(
self,
output: OnnxOutputContext,
) -> Iterable[np.ndarray]:
) -> Iterable[NumpyArray]:
"""
Post-process the ONNX model output to convert it into a usable format.
Args:
output (OnnxOutputContext): The raw output from the ONNX model.
Returns:
Iterable[np.ndarray]: Post-processed output as NumPy arrays.
Iterable[NumpyArray]: Post-processed output as NumPy arrays.
"""
return output.model_output.reshape(
output.model_output.shape[0], -1, self.model_description["dim"]
Expand All @@ -149,53 +149,55 @@ def _post_process_onnx_image_output(
def _post_process_onnx_text_output(
self,
output: OnnxOutputContext,
) -> Iterable[np.ndarray]:
) -> Iterable[NumpyArray]:
"""
Post-process the ONNX model output to convert it into a usable format.
Args:
output (OnnxOutputContext): The raw output from the ONNX model.
Returns:
Iterable[np.ndarray]: Post-processed output as NumPy arrays.
Iterable[NumpyArray]: Post-processed output as NumPy arrays.
"""
return output.model_output.astype(np.float32)

def tokenize(self, documents: list[str], **_) -> list[Encoding]:
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
texts_query: list[str] = []
for query in documents:
query = self.BOS_TOKEN + self.QUERY_PREFIX + query + self.PAD_TOKEN * 10
query += "\n"

texts_query.append(query)
encoded = self.tokenizer.encode_batch(texts_query)
encoded = self.tokenizer.encode_batch(texts_query) # type: ignore[union-attr]
return encoded

def _preprocess_onnx_text_input(
self, onnx_input: dict[str, np.ndarray], **kwargs
) -> dict[str, np.ndarray]:
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
) -> dict[str, NumpyArray]:
onnx_input["input_ids"] = np.array(
[
self.QUERY_MARKER_TOKEN_ID + input_ids[2:].tolist()
for input_ids in onnx_input["input_ids"]
]
)
empty_image_placeholder = np.zeros(self.IMAGE_PLACEHOLDER_SIZE, dtype=np.float32)
empty_image_placeholder: NumpyArray = np.zeros(
self.IMAGE_PLACEHOLDER_SIZE, dtype=np.float32
)
onnx_input["pixel_values"] = np.array(
[empty_image_placeholder for _ in onnx_input["input_ids"]]
[empty_image_placeholder for _ in onnx_input["input_ids"]],
)
return onnx_input

def _preprocess_onnx_image_input(
self, onnx_input: dict[str, np.ndarray], **kwargs
) -> dict[str, np.ndarray]:
self, onnx_input: dict[str, np.ndarray], **kwargs: Any
) -> dict[str, NumpyArray]:
"""
Add placeholders for text input when processing image data for ONNX.
Args:
onnx_input (Dict[str, np.ndarray]): Preprocessed image inputs.
onnx_input (Dict[str, NumpyArray]): Preprocessed image inputs.
**kwargs: Additional arguments.
Returns:
Dict[str, np.ndarray]: ONNX input with text placeholders.
Dict[str, NumpyArray]: ONNX input with text placeholders.
"""

onnx_input["input_ids"] = np.array(
Expand All @@ -211,8 +213,8 @@ def embed_text(
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
**kwargs: Any,
) -> Iterable[NumpyArray]:
"""
Encode a list of documents into list of embeddings.
Expand Down Expand Up @@ -241,11 +243,11 @@ def embed_text(

def embed_image(
self,
images: ImageInput,
images: Union[ImageInput, Iterable[ImageInput]],
batch_size: int = 16,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
**kwargs: Any,
) -> Iterable[NumpyArray]:
"""
Encode a list of images into list of embeddings.
Expand Down Expand Up @@ -273,16 +275,16 @@ def embed_image(
)

@classmethod
def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker]:
def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]:
return ColPaliTextEmbeddingWorker

@classmethod
def _get_image_worker_class(cls) -> Type[ImageEmbeddingWorker]:
def _get_image_worker_class(cls) -> Type[ImageEmbeddingWorker[NumpyArray]]:
return ColPaliImageEmbeddingWorker


class ColPaliTextEmbeddingWorker(TextEmbeddingWorker):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali:
class ColPaliTextEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> ColPali:
return ColPali(
model_name=model_name,
cache_dir=cache_dir,
Expand All @@ -291,8 +293,8 @@ def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali:
)


class ColPaliImageEmbeddingWorker(ImageEmbeddingWorker):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali:
class ColPaliImageEmbeddingWorker(ImageEmbeddingWorker[NumpyArray]):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> ColPali:
return ColPali(
model_name=model_name,
cache_dir=cache_dir,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np

from fastembed.common import OnnxProvider, ImageInput
from fastembed.common.types import NumpyArray
from fastembed.late_interaction_multimodal.colpali import ColPali

from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
Expand Down Expand Up @@ -41,7 +40,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
]
```
"""
result = []
result: list[dict[str, Any]] = []
for embedding in cls.EMBEDDINGS_REGISTRY:
result.extend(embedding.list_supported_models())
return result
Expand All @@ -55,7 +54,7 @@ def __init__(
cuda: bool = False,
device_ids: Optional[list[int]] = None,
lazy_load: bool = False,
**kwargs,
**kwargs: Any,
):
super().__init__(model_name, cache_dir, threads, **kwargs)
for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
Expand Down Expand Up @@ -83,8 +82,8 @@ def embed_text(
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
**kwargs: Any,
) -> Iterable[NumpyArray]:
"""
Encode a list of documents into list of embeddings.
Expand All @@ -106,8 +105,8 @@ def embed_image(
images: Union[ImageInput, Iterable[ImageInput]],
batch_size: int = 16,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
**kwargs: Any,
) -> Iterable[NumpyArray]:
"""
Encode a list of images into list of embeddings.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Iterable, Optional, Union
from typing import Iterable, Optional, Union, Any

import numpy as np

from fastembed.common import ImageInput
from fastembed.common.model_management import ModelManagement
from fastembed.common.types import NumpyArray


class LateInteractionMultimodalEmbeddingBase(ModelManagement):
Expand All @@ -12,7 +12,7 @@ def __init__(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs,
**kwargs: Any,
):
self.model_name = model_name
self.cache_dir = cache_dir
Expand All @@ -24,8 +24,8 @@ def embed_text(
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
**kwargs: Any,
) -> Iterable[NumpyArray]:
"""
Embeds a list of documents into a list of embeddings.
Expand All @@ -39,7 +39,7 @@ def embed_text(
**kwargs: Additional keyword argument to pass to the embed method.
Yields:
Iterable[np.ndarray]: The embeddings.
Iterable[NumpyArray]: The embeddings.
"""
raise NotImplementedError()

Expand All @@ -48,8 +48,8 @@ def embed_image(
images: Union[ImageInput, Iterable[ImageInput]],
batch_size: int = 16,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
**kwargs: Any,
) -> Iterable[NumpyArray]:
"""
Encode a list of images into list of embeddings.
Args:
Expand Down
Loading

0 comments on commit 4599b93

Please sign in to comment.