Skip to content

Commit

Permalink
feat: Expose Setting for GRPC Channel-Level Compression at Client Side (
Browse files Browse the repository at this point in the history
#480)

* expose grpc channel-level compression settings in base functions

* expose grpc channel-level compression settings in remote classes

* expose grpc channel-level compression settings in client

* raise TypeError for compression

* added test cases for grcp channel-level compression

* move grpc_compression parameter from client's signature to **kwargs

* use grpc.Compression instead of creating new enum qdrant.grpc.Compression in qdrant/grpc/__init__.py

* refactor grpc_compression type hint

* fix: Compression instead of grpc.Compression in type hint

* tests: move and update tests

* chore: remove magic method

* fix: fix async client generator, update precommit dependencies

* fix: update isort options

* fix: update dev dependencies

---------

Co-authored-by: George Panchuk <george.panchuk@qdrant.tech>
  • Loading branch information
geetu040 and joein authored Mar 1, 2024
1 parent 09d51e6 commit 3ad05b3
Show file tree
Hide file tree
Showing 8 changed files with 597 additions and 569 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ repos:
- id: check-added-large-files

- repo: https://github.com/psf/black
rev: 23.1.0
rev: 23.12.1
hooks:
- id: black
name: "Black: The uncompromising Python code formatter"
args: ["--line-length", "99"]

- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
name: "Sort Imports"
args: ["--profile", "black"]
args: ["--profile", "black", "--py", "310"]
1,089 changes: 534 additions & 555 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ coverage = "^6.3.3"
pytest-asyncio = "^0.21.0"
pytest-timeout = "^2.1.0"
autoflake = "^2.2.1"
isort = "^5.12.0"
black = "^23.9.1"
isort = "^5.13.0"
black = "^23.12.1"

[tool.poetry.group.docs.dependencies]
sphinx = "^4.5.0"
Expand Down Expand Up @@ -71,3 +71,6 @@ markers = [
"fastembed: marks tests that require the fastembed package (deselect with '-m \"not fastembed\"')",
"no_fastembed: marks tests that do not require the fastembed package (deselect with '-m \"not no_fastembed\"')"
]

[tool.isort]
known_third_party = "grpc"
12 changes: 12 additions & 0 deletions qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import httpx
import numpy as np
from grpc import Compression
from urllib3.util import Url, parse_url

from qdrant_client import grpc as grpc
Expand Down Expand Up @@ -111,6 +112,16 @@ def __init__(
warnings.warn("Api key is used with unsecure connection.")
self._rest_headers["api-key"] = api_key
self._grpc_headers.append(("api-key", api_key))
grpc_compression: Optional[Compression] = kwargs.pop("grpc_compression", None)
if grpc_compression is not None and (not isinstance(grpc_compression, Compression)):
raise TypeError(
f"Expected 'grpc_compression' to be of type grpc.Compression or None, but got {type(grpc_compression)}"
)
if grpc_compression == Compression.Deflate:
raise ValueError(
"grpc.Compression.Deflate is not supported. Try grpc.Compression.Gzip or grpc.Compression.NoCompression"
)
self._grpc_compression = grpc_compression
address = f"{self._host}:{self._port}" if self._port is not None else self._host
self.rest_uri = f"{self._scheme}://{address}{self._prefix}"
self._rest_args = {"headers": self._rest_headers, "http2": http2, **kwargs}
Expand Down Expand Up @@ -170,6 +181,7 @@ def _init_grpc_channel(self) -> None:
ssl=self._https,
metadata=self._grpc_headers,
options=self._grpc_options,
compression=self._grpc_compression,
)

def _init_grpc_points_client(self) -> None:
Expand Down
14 changes: 8 additions & 6 deletions qdrant_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def get_channel(
ssl: bool,
metadata: Optional[List[Tuple[str, str]]] = None,
options: Optional[Dict[str, Any]] = None,
compression: Optional[grpc.Compression] = None,
) -> grpc.Channel:
# gRPC client options
_options = parse_channel_options(options)
Expand All @@ -223,14 +224,14 @@ def metadata_callback(context: Any, callback: Any) -> None:
creds = grpc.ssl_channel_credentials()

# finally pass in the combined credentials when creating a channel
return grpc.secure_channel(f"{host}:{port}", creds, _options)
return grpc.secure_channel(f"{host}:{port}", creds, _options, compression)
else:
if metadata:
metadata_interceptor = header_adder_interceptor(metadata)
channel = grpc.insecure_channel(f"{host}:{port}", _options)
channel = grpc.insecure_channel(f"{host}:{port}", _options, compression)
return grpc.intercept_channel(channel, metadata_interceptor)
else:
return grpc.insecure_channel(f"{host}:{port}", _options)
return grpc.insecure_channel(f"{host}:{port}", _options, compression)


def get_async_channel(
Expand All @@ -239,6 +240,7 @@ def get_async_channel(
ssl: bool,
metadata: Optional[List[Tuple[str, str]]] = None,
options: Optional[Dict[str, Any]] = None,
compression: Optional[grpc.Compression] = None,
) -> grpc.aio.Channel:
# gRPC client options
_options = parse_channel_options(options)
Expand All @@ -263,12 +265,12 @@ def metadata_callback(context: Any, callback: Any) -> None:
creds = grpc.ssl_channel_credentials()

# finally pass in the combined credentials when creating a channel
return grpc.aio.secure_channel(f"{host}:{port}", creds, _options)
return grpc.aio.secure_channel(f"{host}:{port}", creds, _options, compression)
else:
if metadata:
metadata_interceptor = header_adder_async_interceptor(metadata)
return grpc.aio.insecure_channel(
f"{host}:{port}", _options, interceptors=[metadata_interceptor]
f"{host}:{port}", _options, compression, interceptors=[metadata_interceptor]
)
else:
return grpc.aio.insecure_channel(f"{host}:{port}", _options)
return grpc.aio.insecure_channel(f"{host}:{port}", _options, compression)
16 changes: 16 additions & 0 deletions qdrant_client/qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import httpx
import numpy as np
from grpc import Compression
from urllib3.util import Url, parse_url

from qdrant_client import grpc as grpc
Expand Down Expand Up @@ -126,6 +127,19 @@ def __init__(
self._rest_headers["api-key"] = api_key
self._grpc_headers.append(("api-key", api_key))

# GRPC Channel-Level Compression
grpc_compression: Optional[Compression] = kwargs.pop("grpc_compression", None)
if grpc_compression is not None and not isinstance(grpc_compression, Compression):
raise TypeError(
f"Expected 'grpc_compression' to be of type "
f"grpc.Compression or None, but got {type(grpc_compression)}"
)
if grpc_compression == Compression.Deflate:
raise ValueError(
"grpc.Compression.Deflate is not supported. Try grpc.Compression.Gzip or grpc.Compression.NoCompression"
)
self._grpc_compression = grpc_compression

address = f"{self._host}:{self._port}" if self._port is not None else self._host
self.rest_uri = f"{self._scheme}://{address}{self._prefix}"

Expand Down Expand Up @@ -206,6 +220,7 @@ def _init_grpc_channel(self) -> None:
ssl=self._https,
metadata=self._grpc_headers,
options=self._grpc_options,
compression=self._grpc_compression,
)

def _init_async_grpc_channel(self) -> None:
Expand All @@ -219,6 +234,7 @@ def _init_async_grpc_channel(self) -> None:
ssl=self._https,
metadata=self._grpc_headers,
options=self._grpc_options,
compression=self._grpc_compression,
)

def _init_grpc_points_client(self) -> None:
Expand Down
18 changes: 17 additions & 1 deletion tests/test_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import pytest
from grpc import RpcError
from grpc import Compression, RpcError

from qdrant_client import QdrantClient, models
from qdrant_client._pydantic_compat import to_dict
Expand Down Expand Up @@ -48,6 +48,7 @@
)
from qdrant_client.qdrant_remote import QdrantRemote
from qdrant_client.uploader.grpc_uploader import payload_to_grpc
from tests.congruence_tests.test_common import generate_fixtures, init_client
from tests.fixtures.payload import (
one_random_payload_please,
random_payload,
Expand Down Expand Up @@ -1730,6 +1731,21 @@ def test_grpc_options():
)


