diff --git a/fastembed/text/mini_lm_embedding.py b/fastembed/text/mini_lm_embedding.py deleted file mode 100644 index f12c1d81..00000000 --- a/fastembed/text/mini_lm_embedding.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Any, Dict, Iterable, List, Type - -import numpy as np - -from fastembed.common.onnx_model import OnnxOutputContext -from fastembed.common.utils import normalize -from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker -from fastembed.text.onnx_text_model import TextEmbeddingWorker - -supported_mini_lm_models = [ - { - "model": "sentence-transformers/all-MiniLM-L6-v2", - "dim": 384, - "description": "Sentence Transformer model, MiniLM-L6-v2", - "size_in_GB": 0.09, - "sources": { - "url": "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz", - "hf": "qdrant/all-MiniLM-L6-v2-onnx", - }, - "model_file": "model.onnx", - } -] - - -class MiniLMOnnxEmbedding(OnnxTextEmbedding): - @classmethod - def _get_worker_class(cls) -> Type[TextEmbeddingWorker]: - return MiniLMEmbeddingWorker - - @classmethod - def mean_pooling(cls, model_output: np.ndarray, attention_mask: np.ndarray) -> np.ndarray: - token_embeddings = model_output - input_mask_expanded = np.expand_dims(attention_mask, axis=-1) - input_mask_expanded = np.tile(input_mask_expanded, (1, 1, token_embeddings.shape[-1])) - input_mask_expanded = input_mask_expanded.astype(float) - sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1) - sum_mask = np.sum(input_mask_expanded, axis=1) - pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9) - return pooled_embeddings - - @classmethod - def list_supported_models(cls) -> List[Dict[str, Any]]: - """Lists the supported models. - - Returns: - List[Dict[str, Any]]: A list of dictionaries containing the model information. - """ - return supported_mini_lm_models - - def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]: - embeddings = output.model_output - attn_mask = output.attention_mask - return normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32) - - -class MiniLMEmbeddingWorker(OnnxTextEmbeddingWorker): - def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> OnnxTextEmbedding: - return MiniLMOnnxEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs) diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index d5cf4569..90002c62 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -80,36 +80,6 @@ }, "model_file": "model_optimized.onnx", }, - { - "model": "nomic-ai/nomic-embed-text-v1", - "dim": 768, - "description": "8192 context length english model", - "size_in_GB": 0.52, - "sources": { - "hf": "nomic-ai/nomic-embed-text-v1", - }, - "model_file": "onnx/model.onnx", - }, - { - "model": "nomic-ai/nomic-embed-text-v1.5", - "dim": 768, - "description": "8192 context length english model", - "size_in_GB": 0.52, - "sources": { - "hf": "nomic-ai/nomic-embed-text-v1.5", - }, - "model_file": "onnx/model.onnx", - }, - { - "model": "nomic-ai/nomic-embed-text-v1.5-Q", - "dim": 768, - "description": "Quantized 8192 context length english model", - "size_in_GB": 0.13, - "sources": { - "hf": "nomic-ai/nomic-embed-text-v1.5", - }, - "model_file": "onnx/model_quantized.onnx", - }, { "model": "thenlper/gte-large", "dim": 1024, @@ -274,7 +244,9 @@ def _preprocess_onnx_input( """ return onnx_input - def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]: + def _post_process_onnx_output( + self, output: OnnxOutputContext + ) -> Iterable[np.ndarray]: embeddings = output.model_output return normalize(embeddings[:, 0]).astype(np.float32) @@ -286,4 +258,6 @@ def init_embedding( cache_dir: str, **kwargs, ) -> OnnxTextEmbedding: - return OnnxTextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs) + return OnnxTextEmbedding( + model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs + ) diff --git a/fastembed/text/pooled_embedding.py b/fastembed/text/pooled_embedding.py new file mode 100644 index 00000000..881fe4b0 --- /dev/null +++ b/fastembed/text/pooled_embedding.py @@ -0,0 +1,87 @@ +from typing import Any, Dict, Iterable, List, Type + +import numpy as np + +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.utils import normalize +from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker +from fastembed.text.onnx_text_model import TextEmbeddingWorker + +supported_pooled_models = [ + { + "model": "nomic-ai/nomic-embed-text-v1.5", + "dim": 768, + "description": "8192 context length english model", + "size_in_GB": 0.52, + "sources": { + "hf": "nomic-ai/nomic-embed-text-v1.5", + }, + "model_file": "onnx/model.onnx", + }, + { + "model": "nomic-ai/nomic-embed-text-v1.5-Q", + "dim": 768, + "description": "Quantized 8192 context length english model", + "size_in_GB": 0.13, + "sources": { + "hf": "nomic-ai/nomic-embed-text-v1.5", + }, + "model_file": "onnx/model_quantized.onnx", + }, + { + "model": "nomic-ai/nomic-embed-text-v1", + "dim": 768, + "description": "8192 context length english model", + "size_in_GB": 0.52, + "sources": { + "hf": "nomic-ai/nomic-embed-text-v1", + }, + "model_file": "onnx/model.onnx", + }, +] + + +class PooledEmbedding(OnnxTextEmbedding): + @classmethod + def _get_worker_class(cls) -> Type[TextEmbeddingWorker]: + return PooledEmbeddingWorker + + @classmethod + def mean_pooling( + cls, model_output: np.ndarray, attention_mask: np.ndarray + ) -> np.ndarray: + token_embeddings = model_output + input_mask_expanded = np.expand_dims(attention_mask, axis=-1) + input_mask_expanded = np.tile( + input_mask_expanded, (1, 1, token_embeddings.shape[-1]) + ) + input_mask_expanded = input_mask_expanded.astype(float) + sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1) + sum_mask = np.sum(input_mask_expanded, axis=1) + pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9) + return pooled_embeddings + + @classmethod + def list_supported_models(cls) -> List[Dict[str, Any]]: + """Lists the supported models. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the model information. + """ + return supported_pooled_models + + def _post_process_onnx_output( + self, output: OnnxOutputContext + ) -> Iterable[np.ndarray]: + embeddings = output.model_output + attn_mask = output.attention_mask + return self.mean_pooling(embeddings, attn_mask).astype(np.float32) + + +class PooledEmbeddingWorker(OnnxTextEmbeddingWorker): + def init_embedding( + self, model_name: str, cache_dir: str, **kwargs + ) -> OnnxTextEmbedding: + return PooledEmbedding( + model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs + ) diff --git a/fastembed/text/jina_onnx_embedding.py b/fastembed/text/pooled_normalized_embedding.py similarity index 72% rename from fastembed/text/jina_onnx_embedding.py rename to fastembed/text/pooled_normalized_embedding.py index 9f70fef9..5ac932f8 100644 --- a/fastembed/text/jina_onnx_embedding.py +++ b/fastembed/text/pooled_normalized_embedding.py @@ -6,8 +6,20 @@ from fastembed.common.utils import normalize from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker from fastembed.text.onnx_text_model import TextEmbeddingWorker +from fastembed.text.pooled_embedding import PooledEmbedding -supported_jina_models = [ +supported_pooled_normalized_models = [ + { + "model": "sentence-transformers/all-MiniLM-L6-v2", + "dim": 384, + "description": "Sentence Transformer model, MiniLM-L6-v2", + "size_in_GB": 0.09, + "sources": { + "url": "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz", + "hf": "qdrant/all-MiniLM-L6-v2-onnx", + }, + "model_file": "model.onnx", + }, { "model": "jinaai/jina-embeddings-v2-base-en", "dim": 768, @@ -35,20 +47,10 @@ ] -class JinaOnnxEmbedding(OnnxTextEmbedding): +class PooledNormalizedEmbedding(PooledEmbedding): @classmethod def _get_worker_class(cls) -> Type[TextEmbeddingWorker]: - return JinaEmbeddingWorker - - @classmethod - def mean_pooling(cls, model_output, attention_mask) -> np.ndarray: - token_embeddings = model_output - input_mask_expanded = (np.expand_dims(attention_mask, axis=-1)).astype(float) - - sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1) - mask_sum = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None) - - return sum_embeddings / mask_sum + return PooledNormalizedEmbeddingWorker @classmethod def list_supported_models(cls) -> List[Dict[str, Any]]: @@ -57,7 +59,7 @@ def list_supported_models(cls) -> List[Dict[str, Any]]: Returns: List[Dict[str, Any]]: A list of dictionaries containing the model information. """ - return supported_jina_models + return supported_pooled_normalized_models def _post_process_onnx_output( self, output: OnnxOutputContext @@ -67,10 +69,10 @@ def _post_process_onnx_output( return normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32) -class JinaEmbeddingWorker(OnnxTextEmbeddingWorker): +class PooledNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker): def init_embedding( self, model_name: str, cache_dir: str, **kwargs ) -> OnnxTextEmbedding: - return JinaOnnxEmbedding( + return PooledNormalizedEmbedding( model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs ) diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index 02d1b52e..3141043c 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -5,8 +5,8 @@ from fastembed.common import OnnxProvider from fastembed.text.clip_embedding import CLIPOnnxEmbedding from fastembed.text.e5_onnx_embedding import E5OnnxEmbedding -from fastembed.text.jina_onnx_embedding import JinaOnnxEmbedding -from fastembed.text.mini_lm_embedding import MiniLMOnnxEmbedding +from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding +from fastembed.text.pooled_embedding import PooledEmbedding from fastembed.text.onnx_embedding import OnnxTextEmbedding from fastembed.text.text_embedding_base import TextEmbeddingBase @@ -15,9 +15,9 @@ class TextEmbedding(TextEmbeddingBase): EMBEDDINGS_REGISTRY: List[Type[TextEmbeddingBase]] = [ OnnxTextEmbedding, E5OnnxEmbedding, - JinaOnnxEmbedding, CLIPOnnxEmbedding, - MiniLMOnnxEmbedding, + PooledNormalizedEmbedding, + PooledEmbedding, ] @classmethod diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 7ca4005c..945ad265 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -45,16 +45,16 @@ [-0.0332, -0.0509, 0.0287, -0.0043, -0.0077] ), "jinaai/jina-embeddings-v2-base-de": np.array( - [-0.0085, 0.0417, 0.0342, 0.0309, -0.0149] + [-0.0085, 0.0417, 0.0342, 0.0309, -0.0149] ), "nomic-ai/nomic-embed-text-v1": np.array( - [0.0061, 0.0103, -0.0296, -0.0242, -0.0170] + [0.3708 , 0.2031, -0.3406, -0.2114, -0.3230] ), "nomic-ai/nomic-embed-text-v1.5": np.array( - [-1.6531514e-02, 8.5380634e-05, -1.8171231e-01, -3.9333291e-03, 1.2763254e-02] + [-0.15407836, -0.03053198, -3.9138033, 0.1910364, 0.13224715] ), "nomic-ai/nomic-embed-text-v1.5-Q": np.array( - [-0.01554983, 0.0129992, -0.17909265, -0.01062993, 0.00512859] + [-0.12525563, 0.38030425, -3.961622 , 0.04176439, -0.0758301] ), "thenlper/gte-large": np.array( [-0.01920587, 0.00113156, -0.00708992, -0.00632304, -0.04025577]