Skip to content

Commit

Permalink
chore: Add missing type hints in functions
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Jan 28, 2025
1 parent 6cbf60b commit 8865b7d
Show file tree
Hide file tree
Showing 12 changed files with 43 additions and 31 deletions.
2 changes: 1 addition & 1 deletion fastembed/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing_extensions import TypeAlias


PathInput: TypeAlias = Union[str, os.PathLike]
PathInput: TypeAlias = Union[str, os.PathLike[str]]
PilInput: TypeAlias = Union[Image.Image, Iterable[Image.Image]]
ImageInput: TypeAlias = Union[PathInput, Iterable[PathInput], PilInput]

Expand Down
12 changes: 9 additions & 3 deletions fastembed/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,26 @@
import unicodedata
from pathlib import Path
from itertools import islice
from typing import Generator, Iterable, Optional, Union
from typing import Iterable, Optional, TypeVar

import numpy as np
from numpy.typing import NDArray


def normalize(input_array, p=2, dim=1, eps=1e-12) -> np.ndarray:
T = TypeVar("T")


def normalize(
input_array: NDArray[np.float32], p: int = 2, dim: int = 1, eps: float = 1e-12
) -> np.ndarray:
# Calculate the Lp norm along the specified dimension
norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True)
norm = np.maximum(norm, eps) # Avoid division by zero
normalized_array = input_array / norm
return normalized_array


