Skip to content

Commit

Permalink
fix: fix sparse vector sort
Browse files Browse the repository at this point in the history
  • Loading branch information
joein committed Jan 19, 2024
1 parent fa11d35 commit af7101c
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 30 deletions.
2 changes: 1 addition & 1 deletion qdrant_client/local/local_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def search(
scores = calculate_context_scores(query_vector, vectors[: len(self.payload)], distance)
elif isinstance(query_vector, SparseVector):
# sparse vector query must be sorted by indices for dot product to work with persisted vectors
sort_sparse_vector(query_vector)
query_vector = sort_sparse_vector(query_vector)
sparse_scoring = True
sparse_vectors = self.sparse_vectors[name]
scores = calculate_distance_sparse(query_vector, sparse_vectors[: len(self.payload)])
Expand Down
22 changes: 0 additions & 22 deletions qdrant_client/local/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,3 @@ def sparse_dot_product(vector1: SparseVector, vector2: SparseVector) -> Optional
return np.float32(result)
else:
return None


# Generate random sparse vector with given size and density
# The density is the probability of non-zero value over the whole vector
def generate_random_sparse_vector(size: int, density: float) -> SparseVector:
num_non_zero = int(size * density)
indices: List[int] = random.sample(range(size), num_non_zero)
values: List[float] = [round(random.random(), 6) for _ in range(num_non_zero)]
indices.sort()
sparse_vector = SparseVector(indices=indices, values=values)
validate_sparse_vector(sparse_vector)
return sparse_vector


def generate_random_sparse_vector_list(
num_vectors: int, vector_size: int, vector_density: float
) -> List[SparseVector]:
sparse_vector_list = []
for _ in range(num_vectors):
sparse_vector = generate_random_sparse_vector(vector_size, vector_density)
sparse_vector_list.append(sparse_vector)
return sparse_vector_list
12 changes: 8 additions & 4 deletions tests/congruence_tests/test_sparse_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@
from qdrant_client.client_base import QdrantBase
from qdrant_client.conversions.common_types import NamedSparseVector
from qdrant_client.http.models import models
from qdrant_client.local.sparse import generate_random_sparse_vector
from tests.congruence_tests.test_common import (
COLLECTION_NAME,
compare_client_results,
generate_sparse_fixtures,
init_client,
init_local,
init_remote, sparse_text_vector_size, sparse_image_vector_size, sparse_code_vector_size, sparse_vectors_config,
init_remote,
sparse_code_vector_size,
sparse_image_vector_size,
sparse_text_vector_size,
sparse_vectors_config,
)
from tests.fixtures.filters import one_random_filter_please
from tests.fixtures.points import generate_random_sparse_vector


class TestSimpleSparseSearcher:
Expand Down Expand Up @@ -112,7 +116,7 @@ def simple_search_image_select_vector(self, client: QdrantBase) -> List[models.S
)

def filter_search_text(
self, client: QdrantBase, query_filter: models.Filter
self, client: QdrantBase, query_filter: models.Filter
) -> List[models.ScoredPoint]:
return client.search(
collection_name=COLLECTION_NAME,
Expand All @@ -123,7 +127,7 @@ def filter_search_text(
)

def filter_search_text_single(
self, client: QdrantBase, query_filter: models.Filter
self, client: QdrantBase, query_filter: models.Filter
) -> List[models.ScoredPoint]:
return client.search(
collection_name=COLLECTION_NAME,
Expand Down
28 changes: 26 additions & 2 deletions tests/fixtures/points.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from qdrant_client._pydantic_compat import construct
from qdrant_client.http import models
from qdrant_client.local.sparse import generate_random_sparse_vector
from qdrant_client.http.models import SparseVector
from qdrant_client.local.sparse import validate_sparse_vector
from tests.fixtures.payload import one_random_payload_please


Expand All @@ -24,7 +25,30 @@ def random_vectors(
raise ValueError("vector_sizes must be int or dict")


def random_sparse_vectors(vector_sizes: Union[Dict[str, int], int],) -> models.VectorStruct:
# Generate random sparse vector with given size and density
# The density is the probability of non-zero value over the whole vector
def generate_random_sparse_vector(size: int, density: float) -> SparseVector:
num_non_zero = int(size * density)
indices: List[int] = random.sample(range(size), num_non_zero)
values: List[float] = [round(random.random(), 6) for _ in range(num_non_zero)]
sparse_vector = SparseVector(indices=indices, values=values)
validate_sparse_vector(sparse_vector)
return sparse_vector


def generate_random_sparse_vector_list(
num_vectors: int, vector_size: int, vector_density: float
) -> List[SparseVector]:
sparse_vector_list = []
for _ in range(num_vectors):
sparse_vector = generate_random_sparse_vector(vector_size, vector_density)
sparse_vector_list.append(sparse_vector)
return sparse_vector_list


def random_sparse_vectors(
vector_sizes: Union[Dict[str, int], int],
) -> models.VectorStruct:
vectors = {}
for vector_name, vector_size in vector_sizes.items():
# use sparse vectors with 20% density
Expand Down
2 changes: 1 addition & 1 deletion tests/test_local_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import qdrant_client
import qdrant_client.http.models as rest
from qdrant_client._pydantic_compat import construct
from qdrant_client.local.sparse import generate_random_sparse_vector_list
from tests.fixtures.points import generate_random_sparse_vector_list

default_collection_name = "example"

Expand Down

0 comments on commit af7101c

Please sign in to comment.