Skip to content

Commit

Permalink
Upgrade FastEmbed Version (#493)
Browse files Browse the repository at this point in the history
* Update fastembed to v0.2.1

* chore(qdrant_fastembed.py): update DEFAULT_EMBEDDING_MODEL

* fix(fastembed integration): upgrade to latest version

* Prefer black over ruff

* Prefer black over ruff

* Remove hardcoded directory structure from Qdrant Client checks

* new: deprecate current default model, deprecate max token length, update fastembed

* fix: make embedding_model_name method sync

* fix: update poetry lock

* refactor: use list_supported_models() (#501)

* fix: fix fastembed check

* fix: fix fastembed class var assignment

* fix: remove fastembed deprecation from qdrant client (#524)

---------

Co-authored-by: George Panchuk <george.panchuk@qdrant.tech>
Co-authored-by: Anush <anushshetty90@gmail.com>
  • Loading branch information
3 people committed Mar 5, 2024
1 parent 7365f83 commit cb0aa80
Show file tree
Hide file tree
Showing 8 changed files with 414 additions and 163 deletions.
363 changes: 303 additions & 60 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ grpcio-tools = ">=1.41.0"
urllib3 = ">=1.26.14,<3"
portalocker = "^2.7.0"
fastembed = [
{ version = "0.1.1", optional = true, python = "<3.12" }
{ version = "0.2.2", optional = true, python = "<3.13" }
]

[tool.poetry.group.dev.dependencies]
Expand Down
8 changes: 0 additions & 8 deletions qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,6 @@ def __init__(
grpc_options=grpc_options,
**kwargs,
)
self._is_fastembed_installed: Optional[bool] = None
if self._is_fastembed_installed is None:
try:
from fastembed.embedding import DefaultEmbedding

self._is_fastembed_installed = True
except ImportError:
self._is_fastembed_installed = False

async def close(self, grpc_grace: Optional[float] = None, **kwargs: Any) -> None:
"""Closes the connection to Qdrant
Expand Down
85 changes: 48 additions & 37 deletions qdrant_client/async_qdrant_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# ****** WARNING: THIS FILE IS AUTOGENERATED ******

import uuid
import warnings
from itertools import tee
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

Expand All @@ -19,31 +20,44 @@
from qdrant_client.http import models

try:
from fastembed.embedding import DefaultEmbedding
from fastembed import TextEmbedding
except ImportError:
DefaultEmbedding = None
SUPPORTED_EMBEDDING_MODELS: Dict[str, Tuple[int, models.Distance]] = {
"BAAI/bge-base-en": (768, models.Distance.COSINE),
"sentence-transformers/all-MiniLM-L6-v2": (384, models.Distance.COSINE),
"BAAI/bge-small-en": (384, models.Distance.COSINE),
"BAAI/bge-small-en-v1.5": (384, models.Distance.COSINE),
"BAAI/bge-base-en-v1.5": (768, models.Distance.COSINE),
"intfloat/multilingual-e5-large": (1024, models.Distance.COSINE),
}
TextEmbedding = None
SUPPORTED_EMBEDDING_MODELS: Dict[str, Tuple[int, models.Distance]] = (
{
model["model"]: (model["dim"], models.Distance.COSINE)
for model in TextEmbedding.list_supported_models()
}
if TextEmbedding
else {}
)


class AsyncQdrantFastembedMixin(AsyncQdrantBase):
DEFAULT_EMBEDDING_MODEL = "BAAI/bge-small-en"
embedding_models: Dict[str, "DefaultEmbedding"] = {}
embedding_models: Dict[str, "TextEmbedding"] = {}
_FASTEMBED_INSTALLED: bool

def __init__(self, **kwargs: Any):
self.embedding_model_name = self.DEFAULT_EMBEDDING_MODEL
self._embedding_model_name: Optional[str] = None
try:
from fastembed import TextEmbedding

self.__class__._FASTEMBED_INSTALLED = True
except ImportError:
self.__class__._FASTEMBED_INSTALLED = False
super().__init__(**kwargs)

@property
def embedding_model_name(self) -> str:
if self._embedding_model_name is None:
self._embedding_model_name = self.DEFAULT_EMBEDDING_MODEL
return self._embedding_model_name

def set_model(
self,
embedding_model_name: str,
max_length: int = 512,
max_length: Optional[int] = None,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs: Any,
Expand All @@ -52,7 +66,7 @@ def set_model(
Set embedding model to use for encoding documents and queries.
Args:
embedding_model_name: One of the supported embedding models. See `SUPPORTED_EMBEDDING_MODELS` for details.
max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512.
max_length (int, optional): Deprecated. Defaults to None.
cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
Expand All @@ -64,26 +78,28 @@ def set_model(
Returns:
None
"""
if max_length is not None:
warnings.warn(
"max_length parameter is deprecated and will be removed in the future. It's not used by fastembed models.",
DeprecationWarning,
stacklevel=2,
)
self._get_or_init_model(
model_name=embedding_model_name,
max_length=max_length,
cache_dir=cache_dir,
threads=threads,
**kwargs,
model_name=embedding_model_name, cache_dir=cache_dir, threads=threads, **kwargs
)
self.embedding_model_name = embedding_model_name
self._embedding_model_name = embedding_model_name

@staticmethod
def _import_fastembed() -> None:
try:
from fastembed.embedding import DefaultEmbedding
except ImportError:
raise ImportError(
"fastembed is not installed. Please install it to enable fast vector indexing with `pip install fastembed`."
)
@classmethod
def _import_fastembed(cls) -> None:
if cls._FASTEMBED_INSTALLED:
return
raise ImportError(
"fastembed is not installed. Please install it to enable fast vector indexing with `pip install fastembed`."
)

@classmethod
def _get_model_params(cls, model_name: str) -> Tuple[int, models.Distance]:
cls._import_fastembed()
if model_name not in SUPPORTED_EMBEDDING_MODELS:
raise ValueError(
f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_EMBEDDING_MODELS}"
Expand All @@ -94,24 +110,19 @@ def _get_model_params(cls, model_name: str) -> Tuple[int, models.Distance]:
def _get_or_init_model(
cls,
model_name: str,
max_length: int = 512,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs: Any,
) -> "DefaultEmbedding":
) -> "TextEmbedding":
if model_name in cls.embedding_models:
return cls.embedding_models[model_name]
cls._import_fastembed()
if model_name not in SUPPORTED_EMBEDDING_MODELS:
raise ValueError(
f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_EMBEDDING_MODELS}"
)
cls._import_fastembed()
cls.embedding_models[model_name] = DefaultEmbedding(
model_name=model_name,
max_length=max_length,
cache_dir=cache_dir,
threads=threads,
**kwargs,
cls.embedding_models[model_name] = TextEmbedding(
model_name=model_name, cache_dir=cache_dir, threads=threads, **kwargs
)
return cls.embedding_models[model_name]

Expand Down
9 changes: 0 additions & 9 deletions qdrant_client/qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,6 @@ def __init__(
grpc_options=grpc_options,
**kwargs,
)
self._is_fastembed_installed: Optional[bool] = None
# if fastembed is installed, set to true else False
if self._is_fastembed_installed is None:
try:
from fastembed.embedding import DefaultEmbedding # noqa: F401

self._is_fastembed_installed = True
except ImportError:
self._is_fastembed_installed = False

def __del__(self) -> None:
self.close()
Expand Down
89 changes: 57 additions & 32 deletions qdrant_client/qdrant_fastembed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
import warnings
from itertools import tee
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

Expand All @@ -8,33 +9,49 @@
from qdrant_client.http import models

try:
from fastembed.embedding import DefaultEmbedding
from fastembed import TextEmbedding
except ImportError:
DefaultEmbedding = None
TextEmbedding = None

SUPPORTED_EMBEDDING_MODELS: Dict[str, Tuple[int, models.Distance]] = {
"BAAI/bge-base-en": (768, models.Distance.COSINE),
"sentence-transformers/all-MiniLM-L6-v2": (384, models.Distance.COSINE),
"BAAI/bge-small-en": (384, models.Distance.COSINE),
"BAAI/bge-small-en-v1.5": (384, models.Distance.COSINE),
"BAAI/bge-base-en-v1.5": (768, models.Distance.COSINE),
"intfloat/multilingual-e5-large": (1024, models.Distance.COSINE),
}

SUPPORTED_EMBEDDING_MODELS: Dict[str, Tuple[int, models.Distance]] = (
{
model["model"]: (model["dim"], models.Distance.COSINE)
for model in TextEmbedding.list_supported_models()
}
if TextEmbedding
else {}
)


class QdrantFastembedMixin(QdrantBase):
DEFAULT_EMBEDDING_MODEL = "BAAI/bge-small-en"

embedding_models: Dict[str, "DefaultEmbedding"] = {}
embedding_models: Dict[str, "TextEmbedding"] = {}

_FASTEMBED_INSTALLED: bool

def __init__(self, **kwargs: Any):
self.embedding_model_name = self.DEFAULT_EMBEDDING_MODEL
self._embedding_model_name: Optional[str] = None
try:
from fastembed import TextEmbedding # noqa: F401

self.__class__._FASTEMBED_INSTALLED = True
except ImportError:
self.__class__._FASTEMBED_INSTALLED = False

super().__init__(**kwargs)

@property
def embedding_model_name(self) -> str:
if self._embedding_model_name is None:
self._embedding_model_name = self.DEFAULT_EMBEDDING_MODEL
return self._embedding_model_name

def set_model(
self,
embedding_model_name: str,
max_length: int = 512,
max_length: Optional[int] = None,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs: Any,
Expand All @@ -43,7 +60,7 @@ def set_model(
Set embedding model to use for encoding documents and queries.
Args:
embedding_model_name: One of the supported embedding models. See `SUPPORTED_EMBEDDING_MODELS` for details.
max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512.
max_length (int, optional): Deprecated. Defaults to None.
cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
Expand All @@ -55,28 +72,38 @@ def set_model(
Returns:
None
"""

if max_length is not None:
warnings.warn(
"max_length parameter is deprecated and will be removed in the future. "
"It's not used by fastembed models.",
DeprecationWarning,
stacklevel=2,
)

self._get_or_init_model(
model_name=embedding_model_name,
max_length=max_length,
cache_dir=cache_dir,
threads=threads,
**kwargs,
)
self.embedding_model_name = embedding_model_name
self._embedding_model_name = embedding_model_name

@staticmethod
def _import_fastembed() -> None:
try:
from fastembed.embedding import DefaultEmbedding
except ImportError:
# If it's not, ask the user to install it
raise ImportError(
"fastembed is not installed."
" Please install it to enable fast vector indexing with `pip install fastembed`."
)
@classmethod
def _import_fastembed(cls) -> None:
if cls._FASTEMBED_INSTALLED:
return

# If it's not, ask the user to install it
raise ImportError(
"fastembed is not installed."
" Please install it to enable fast vector indexing with `pip install fastembed`."
)

@classmethod
def _get_model_params(cls, model_name: str) -> Tuple[int, models.Distance]:
cls._import_fastembed()

if model_name not in SUPPORTED_EMBEDDING_MODELS:
raise ValueError(
f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_EMBEDDING_MODELS}"
Expand All @@ -88,24 +115,22 @@ def _get_model_params(cls, model_name: str) -> Tuple[int, models.Distance]:
def _get_or_init_model(
cls,
model_name: str,
max_length: int = 512,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs: Any,
) -> "DefaultEmbedding": # -> Embedding: # noqa: F821
) -> "TextEmbedding":
if model_name in cls.embedding_models:
return cls.embedding_models[model_name]

cls._import_fastembed()

if model_name not in SUPPORTED_EMBEDDING_MODELS:
raise ValueError(
f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_EMBEDDING_MODELS}"
)

cls._import_fastembed()

cls.embedding_models[model_name] = DefaultEmbedding(
cls.embedding_models[model_name] = TextEmbedding(
model_name=model_name,
max_length=max_length,
cache_dir=cache_dir,
threads=threads,
**kwargs,
Expand Down
Loading

0 comments on commit cb0aa80

Please sign in to comment.