Skip to content

Commit

Permalink
new: preserve embeddings in a type set by their model (#492)
Browse files Browse the repository at this point in the history
* new: preserve embeddings in a type set by their model

* fix: remove type coercion

* fix: remove redundant type

* fix: fix random data type in tests
  • Loading branch information
joein authored Mar 3, 2025
1 parent 42fca3b commit 1729aab
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 19 deletions.
3 changes: 1 addition & 2 deletions fastembed/image/onnx_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np

from fastembed.common.types import NumpyArray
from fastembed.common import ImageInput, OnnxProvider
Expand Down Expand Up @@ -195,7 +194,7 @@ def _preprocess_onnx_input(
return onnx_input

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]:
return normalize(output.model_output).astype(np.float32)
return normalize(output.model_output)


class OnnxImageEmbeddingWorker(ImageEmbeddingWorker[NumpyArray]):
Expand Down
6 changes: 3 additions & 3 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _post_process_onnx_output(
self, output: OnnxOutputContext, is_doc: bool = True
) -> Iterable[NumpyArray]:
if not is_doc:
return output.model_output.astype(np.float32)
return output.model_output

if output.input_ids is None or output.attention_mask is None:
raise ValueError(
Expand All @@ -58,11 +58,11 @@ def _post_process_onnx_output(
if token_id in self.skip_list or token_id == self.pad_token_id:
output.attention_mask[i, j] = 0

output.model_output *= np.expand_dims(output.attention_mask, 2).astype(np.float32)
output.model_output *= np.expand_dims(output.attention_mask, 2)
norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
norm_clamped = np.maximum(norm, 1e-12)
output.model_output /= norm_clamped
return output.model_output.astype(np.float32)
return output.model_output

def _preprocess_onnx_input(
self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any
Expand Down
4 changes: 2 additions & 2 deletions fastembed/late_interaction_multimodal/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _post_process_onnx_image_output(
assert self.model_description.dim is not None, "Model dim is not defined"
return output.model_output.reshape(
output.model_output.shape[0], -1, self.model_description.dim
).astype(np.float32)
)

def _post_process_onnx_text_output(
self,
Expand All @@ -157,7 +157,7 @@ def _post_process_onnx_text_output(
Returns:
Iterable[NumpyArray]: Post-processed output as NumPy arrays.
"""
return output.model_output.astype(np.float32)
return output.model_output

def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
texts_query: list[str] = []
Expand Down
3 changes: 1 addition & 2 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np
from fastembed.common.types import NumpyArray, OnnxProvider
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.utils import define_cache_dir, normalize
Expand Down Expand Up @@ -313,7 +312,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy
processed_embeddings = embeddings
else:
raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
return normalize(processed_embeddings).astype(np.float32)
return normalize(processed_embeddings)

def load_onnx_model(self) -> None:
self._load_onnx_model(
Expand Down
2 changes: 1 addition & 1 deletion fastembed/text/pooled_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy

embeddings = output.model_output
attn_mask = output.attention_mask
return self.mean_pooling(embeddings, attn_mask).astype(np.float32)
return self.mean_pooling(embeddings, attn_mask)


class PooledEmbeddingWorker(OnnxTextEmbeddingWorker):
Expand Down
3 changes: 1 addition & 2 deletions fastembed/text/pooled_normalized_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Iterable, Type

import numpy as np

from fastembed.common.types import NumpyArray
from fastembed.common.onnx_model import OnnxOutputContext
Expand Down Expand Up @@ -145,7 +144,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy

embeddings = output.model_output
attn_mask = output.attention_mask
return normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
return normalize(self.mean_pooling(embeddings, attn_mask))


class PooledNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker):
Expand Down
10 changes: 3 additions & 7 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,11 @@ def test_mock_add_custom_models():
expected_output = {
f"{PoolingType.MEAN.lower()}-normalized": normalize(
mean_pooling(dummy_token_embedding, dummy_attention_mask)
).astype(np.float32),
f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask),
f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]).astype(
np.float32
),
f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask),
f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]),
f"{PoolingType.CLS.lower()}": dummy_token_embedding[:, 0],
f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding).astype(
np.float32
),
f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding),
f"{PoolingType.DISABLED.lower()}": dummy_pooled_embedding,
}

Expand Down

0 comments on commit 1729aab

Please sign in to comment.