Skip to content

Commit

Permalink
refactor: Call one model
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Feb 28, 2025
1 parent 671b874 commit 44644e7
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 89 deletions.
10 changes: 1 addition & 9 deletions tests/test_image_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,9 @@
}

ALL_IMAGE_MODEL_DESC = ImageEmbedding._list_supported_models()
smallest_model = min(ALL_IMAGE_MODEL_DESC, key=lambda m: m.size_in_GB).model


@pytest.mark.parametrize(
"model_name",
[
smallest_model
if smallest_model in CANONICAL_VECTOR_VALUES
else "Qdrant/clip-ViT-B-32-vision"
],
)
@pytest.mark.parametrize("model_name", ["Qdrant/clip-ViT-B-32-vision"])
def test_embedding(model_name: str) -> None:
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
Expand Down
20 changes: 4 additions & 16 deletions tests/test_late_interaction_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,7 @@ def test_batch_embedding(model_name: str):
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
["answerdotai/answerai-colbert-small-v1"],
)
@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"])
def test_single_embedding(model_name: str):
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
Expand Down Expand Up @@ -203,10 +200,7 @@ def test_single_embedding(model_name: str):
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
["answerdotai/answerai-colbert-small-v1"],
)
@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"])
def test_single_embedding_query(model_name: str):
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
Expand Down Expand Up @@ -236,10 +230,7 @@ def test_single_embedding_query(model_name: str):
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"token_dim,model_name",
[(96, "answerdotai/answerai-colbert-small-v1")],
)
@pytest.mark.parametrize("token_dim,model_name", [(96, "answerdotai/answerai-colbert-small-v1")])
def test_parallel_processing(token_dim: int, model_name: str):
is_ci = os.getenv("CI")
model = LateInteractionTextEmbedding(model_name=model_name)
Expand All @@ -262,10 +253,7 @@ def test_parallel_processing(token_dim: int, model_name: str):
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
["answerdotai/answerai-colbert-small-v1"],
)
@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"])
def test_lazy_load(model_name: str):
is_ci = os.getenv("CI")

Expand Down
14 changes: 2 additions & 12 deletions tests/test_sparse_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,7 @@ def test_batch_embedding() -> None:
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
[
min(ALL_SPARSE_MODEL_DESC, key=lambda m: m.size_in_GB).model
if CANONICAL_COLUMN_VALUES
else "prithivida/Splade_PP_en_v1"
],
)
@pytest.mark.parametrize("model_name", ["prithivida/Splade_PP_en_v1"])
def test_single_embedding(model_name: str) -> None:
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
Expand Down Expand Up @@ -195,10 +188,7 @@ def test_disable_stemmer_behavior(disable_stemmer: bool) -> None:
assert result == expected, f"Expected {expected}, but got {result}"


@pytest.mark.parametrize(
"model_name",
["prithivida/Splade_PP_en_v1"],
)
@pytest.mark.parametrize("model_name", ["prithivida/Splade_PP_en_v1"])
def test_lazy_load(model_name: str) -> None:
is_ci = os.getenv("CI")
model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
Expand Down
24 changes: 4 additions & 20 deletions tests/test_text_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,7 @@
ALL_RERANK_MODEL_DESC = TextCrossEncoder._list_supported_models()


@pytest.mark.parametrize(
"model_name",
[
min(ALL_RERANK_MODEL_DESC, key=lambda m: m.size_in_GB).model
if CANONICAL_SCORE_VALUES
else "Xenova/ms-marco-MiniLM-L-6-v2"
],
)
@pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"])
def test_rerank(model_name: str) -> None:
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
Expand Down Expand Up @@ -61,10 +54,7 @@ def test_rerank(model_name: str) -> None:
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
["Xenova/ms-marco-MiniLM-L-6-v2"],
)
@pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"])
def test_batch_rerank(model_name: str) -> None:
is_ci = os.getenv("CI")

Expand All @@ -90,10 +80,7 @@ def test_batch_rerank(model_name: str) -> None:
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
["Xenova/ms-marco-MiniLM-L-6-v2"],
)
@pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"])
def test_lazy_load(model_name: str) -> None:
is_ci = os.getenv("CI")
model = TextCrossEncoder(model_name=model_name, lazy_load=True)
Expand All @@ -107,10 +94,7 @@ def test_lazy_load(model_name: str) -> None:
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
["Xenova/ms-marco-MiniLM-L-6-v2"],
)
@pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"])
def test_rerank_pairs_parallel(model_name: str) -> None:
is_ci = os.getenv("CI")

Expand Down
30 changes: 6 additions & 24 deletions tests/test_text_multitask_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,7 @@
docs = ["Hello World", "Follow the white rabbit."]


@pytest.mark.parametrize(
"dim,model_name",
[(1024, "jinaai/jina-embeddings-v3")],
)
@pytest.mark.parametrize("dim,model_name", [(1024, "jinaai/jina-embeddings-v3")])
def test_batch_embedding(dim: int, model_name: str):
is_ci = os.getenv("CI")
docs_to_embed = docs * 10
Expand All @@ -85,10 +82,7 @@ def test_batch_embedding(dim: int, model_name: str):
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
["jinaai/jina-embeddings-v3"],
)
@pytest.mark.parametrize("model_name", ["jinaai/jina-embeddings-v3"])
def test_single_embedding(model_name: str):
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
Expand Down Expand Up @@ -128,10 +122,7 @@ def test_single_embedding(model_name: str):
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
["jinaai/jina-embeddings-v3"],
)
@pytest.mark.parametrize("model_name", ["jinaai/jina-embeddings-v3"])
def test_single_embedding_query(model_name: str):
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
Expand Down Expand Up @@ -171,10 +162,7 @@ def test_single_embedding_query(model_name: str):
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
["jinaai/jina-embeddings-v3"],
)
@pytest.mark.parametrize("model_name", ["jinaai/jina-embeddings-v3"])
def test_single_embedding_passage(model_name: str):
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
Expand Down Expand Up @@ -214,10 +202,7 @@ def test_single_embedding_passage(model_name: str):
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"dim,model_name",
[(1024, "jinaai/jina-embeddings-v3")],
)
@pytest.mark.parametrize("dim,model_name", [(1024, "jinaai/jina-embeddings-v3")])
def test_parallel_processing(dim: int, model_name: str):
is_ci = os.getenv("CI")

Expand Down Expand Up @@ -263,10 +248,7 @@ def test_task_assignment():
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
["jinaai/jina-embeddings-v3"],
)
@pytest.mark.parametrize("model_name", ["jinaai/jina-embeddings-v3"])
def test_lazy_load(model_name: str):
is_ci = os.getenv("CI")
model = TextEmbedding(model_name=model_name, lazy_load=True)
Expand Down
9 changes: 1 addition & 8 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,7 @@
ALL_TEXT_MODEL_DESC = TextEmbedding._list_supported_models()


@pytest.mark.parametrize(
"model_name",
[
min(ALL_TEXT_MODEL_DESC, key=lambda m: m.size_in_GB).model
if CANONICAL_VECTOR_VALUES
else "BAAI/bge-small-en-v1.5"
],
)
@pytest.mark.parametrize("model_name", ["BAAI/bge-small-en-v1.5"])
def test_embedding(model_name: str) -> None:
is_ci = os.getenv("CI")
is_mac = platform.system() == "Darwin"
Expand Down

0 comments on commit 44644e7

Please sign in to comment.