Skip to content

Commit

Permalink
Fixed qdrant migration on update of qdrant-client (#206)
Browse files Browse the repository at this point in the history
* Fixed qdrant migration on update of qdrant-client

* modified args parse

* modifed logging statements
  • Loading branch information
S1LV3RJ1NX authored Jun 5, 2024
1 parent 3455cb6 commit f8f48f0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
18 changes: 12 additions & 6 deletions backend/migration/qdrant_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,19 @@ def migrate_collection(
fetched_destination_collection = get_collection(
destination_backend_url, destination_collection_name, type="destination"
)
if fetched_destination_collection and not overwrite and not same_qdrant_loc:
if fetched_destination_collection and same_qdrant_loc:
raise Exception(
f"Destination collection '{destination_collection_name}' already exists. Either delete it or set overwrite flag to True"
f"Source and destination qdrant locations are same. Destination collection '{destination_collection_name}' already exists."
)

elif fetched_destination_collection and not overwrite and not same_qdrant_loc:
raise Exception(
f"Source and destination qdrant locations are different. Destination collection '{destination_collection_name}' already exists in the destination qdrant. Add --overwrite to overwrite the collection."
)
else:
logger.debug(
f"Destination collection '{destination_collection_name}' not found at destination. Proceeding..."
)
logger.debug(
f"Destination collection '{destination_collection_name}' not found at destination. Proceeding..."
)
except Exception as e:
raise e

Expand Down Expand Up @@ -200,7 +206,7 @@ def main():
type=bool,
help="Overwrite destination collection if exists in separate qdrant",
required=False,
default=False,
action="store_true",
)

args = parser.parse_args()
Expand Down
27 changes: 15 additions & 12 deletions backend/migration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from qdrant_client._pydantic_compat import to_dict
from qdrant_client.client_base import QdrantBase
from qdrant_client.http import models
from tqdm import tqdm

from backend.logger import logger

Expand Down Expand Up @@ -94,7 +95,9 @@ def _recreate_collection(
src_config = src_collection_info.config
src_payload_schema = src_collection_info.payload_schema

dest_client.recreate_collection(
# delete destination collection only from qdrant that was created while creating metadatastore entry
dest_client.delete_collection(destination_collection_name, timeout=300)
dest_client.create_collection(
destination_collection_name,
vectors_config=src_config.params.vectors,
sparse_vectors_config=src_config.params.sparse_vectors,
Expand All @@ -108,6 +111,7 @@ def _recreate_collection(
),
wal_config=models.WalConfigDiff(**to_dict(src_config.wal_config)),
quantization_config=src_config.quantization_config,
timeout=300,
)

_recreate_payload_schema(
Expand Down Expand Up @@ -154,20 +158,19 @@ def _migrate_collection(
# upload_records has been deprecated due to the usage of models.Record; models.Record has been deprecated as a
# structure for uploading due to a `shard_key` field, and now is used only as a result structure.
# since shard_keys are not supported in migration, we can safely type ignore here and use Records for uploading

while next_offset is not None:
records, next_offset = source_client.scroll(
source_collection_name,
offset=next_offset,
limit=batch_size,
with_vectors=True,
records, next_offset = tqdm(
source_client.scroll(
source_collection_name,
offset=next_offset,
limit=batch_size,
with_vectors=True,
)
)
dest_client.upload_points(destination_collection_name, records, wait=True)
source_client_vectors_count = source_client.get_collection(
source_collection_name
).vectors_count
dest_client_vectors_count = dest_client.get_collection(
destination_collection_name
).vectors_count
source_client_vectors_count = source_client.count(source_collection_name).count
dest_client_vectors_count = dest_client.count(destination_collection_name).count

if source_client_vectors_count != dest_client_vectors_count:
warnings.warn(
Expand Down

0 comments on commit f8f48f0

Please sign in to comment.