Skip to content

Commit

Permalink
Nomic-embeddings-support (#280)
Browse files Browse the repository at this point in the history
* Nomic-embeddings-support

* Jina models moved to pooled-normalized embeddings

* Canonical vector for nomic-ai/nomic-embed-text-v1.5-Q

* Moved all nomics to pooled_embeddings

---------

Co-authored-by: d.rudenko <dimitriyrudenk@gmail.com>
  • Loading branch information
I8dNLo and d.rudenko authored Jul 10, 2024
1 parent 9387ca3 commit d09af55
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 114 deletions.
58 changes: 0 additions & 58 deletions fastembed/text/mini_lm_embedding.py

This file was deleted.

38 changes: 6 additions & 32 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,36 +80,6 @@
},
"model_file": "model_optimized.onnx",
},
{
"model": "nomic-ai/nomic-embed-text-v1",
"dim": 768,
"description": "8192 context length english model",
"size_in_GB": 0.52,
"sources": {
"hf": "nomic-ai/nomic-embed-text-v1",
},
"model_file": "onnx/model.onnx",
},
{
"model": "nomic-ai/nomic-embed-text-v1.5",
"dim": 768,
"description": "8192 context length english model",
"size_in_GB": 0.52,
"sources": {
"hf": "nomic-ai/nomic-embed-text-v1.5",
},
"model_file": "onnx/model.onnx",
},
{
"model": "nomic-ai/nomic-embed-text-v1.5-Q",
"dim": 768,
"description": "Quantized 8192 context length english model",
"size_in_GB": 0.13,
"sources": {
"hf": "nomic-ai/nomic-embed-text-v1.5",
},
"model_file": "onnx/model_quantized.onnx",
},
{
"model": "thenlper/gte-large",
"dim": 1024,
Expand Down Expand Up @@ -274,7 +244,9 @@ def _preprocess_onnx_input(
"""
return onnx_input

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
def _post_process_onnx_output(
self, output: OnnxOutputContext
) -> Iterable[np.ndarray]:
embeddings = output.model_output
return normalize(embeddings[:, 0]).astype(np.float32)

Expand All @@ -286,4 +258,6 @@ def init_embedding(
cache_dir: str,
**kwargs,
) -> OnnxTextEmbedding:
return OnnxTextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs)
return OnnxTextEmbedding(
model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs
)
87 changes: 87 additions & 0 deletions fastembed/text/pooled_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Any, Dict, Iterable, List, Type

import numpy as np

from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.utils import normalize
from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
from fastembed.text.onnx_text_model import TextEmbeddingWorker

supported_pooled_models = [
{
"model": "nomic-ai/nomic-embed-text-v1.5",
"dim": 768,
"description": "8192 context length english model",
"size_in_GB": 0.52,
"sources": {
"hf": "nomic-ai/nomic-embed-text-v1.5",
},
"model_file": "onnx/model.onnx",
},
{
"model": "nomic-ai/nomic-embed-text-v1.5-Q",
"dim": 768,
"description": "Quantized 8192 context length english model",
"size_in_GB": 0.13,
"sources": {
"hf": "nomic-ai/nomic-embed-text-v1.5",
},
"model_file": "onnx/model_quantized.onnx",
},
{
"model": "nomic-ai/nomic-embed-text-v1",
"dim": 768,
"description": "8192 context length english model",
"size_in_GB": 0.52,
"sources": {
"hf": "nomic-ai/nomic-embed-text-v1",
},
"model_file": "onnx/model.onnx",
},
]


class PooledEmbedding(OnnxTextEmbedding):
@classmethod
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
return PooledEmbeddingWorker

@classmethod
def mean_pooling(
cls, model_output: np.ndarray, attention_mask: np.ndarray
) -> np.ndarray:
token_embeddings = model_output
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
input_mask_expanded = np.tile(
input_mask_expanded, (1, 1, token_embeddings.shape[-1])
)
input_mask_expanded = input_mask_expanded.astype(float)
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
sum_mask = np.sum(input_mask_expanded, axis=1)
pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9)
return pooled_embeddings

@classmethod
def list_supported_models(cls) -> List[Dict[str, Any]]:
"""Lists the supported models.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_pooled_models

def _post_process_onnx_output(
self, output: OnnxOutputContext
) -> Iterable[np.ndarray]:
embeddings = output.model_output
attn_mask = output.attention_mask
return self.mean_pooling(embeddings, attn_mask).astype(np.float32)


class PooledEmbeddingWorker(OnnxTextEmbeddingWorker):
def init_embedding(
self, model_name: str, cache_dir: str, **kwargs
) -> OnnxTextEmbedding:
return PooledEmbedding(
model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,20 @@
from fastembed.common.utils import normalize
from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
from fastembed.text.onnx_text_model import TextEmbeddingWorker
from fastembed.text.pooled_embedding import PooledEmbedding

supported_jina_models = [
supported_pooled_normalized_models = [
{
"model": "sentence-transformers/all-MiniLM-L6-v2",
"dim": 384,
"description": "Sentence Transformer model, MiniLM-L6-v2",
"size_in_GB": 0.09,
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz",
"hf": "qdrant/all-MiniLM-L6-v2-onnx",
},
"model_file": "model.onnx",
},
{
"model": "jinaai/jina-embeddings-v2-base-en",
"dim": 768,
Expand Down Expand Up @@ -35,20 +47,10 @@
]


class JinaOnnxEmbedding(OnnxTextEmbedding):
class PooledNormalizedEmbedding(PooledEmbedding):
@classmethod
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
return JinaEmbeddingWorker

@classmethod
def mean_pooling(cls, model_output, attention_mask) -> np.ndarray:
token_embeddings = model_output
input_mask_expanded = (np.expand_dims(attention_mask, axis=-1)).astype(float)

sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
mask_sum = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)

return sum_embeddings / mask_sum
return PooledNormalizedEmbeddingWorker

@classmethod
def list_supported_models(cls) -> List[Dict[str, Any]]:
Expand All @@ -57,7 +59,7 @@ def list_supported_models(cls) -> List[Dict[str, Any]]:
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_jina_models
return supported_pooled_normalized_models

def _post_process_onnx_output(
self, output: OnnxOutputContext
Expand All @@ -67,10 +69,10 @@ def _post_process_onnx_output(
return normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)


class JinaEmbeddingWorker(OnnxTextEmbeddingWorker):
class PooledNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker):
def init_embedding(
self, model_name: str, cache_dir: str, **kwargs
) -> OnnxTextEmbedding:
return JinaOnnxEmbedding(
return PooledNormalizedEmbedding(
model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs
)
8 changes: 4 additions & 4 deletions fastembed/text/text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from fastembed.common import OnnxProvider
from fastembed.text.clip_embedding import CLIPOnnxEmbedding
from fastembed.text.e5_onnx_embedding import E5OnnxEmbedding
from fastembed.text.jina_onnx_embedding import JinaOnnxEmbedding
from fastembed.text.mini_lm_embedding import MiniLMOnnxEmbedding
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
from fastembed.text.pooled_embedding import PooledEmbedding
from fastembed.text.onnx_embedding import OnnxTextEmbedding
from fastembed.text.text_embedding_base import TextEmbeddingBase

Expand All @@ -15,9 +15,9 @@ class TextEmbedding(TextEmbeddingBase):
EMBEDDINGS_REGISTRY: List[Type[TextEmbeddingBase]] = [
OnnxTextEmbedding,
E5OnnxEmbedding,
JinaOnnxEmbedding,
CLIPOnnxEmbedding,
MiniLMOnnxEmbedding,
PooledNormalizedEmbedding,
PooledEmbedding,
]

@classmethod
Expand Down
8 changes: 4 additions & 4 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@
[-0.0332, -0.0509, 0.0287, -0.0043, -0.0077]
),
"jinaai/jina-embeddings-v2-base-de": np.array(
[-0.0085, 0.0417, 0.0342, 0.0309, -0.0149]
[-0.0085, 0.0417, 0.0342, 0.0309, -0.0149]
),
"nomic-ai/nomic-embed-text-v1": np.array(
[0.0061, 0.0103, -0.0296, -0.0242, -0.0170]
[0.3708 , 0.2031, -0.3406, -0.2114, -0.3230]
),
"nomic-ai/nomic-embed-text-v1.5": np.array(
[-1.6531514e-02, 8.5380634e-05, -1.8171231e-01, -3.9333291e-03, 1.2763254e-02]
[-0.15407836, -0.03053198, -3.9138033, 0.1910364, 0.13224715]
),
"nomic-ai/nomic-embed-text-v1.5-Q": np.array(
[-0.01554983, 0.0129992, -0.17909265, -0.01062993, 0.00512859]
[-0.12525563, 0.38030425, -3.961622 , 0.04176439, -0.0758301]
),
"thenlper/gte-large": np.array(
[-0.01920587, 0.00113156, -0.00708992, -0.00632304, -0.04025577]
Expand Down

0 comments on commit d09af55

Please sign in to comment.