Skip to content

Commit

Permalink
new: Add sparse type hints (#460)
Browse files Browse the repository at this point in the history
* new: Add sparse type hints

* fix: ndarray -> numpyarray

---------

Co-authored-by: George Panchuk <george.panchuk@qdrant.tech>
  • Loading branch information
hh-space-invader and joein authored Feb 4, 2025
1 parent b08febb commit 37a66d9
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 29 deletions.
12 changes: 6 additions & 6 deletions fastembed/sparse/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
self.avg_len = avg_len

model_description = self._get_model_description(model_name)
self.cache_dir = define_cache_dir(cache_dir)
self.cache_dir = str(define_cache_dir(cache_dir))

self._model_dir = self.download_model(
model_description,
Expand All @@ -137,7 +137,7 @@ def __init__(
self.disable_stemmer = disable_stemmer

if disable_stemmer:
self.stopwords = set()
self.stopwords: set[str] = set()
self.stemmer = None
else:
self.stopwords = set(self._load_stopwords(self._model_dir, self.language))
Expand Down Expand Up @@ -239,7 +239,7 @@ def embed(
)

def _stem(self, tokens: list[str]) -> list[str]:
stemmed_tokens = []
stemmed_tokens: list[str] = []
for token in tokens:
lower_token = token.lower()

Expand All @@ -262,7 +262,7 @@ def raw_embed(
self,
documents: list[str],
) -> list[SparseEmbedding]:
embeddings = []
embeddings: list[SparseEmbedding] = []
for document in documents:
document = remove_non_alphanumeric(document)
tokens = self.tokenizer.tokenize(document)
Expand All @@ -286,8 +286,8 @@ def _term_frequency(self, tokens: list[str]) -> dict[int, float]:
Returns:
dict[int, float]: The token_id to term frequency mapping.
"""
tf_map = {}
counter = defaultdict(int)
tf_map: dict[int, float] = {}
counter: defaultdict[str, int] = defaultdict(int)
for stemmed_token in tokens:
counter[stemmed_token] += 1

Expand Down
30 changes: 15 additions & 15 deletions fastembed/sparse/bm42.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
self.device_id = None

self.model_description = self._get_model_description(model_name)
self.cache_dir = define_cache_dir(cache_dir)
self.cache_dir = str(define_cache_dir(cache_dir))

self._model_dir = self.download_model(
self.model_description,
Expand All @@ -119,10 +119,10 @@ def __init__(
specific_model_path=specific_model_path,
)

self.invert_vocab = {}
self.invert_vocab: dict[int, str] = {}

self.special_tokens = set()
self.special_tokens_ids = set()
self.special_tokens: set[str] = set()
self.special_tokens_ids: set[int] = set()
self.punctuation = set(string.punctuation)
self.stopwords = set(self._load_stopwords(self._model_dir))
self.stemmer = SnowballStemmer(MODEL_TO_LANGUAGE[model_name])
Expand All @@ -147,15 +147,15 @@ def load_onnx_model(self) -> None:
self.stopwords = set(self._load_stopwords(self._model_dir))

def _filter_pair_tokens(self, tokens: list[tuple[str, Any]]) -> list[tuple[str, Any]]:
result = []
result: list[tuple[str, Any]] = []
for token, value in tokens:
if token in self.stopwords or token in self.punctuation:
continue
result.append((token, value))
return result

def _stem_pair_tokens(self, tokens: list[tuple[str, Any]]) -> list[tuple[str, Any]]:
result = []
result: list[tuple[str, Any]] = []
for token, value in tokens:
processed_token = self.stemmer.stem_word(token)
result.append((processed_token, value))
Expand All @@ -165,7 +165,7 @@ def _stem_pair_tokens(self, tokens: list[tuple[str, Any]]) -> list[tuple[str, An
def _aggregate_weights(
cls, tokens: list[tuple[str, list[int]]], weights: list[float]
) -> list[tuple[str, float]]:
result = []
result: list[tuple[str, float]] = []
for token, idxs in tokens:
sum_weight = sum(weights[idx] for idx in idxs)
result.append((token, sum_weight))
Expand All @@ -174,9 +174,9 @@ def _aggregate_weights(
def _reconstruct_bpe(
self, bpe_tokens: Iterable[tuple[int, str]]
) -> list[tuple[str, list[int]]]:
result = []
acc = ""
acc_idx = []
result: list[tuple[str, list[int]]] = []
acc: str = ""
acc_idx: list[int] = []

continuing_subword_prefix = self.tokenizer.model.continuing_subword_prefix
continuing_subword_prefix_len = len(continuing_subword_prefix)
Expand Down Expand Up @@ -206,7 +206,7 @@ def _rescore_vector(self, vector: dict[str, float]) -> dict[int, float]:
So that the scoring doesn't depend on absolute values assigned by the model, but on the relative importance.
"""

new_vector = {}
new_vector: dict[int, float] = {}

for token, value in vector.items():
token_id = abs(mmh3.hash(token))
Expand Down Expand Up @@ -241,7 +241,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Spars

weighted = self._aggregate_weights(stemmed, attention_value)

max_token_weight = {}
max_token_weight: dict[str, float] = {}

for token, weight in weighted:
max_token_weight[token] = max(max_token_weight.get(token, 0), weight)
Expand Down Expand Up @@ -304,7 +304,7 @@ def embed(

@classmethod
def _query_rehash(cls, tokens: Iterable[str]) -> dict[int, float]:
result = {}
result: dict[int, float] = {}
for token in tokens:
token_id = abs(mmh3.hash(token))
result[token_id] = 1.0
Expand Down Expand Up @@ -334,11 +334,11 @@ def query_embed(
yield SparseEmbedding.from_dict(self._query_rehash(token for token, _ in stemmed))

@classmethod
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
def _get_worker_class(cls) -> Type[TextEmbeddingWorker[SparseEmbedding]]:
return Bm42TextEmbeddingWorker


class Bm42TextEmbeddingWorker(TextEmbeddingWorker):
class Bm42TextEmbeddingWorker(TextEmbeddingWorker[SparseEmbedding]):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> Bm42:
return Bm42(
model_name=model_name,
Expand Down
9 changes: 5 additions & 4 deletions fastembed/sparse/sparse_embedding_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@

import numpy as np

from fastembed.common.types import NumpyArray
from fastembed.common.model_management import ModelManagement


@dataclass
class SparseEmbedding:
values: np.ndarray
indices: np.ndarray
values: NumpyArray
indices: NumpyArray

def as_object(self) -> dict[str, np.ndarray]:
def as_object(self) -> dict[str, NumpyArray]:
return {
"values": self.values,
"indices": self.indices,
Expand Down Expand Up @@ -81,5 +82,5 @@ def query_embed(
# This is model-specific, so that different models can have specialized implementations
if isinstance(query, str):
yield from self.embed([query], **kwargs)
if isinstance(query, Iterable):
else:
yield from self.embed(query, **kwargs)
2 changes: 1 addition & 1 deletion fastembed/sparse/sparse_text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
]
```
"""
result = []
result: list[dict[str, Any]] = []
for embedding in cls.EMBEDDINGS_REGISTRY:
result.extend(embedding.list_supported_models())
return result
Expand Down
6 changes: 3 additions & 3 deletions fastembed/sparse/splade_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
self.device_id = None

self.model_description = self._get_model_description(model_name)
self.cache_dir = define_cache_dir(cache_dir)
self.cache_dir = str(define_cache_dir(cache_dir))

self._model_dir = self.download_model(
self.model_description,
Expand Down Expand Up @@ -171,11 +171,11 @@ def embed(
)

@classmethod
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
def _get_worker_class(cls) -> Type[TextEmbeddingWorker[SparseEmbedding]]:
return SpladePPEmbeddingWorker


class SpladePPEmbeddingWorker(TextEmbeddingWorker):
class SpladePPEmbeddingWorker(TextEmbeddingWorker[SparseEmbedding]):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> SpladePP:
return SpladePP(
model_name=model_name,
Expand Down

0 comments on commit 37a66d9

Please sign in to comment.