def test_grpc_compression():
client = QdrantClient(prefer_grpc=True, grpc_compression=Compression.Gzip)
client.get_collections()

client = QdrantClient(prefer_grpc=True, grpc_compression=Compression.NoCompression)
client.get_collections()

with pytest.raises(ValueError):
# creates a grpc client with not supported Compression type
QdrantClient(prefer_grpc=True, grpc_compression=Compression.Deflate)

with pytest.raises(TypeError):
QdrantClient(prefer_grpc=True, grpc_compression="gzip")


if __name__ == "__main__":
test_qdrant_client_integration()
test_points_crud()
Expand Down
4 changes: 2 additions & 2 deletions tools/generate_async_client.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mv async_qdrant_local.py $ABSOLUTE_PROJECT_ROOT/qdrant_client/async_qdrant_local
cd $ABSOLUTE_PROJECT_ROOT/qdrant_client

ls -1 async*.py | autoflake --recursive --imports qdrant_client --remove-unused-variables --in-place async*.py
ls -1 async*.py | xargs -I {} isort --profile black --py 39 {}
ls -1 async*.py | xargs -I {} black -l 99 --target-version py39 {}
ls -1 async*.py | xargs -I {} isort --profile black --py 310 {}
ls -1 async*.py | xargs -I {} black -l 99 --target-version py310 {}

mv async_qdrant_local.py local/async_qdrant_local.py

0 comments on commit 3ad05b3

Please sign in to comment.