Skip to content

Commit

Permalink
chore: Add missing type hints in functions (#453)
Browse files Browse the repository at this point in the history
* chore: Add missing type hints in functions

* add missing import, small type refactor

---------

Co-authored-by: George Panchuk <george.panchuk@qdrant.tech>
  • Loading branch information
hh-space-invader and joein authored Jan 29, 2025
1 parent 73e1e5e commit 993dcd5
Show file tree
Hide file tree
Showing 13 changed files with 41 additions and 33 deletions.
2 changes: 1 addition & 1 deletion fastembed/common/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import shutil
import tarfile
from pathlib import Path
from typing import Any
from typing import Any, Optional

import requests
from huggingface_hub import snapshot_download, model_info, list_repo_tree
Expand Down
4 changes: 2 additions & 2 deletions fastembed/common/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
from pathlib import Path
import sys
from PIL import Image
from typing import Any, Iterable, Union
Expand All @@ -9,7 +9,7 @@
from typing_extensions import TypeAlias


PathInput: TypeAlias = Union[str, os.PathLike]
PathInput: TypeAlias = Union[str, Path]
PilInput: TypeAlias = Union[Image.Image, Iterable[Image.Image]]
ImageInput: TypeAlias = Union[PathInput, Iterable[PathInput], PilInput]

Expand Down
8 changes: 5 additions & 3 deletions fastembed/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,22 @@
import unicodedata
from pathlib import Path
from itertools import islice
from typing import Generator, Iterable, Optional, Union
from typing import Iterable, Optional, TypeVar

import numpy as np

T = TypeVar("T")

def normalize(input_array, p=2, dim=1, eps=1e-12) -> np.ndarray:

def normalize(input_array: np.ndarray, p: int = 2, dim: int = 1, eps: float = 1e-12) -> np.ndarray:
# Calculate the Lp norm along the specified dimension
norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True)
norm = np.maximum(norm, eps) # Avoid division by zero
normalized_array = input_array / norm
return normalized_array