def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
def iter_batch(iterable: Iterable[T], size: int) -> Iterable[list[T]]:
"""
>>> list(iter_batch([1,2,3,4,5], 3))
[[1, 2, 3], [4, 5]]
Expand Down
2 changes: 1 addition & 1 deletion fastembed/image/transform/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def resize(
return image.resize(new_size, resample)


def rescale(image: np.ndarray, scale: float, dtype=np.float32) -> np.ndarray:
def rescale(image: np.ndarray, scale: float, dtype: type = np.float32) -> np.ndarray:
return (image * scale).astype(dtype)


Expand Down
10 changes: 5 additions & 5 deletions tests/test_attention_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"])
def test_attention_embeddings(model_name):
def test_attention_embeddings(model_name: str):
is_ci = os.getenv("CI")
model = SparseTextEmbedding(model_name=model_name)

Expand Down Expand Up @@ -71,7 +71,7 @@ def test_attention_embeddings(model_name):


@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"])
def test_parallel_processing(model_name):
def test_parallel_processing(model_name: str):
is_ci = os.getenv("CI")

model = SparseTextEmbedding(model_name=model_name)
Expand All @@ -96,7 +96,7 @@ def test_parallel_processing(model_name):


@pytest.mark.parametrize("model_name", ["Qdrant/bm25"])
def test_multilanguage(model_name):
def test_multilanguage(model_name: str):
is_ci = os.getenv("CI")

docs = ["Mangez-vous vraiment des grenouilles?", "Je suis au lit"]
Expand All @@ -122,7 +122,7 @@ def test_multilanguage(model_name):


@pytest.mark.parametrize("model_name", ["Qdrant/bm25"])
def test_special_characters(model_name):
def test_special_characters(model_name: str):
is_ci = os.getenv("CI")

docs = [
Expand All @@ -145,7 +145,7 @@ def test_special_characters(model_name):


@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions"])
def test_lazy_load(model_name):
def test_lazy_load(model_name: str):
model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")
docs = ["hello world", "flag embedding"]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_image_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_embedding():


@pytest.mark.parametrize("n_dims,model_name", [(512, "Qdrant/clip-ViT-B-32-vision")])
def test_batch_embedding(n_dims, model_name):
def test_batch_embedding(n_dims: int, model_name: str):
is_ci = os.getenv("CI")
model = ImageEmbedding(model_name=model_name)
n_images = 32
Expand All @@ -81,7 +81,7 @@ def test_batch_embedding(n_dims, model_name):


@pytest.mark.parametrize("n_dims,model_name", [(512, "Qdrant/clip-ViT-B-32-vision")])
def test_parallel_processing(n_dims, model_name):
def test_parallel_processing(n_dims: int, model_name: str):
is_ci = os.getenv("CI")
model = ImageEmbedding(model_name=model_name)

Expand Down Expand Up @@ -109,7 +109,7 @@ def test_parallel_processing(n_dims, model_name):


@pytest.mark.parametrize("model_name", ["Qdrant/clip-ViT-B-32-vision"])
def test_lazy_load(model_name):
def test_lazy_load(model_name: str):
is_ci = os.getenv("CI")
model = ImageEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_late_interaction_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_parallel_processing():
"model_name",
["colbert-ir/colbertv2.0"],
)
def test_lazy_load(model_name):
def test_lazy_load(model_name: str):
is_ci = os.getenv("CI")

model = LateInteractionTextEmbedding(model_name=model_name, lazy_load=True)
Expand Down
7 changes: 4 additions & 3 deletions tests/test_multi_gpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from typing import Optional
from fastembed import (
TextEmbedding,
SparseTextEmbedding,
Expand All @@ -13,7 +14,7 @@

@pytest.mark.skip(reason="Requires a multi-gpu server")
@pytest.mark.parametrize("device_id", [None, 0, 1])
def test_gpu_via_providers(device_id):
def test_gpu_via_providers(device_id: Optional[list[int]]):
docs = ["hello world", "flag embedding"]

device_id = device_id if device_id is not None else 0
Expand Down Expand Up @@ -85,7 +86,7 @@ def test_gpu_via_providers(device_id):

@pytest.mark.skip(reason="Requires a multi-gpu server")
@pytest.mark.parametrize("device_ids", [None, [0], [1], [0, 1]])
def test_gpu_cuda_device_ids(device_ids):
def test_gpu_cuda_device_ids(device_ids: Optional[list[int]]):
docs = ["hello world", "flag embedding"]
device_id = device_ids[0] if device_ids else 0
embedding_model = TextEmbedding(
Expand Down Expand Up @@ -170,7 +171,7 @@ def test_gpu_cuda_device_ids(device_ids):
@pytest.mark.parametrize(
"device_ids,parallel", [(None, None), (None, 2), ([1], None), ([1], 1), ([1], 2), ([0, 1], 2)]
)
def test_multi_gpu_parallel_inference(device_ids, parallel):
def test_multi_gpu_parallel_inference(device_ids: Optional[list[int]], parallel: int):
docs = ["hello world", "flag embedding"] * 100
batch_size = 5

Expand Down
8 changes: 4 additions & 4 deletions tests/test_sparse_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def bm25_instance():
delete_model_cache(model._model_dir)


def test_stem_with_stopwords_and_punctuation(bm25_instance):
def test_stem_with_stopwords_and_punctuation(bm25_instance: Bm25):
# Setup
bm25_instance.stopwords = {"the", "is", "a"}
bm25_instance.punctuation = {".", ",", "!"}
Expand All @@ -135,7 +135,7 @@ def test_stem_with_stopwords_and_punctuation(bm25_instance):
assert result == expected, f"Expected {expected}, but got {result}"


def test_stem_case_insensitive_stopwords(bm25_instance):
def test_stem_case_insensitive_stopwords(bm25_instance: Bm25):
# Setup
bm25_instance.stopwords = {"the", "is", "a"}
bm25_instance.punctuation = {".", ",", "!"}
Expand All @@ -152,7 +152,7 @@ def test_stem_case_insensitive_stopwords(bm25_instance):


@pytest.mark.parametrize("disable_stemmer", [True, False])
def test_disable_stemmer_behavior(disable_stemmer):
def test_disable_stemmer_behavior(disable_stemmer: bool):
# Setup
model = Bm25("Qdrant/bm25", language="english", disable_stemmer=disable_stemmer)
model.stopwords = {"the", "is", "a"}
Expand All @@ -176,7 +176,7 @@ def test_disable_stemmer_behavior(disable_stemmer):
"model_name",
["prithivida/Splade_PP_en_v1"],
)
def test_lazy_load(model_name):
def test_lazy_load(model_name: str):
is_ci = os.getenv("CI")
model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")
Expand Down
8 changes: 4 additions & 4 deletions tests/test_text_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"model_name",
[model_name for model_name in CANONICAL_SCORE_VALUES],
)
def test_rerank(model_name):
def test_rerank(model_name: str):
is_ci = os.getenv("CI")

model = TextCrossEncoder(model_name=model_name)
Expand All @@ -53,7 +53,7 @@ def test_rerank(model_name):
"model_name",
[model_name for model_name in SELECTED_MODELS.values()],
)
def test_batch_rerank(model_name):
def test_batch_rerank(model_name: str):
is_ci = os.getenv("CI")

model = TextCrossEncoder(model_name=model_name)
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_batch_rerank(model_name):
"model_name",
["Xenova/ms-marco-MiniLM-L-6-v2"],
)
def test_lazy_load(model_name):
def test_lazy_load(model_name: str):
is_ci = os.getenv("CI")
model = TextCrossEncoder(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")
Expand All @@ -99,7 +99,7 @@ def test_lazy_load(model_name):
"model_name",
[model_name for model_name in SELECTED_MODELS.values()],
)
def test_rerank_pairs_parallel(model_name):
def test_rerank_pairs_parallel(model_name: str):
is_ci = os.getenv("CI")

model = TextCrossEncoder(model_name=model_name)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_text_multitask_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_task_assignment():
"model_name",
["jinaai/jina-embeddings-v3"],
)
def test_lazy_load(model_name):
def test_lazy_load(model_name: str):
is_ci = os.getenv("CI")
model = TextEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")
Expand Down
6 changes: 3 additions & 3 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_embedding():
"n_dims,model_name",
[(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")],
)
def test_batch_embedding(n_dims, model_name):
def test_batch_embedding(n_dims: int, model_name: str):
is_ci = os.getenv("CI")
model = TextEmbedding(model_name=model_name)

Expand All @@ -121,7 +121,7 @@ def test_batch_embedding(n_dims, model_name):
"n_dims,model_name",
[(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")],
)
def test_parallel_processing(n_dims, model_name):
def test_parallel_processing(n_dims: int, model_name: str):
is_ci = os.getenv("CI")
model = TextEmbedding(model_name=model_name)

Expand All @@ -147,7 +147,7 @@ def test_parallel_processing(n_dims, model_name):
"model_name",
["BAAI/bge-small-en-v1.5"],
)
def test_lazy_load(model_name):
def test_lazy_load(model_name: str):
is_ci = os.getenv("CI")
model = TextEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")
Expand Down
9 changes: 7 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import traceback

from pathlib import Path
from typing import Union
from types import TracebackType
from typing import Union, Callable, Any, Type


def delete_model_cache(model_dir: Union[str, Path]) -> None:
Expand All @@ -16,7 +17,11 @@ def delete_model_cache(model_dir: Union[str, Path]) -> None:
model_dir (Union[str, Path]): The path to the model cache directory.
"""

def on_error(func, path, exc_info):
def on_error(
func: Callable[..., Any],
path: str,
exc_info: tuple[Type[BaseException], BaseException, TracebackType],
) -> None:
print("Failed to remove: ", path)
print("Exception: ", exc_info)
traceback.print_exception(*exc_info)
Expand Down

0 comments on commit 8865b7d

Please sign in to comment.