def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
def iter_batch(iterable: Iterable[T], size: int) -> Iterable[list[T]]:
"""
>>> list(iter_batch([1,2,3,4,5], 3))
[[1, 2, 3], [4, 5]]
Expand Down
2 changes: 1 addition & 1 deletion fastembed/image/transform/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def resize(
return image.resize(new_size, resample)


def rescale(image: np.ndarray, scale: float, dtype=np.float32) -> np.ndarray:
def rescale(image: np.ndarray, scale: float, dtype: type = np.float32) -> np.ndarray:
return (image * scale).astype(dtype)


Expand Down
10 changes: 5 additions & 5 deletions tests/test_attention_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"])
def test_attention_embeddings(model_name) -> None:
def test_attention_embeddings(model_name: str) -> None:
is_ci = os.getenv("CI")
model = SparseTextEmbedding(model_name=model_name)

Expand Down Expand Up @@ -71,7 +71,7 @@ def test_attention_embeddings(model_name) -> None:


@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"])
def test_parallel_processing(model_name) -> None:
def test_parallel_processing(model_name: str) -> None:
is_ci = os.getenv("CI")

model = SparseTextEmbedding(model_name=model_name)
Expand All @@ -96,7 +96,7 @@ def test_parallel_processing(model_name) -> None:


@pytest.mark.parametrize("model_name", ["Qdrant/bm25"])
def test_multilanguage(model_name) -> None:
def test_multilanguage(model_name: str) -> None:
is_ci = os.getenv("CI")

docs = ["Mangez-vous vraiment des grenouilles?", "Je suis au lit"]
Expand All @@ -122,7 +122,7 @@ def test_multilanguage(model_name) -> None:


@pytest.mark.parametrize("model_name", ["Qdrant/bm25"])
def test_special_characters(model_name) -> None:
def test_special_characters(model_name: str) -> None:
is_ci = os.getenv("CI")

docs = [
Expand All @@ -145,7 +145,7 @@ def test_special_characters(model_name) -> None:


@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions"])
def test_lazy_load(model_name) -> None:
def test_lazy_load(model_name: str) -> None:
model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")
docs = ["hello world", "flag embedding"]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_image_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_embedding() -> None:


@pytest.mark.parametrize("n_dims,model_name", [(512, "Qdrant/clip-ViT-B-32-vision")])
def test_batch_embedding(n_dims, model_name) -> None:
def test_batch_embedding(n_dims: int, model_name: str) -> None:
is_ci = os.getenv("CI")
model = ImageEmbedding(model_name=model_name)
n_images = 32
Expand All @@ -81,7 +81,7 @@ def test_batch_embedding(n_dims, model_name) -> None:


@pytest.mark.parametrize("n_dims,model_name", [(512, "Qdrant/clip-ViT-B-32-vision")])
def test_parallel_processing(n_dims, model_name) -> None:
def test_parallel_processing(n_dims: int, model_name: str) -> None:
is_ci = os.getenv("CI")
model = ImageEmbedding(model_name=model_name)

Expand Down Expand Up @@ -109,7 +109,7 @@ def test_parallel_processing(n_dims, model_name) -> None:


@pytest.mark.parametrize("model_name", ["Qdrant/clip-ViT-B-32-vision"])
def test_lazy_load(model_name) -> None:
def test_lazy_load(model_name: str) -> None:
is_ci = os.getenv("CI")
model = ImageEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_late_interaction_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_parallel_processing():
"model_name",
["colbert-ir/colbertv2.0"],
)
def test_lazy_load(model_name):
def test_lazy_load(model_name: str):
is_ci = os.getenv("CI")

model = LateInteractionTextEmbedding(model_name=model_name, lazy_load=True)
Expand Down
7 changes: 4 additions & 3 deletions tests/test_multi_gpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from typing import Optional
from fastembed import (
TextEmbedding,
SparseTextEmbedding,
Expand All @@ -13,7 +14,7 @@

@pytest.mark.skip(reason="Requires a multi-gpu server")
@pytest.mark.parametrize("device_id", [None, 0, 1])
def test_gpu_via_providers(device_id) -> None:
def test_gpu_via_providers(device_id: Optional[int]) -> None:
docs = ["hello world", "flag embedding"]

device_id = device_id if device_id is not None else 0
Expand Down Expand Up @@ -85,7 +86,7 @@ def test_gpu_via_providers(device_id) -> None:

@pytest.mark.skip(reason="Requires a multi-gpu server")
@pytest.mark.parametrize("device_ids", [None, [0], [1], [0, 1]])
def test_gpu_cuda_device_ids(device_ids) -> None:
def test_gpu_cuda_device_ids(device_ids: Optional[list[int]]) -> None:
docs = ["hello world", "flag embedding"]
device_id = device_ids[0] if device_ids else 0
embedding_model = TextEmbedding(
Expand Down Expand Up @@ -170,7 +171,7 @@ def test_gpu_cuda_device_ids(device_ids) -> None:
@pytest.mark.parametrize(
"device_ids,parallel", [(None, None), (None, 2), ([1], None), ([1], 1), ([1], 2), ([0, 1], 2)]
)
def test_multi_gpu_parallel_inference(device_ids, parallel) -> None:
def test_multi_gpu_parallel_inference(device_ids: Optional[list[int]], parallel: int) -> None:
docs = ["hello world", "flag embedding"] * 100
batch_size = 5

Expand Down
8 changes: 4 additions & 4 deletions tests/test_sparse_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def bm25_instance() -> None:
delete_model_cache(model._model_dir)


def test_stem_with_stopwords_and_punctuation(bm25_instance) -> None:
def test_stem_with_stopwords_and_punctuation(bm25_instance: Bm25) -> None:
# Setup
bm25_instance.stopwords = {"the", "is", "a"}
bm25_instance.punctuation = {".", ",", "!"}
Expand All @@ -135,7 +135,7 @@ def test_stem_with_stopwords_and_punctuation(bm25_instance) -> None:
assert result == expected, f"Expected {expected}, but got {result}"


def test_stem_case_insensitive_stopwords(bm25_instance) -> None:
def test_stem_case_insensitive_stopwords(bm25_instance: Bm25) -> None:
# Setup
bm25_instance.stopwords = {"the", "is", "a"}
bm25_instance.punctuation = {".", ",", "!"}
Expand All @@ -152,7 +152,7 @@ def test_stem_case_insensitive_stopwords(bm25_instance) -> None:


@pytest.mark.parametrize("disable_stemmer", [True, False])
def test_disable_stemmer_behavior(disable_stemmer) -> None:
def test_disable_stemmer_behavior(disable_stemmer: bool) -> None:
# Setup
model = Bm25("Qdrant/bm25", language="english", disable_stemmer=disable_stemmer)
model.stopwords = {"the", "is", "a"}
Expand All @@ -176,7 +176,7 @@ def test_disable_stemmer_behavior(disable_stemmer) -> None:
"model_name",
["prithivida/Splade_PP_en_v1"],
)
def test_lazy_load(model_name) -> None:
def test_lazy_load(model_name: str) -> None:
is_ci = os.getenv("CI")
model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")
Expand Down
8 changes: 4 additions & 4 deletions tests/test_text_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"model_name",
[model_name for model_name in CANONICAL_SCORE_VALUES],
)
def test_rerank(model_name) -> None:
def test_rerank(model_name: str) -> None:
is_ci = os.getenv("CI")

model = TextCrossEncoder(model_name=model_name)
Expand All @@ -53,7 +53,7 @@ def test_rerank(model_name) -> None:
"model_name",
[model_name for model_name in SELECTED_MODELS.values()],
)
def test_batch_rerank(model_name) -> None:
def test_batch_rerank(model_name: str) -> None:
is_ci = os.getenv("CI")

model = TextCrossEncoder(model_name=model_name)
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_batch_rerank(model_name) -> None:
"model_name",
["Xenova/ms-marco-MiniLM-L-6-v2"],
)
def test_lazy_load(model_name) -> None:
def test_lazy_load(model_name: str) -> None:
is_ci = os.getenv("CI")
model = TextCrossEncoder(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")
Expand All @@ -99,7 +99,7 @@ def test_lazy_load(model_name) -> None:
"model_name",
[model_name for model_name in SELECTED_MODELS.values()],
)
def test_rerank_pairs_parallel(model_name) -> None:
def test_rerank_pairs_parallel(model_name: str) -> None:
is_ci = os.getenv("CI")

model = TextCrossEncoder(model_name=model_name)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_text_multitask_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_task_assignment():
"model_name",
["jinaai/jina-embeddings-v3"],
)
def test_lazy_load(model_name):
def test_lazy_load(model_name: str):
is_ci = os.getenv("CI")
model = TextEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")
Expand Down
6 changes: 3 additions & 3 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_embedding() -> None:
"n_dims,model_name",
[(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")],
)
def test_batch_embedding(n_dims, model_name) -> None:
def test_batch_embedding(n_dims: int, model_name: str) -> None:
is_ci = os.getenv("CI")
model = TextEmbedding(model_name=model_name)

Expand All @@ -121,7 +121,7 @@ def test_batch_embedding(n_dims, model_name) -> None:
"n_dims,model_name",
[(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")],
)
def test_parallel_processing(n_dims, model_name) -> None:
def test_parallel_processing(n_dims: int, model_name: str) -> None:
is_ci = os.getenv("CI")
model = TextEmbedding(model_name=model_name)

Expand All @@ -147,7 +147,7 @@ def test_parallel_processing(n_dims, model_name) -> None:
"model_name",
["BAAI/bge-small-en-v1.5"],
)
def test_lazy_load(model_name) -> None:
def test_lazy_load(model_name: str) -> None:
is_ci = os.getenv("CI")
model = TextEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")
Expand Down
9 changes: 7 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import traceback

from pathlib import Path
from typing import Union
from types import TracebackType
from typing import Union, Callable, Any, Type


def delete_model_cache(model_dir: Union[str, Path]) -> None:
Expand All @@ -16,7 +17,11 @@ def delete_model_cache(model_dir: Union[str, Path]) -> None:
model_dir (Union[str, Path]): The path to the model cache directory.
"""

def on_error(func, path, exc_info) -> None:
def on_error(
func: Callable[..., Any],
path: str,
exc_info: tuple[Type[BaseException], BaseException, TracebackType],
) -> None:
print("Failed to remove: ", path)
print("Exception: ", exc_info)
traceback.print_exception(*exc_info)
Expand Down

0 comments on commit 993dcd5

Please sign in to comment.