diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 5c33d231..dcbfcc02 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -1,9 +1,10 @@ name: Tests on: - push: - branches: [ master, main, gpu ] pull_request: + branches: [ master, main, gpu ] + workflow_dispatch: + env: CARGO_TERM_COLOR: always @@ -42,4 +43,4 @@ jobs: - name: Run pytest run: | - poetry run pytest + poetry run pytest \ No newline at end of file diff --git a/tests/test_image_onnx_embeddings.py b/tests/test_image_onnx_embeddings.py index 0d562279..a9220f00 100644 --- a/tests/test_image_onnx_embeddings.py +++ b/tests/test_image_onnx_embeddings.py @@ -8,7 +8,7 @@ from fastembed import ImageEmbedding from tests.config import TEST_MISC_DIR -from tests.utils import delete_model_cache +from tests.utils import delete_model_cache, should_test_model CANONICAL_VECTOR_VALUES = { "Qdrant/clip-ViT-B-32-vision": np.array([-0.0098, 0.0128, -0.0274, 0.002, -0.0059]), @@ -27,11 +27,13 @@ } -def test_embedding() -> None: +@pytest.mark.parametrize("model_name", ["Qdrant/clip-ViT-B-32-vision"]) +def test_embedding(model_name: str) -> None: is_ci = os.getenv("CI") + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" for model_desc in ImageEmbedding._list_supported_models(): - if not is_ci and model_desc.size_in_GB > 1: + if not should_test_model(model_desc, model_name, is_ci, is_manual): continue dim = model_desc.dim @@ -74,8 +76,12 @@ def test_batch_embedding(n_dims: int, model_name: str) -> None: embeddings = list(model.embed(images, batch_size=10)) embeddings = np.stack(embeddings, axis=0) + assert np.allclose(embeddings[1], embeddings[2]) + + canonical_vector = CANONICAL_VECTOR_VALUES[model_name] assert embeddings.shape == (len(test_images) * n_images, n_dims) + assert np.allclose(embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3) if is_ci: delete_model_cache(model.model._model_dir) diff --git a/tests/test_late_interaction_embeddings.py b/tests/test_late_interaction_embeddings.py index 613895df..6f16fdc2 100644 --- a/tests/test_late_interaction_embeddings.py +++ b/tests/test_late_interaction_embeddings.py @@ -6,7 +6,7 @@ from fastembed.late_interaction.late_interaction_text_embedding import ( LateInteractionTextEmbedding, ) -from tests.utils import delete_model_cache +from tests.utils import delete_model_cache, should_test_model # vectors are abridged and rounded for brevity CANONICAL_COLUMN_VALUES = { @@ -153,31 +153,37 @@ docs = ["Hello World"] -def test_batch_embedding(): +@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"]) +def test_batch_embedding(model_name: str): is_ci = os.getenv("CI") docs_to_embed = docs * 10 - for model_name, expected_result in CANONICAL_COLUMN_VALUES.items(): - print("evaluating", model_name) - model = LateInteractionTextEmbedding(model_name=model_name) - result = list(model.embed(docs_to_embed, batch_size=6)) + model = LateInteractionTextEmbedding(model_name=model_name) + result = list(model.embed(docs_to_embed, batch_size=6)) + expected_result = CANONICAL_COLUMN_VALUES[model_name] - for value in result: - token_num, abridged_dim = expected_result.shape - assert np.allclose(value[:, :abridged_dim], expected_result, atol=2e-3) + for value in result: + token_num, abridged_dim = expected_result.shape + assert np.allclose(value[:, :abridged_dim], expected_result, atol=2e-3) - if is_ci: - delete_model_cache(model.model._model_dir) + if is_ci: + delete_model_cache(model.model._model_dir) -def test_single_embedding(): +@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"]) +def test_single_embedding(model_name: str): is_ci = os.getenv("CI") + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" docs_to_embed = docs - for model_name, expected_result in CANONICAL_COLUMN_VALUES.items(): + for model_desc in LateInteractionTextEmbedding._list_supported_models(): + if not should_test_model(model_desc, model_name, is_ci, is_manual): + continue + print("evaluating", model_name) model = LateInteractionTextEmbedding(model_name=model_name) result = next(iter(model.embed(docs_to_embed, batch_size=6))) + expected_result = CANONICAL_COLUMN_VALUES[model_name] token_num, abridged_dim = expected_result.shape assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3) @@ -185,14 +191,20 @@ def test_single_embedding(): delete_model_cache(model.model._model_dir) -def test_single_embedding_query(): +@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"]) +def test_single_embedding_query(model_name: str): is_ci = os.getenv("CI") + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" queries_to_embed = docs - for model_name, expected_result in CANONICAL_QUERY_VALUES.items(): + for model_desc in LateInteractionTextEmbedding._list_supported_models(): + if not should_test_model(model_desc, model_name, is_ci, is_manual): + continue + print("evaluating", model_name) model = LateInteractionTextEmbedding(model_name=model_name) result = next(iter(model.query_embed(queries_to_embed))) + expected_result = CANONICAL_QUERY_VALUES[model_name] token_num, abridged_dim = expected_result.shape assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3) @@ -200,10 +212,11 @@ def test_single_embedding_query(): delete_model_cache(model.model._model_dir) -def test_parallel_processing(): +@pytest.mark.parametrize("token_dim,model_name", [(96, "answerdotai/answerai-colbert-small-v1")]) +def test_parallel_processing(token_dim: int, model_name: str): is_ci = os.getenv("CI") - model = LateInteractionTextEmbedding(model_name="colbert-ir/colbertv2.0") - token_dim = 128 + model = LateInteractionTextEmbedding(model_name=model_name) + docs = ["hello world", "flag embedding"] * 100 embeddings = list(model.embed(docs, batch_size=10, parallel=2)) embeddings = np.stack(embeddings, axis=0) @@ -222,10 +235,7 @@ def test_parallel_processing(): delete_model_cache(model.model._model_dir) -@pytest.mark.parametrize( - "model_name", - ["colbert-ir/colbertv2.0"], -) +@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"]) def test_lazy_load(model_name: str): is_ci = os.getenv("CI") diff --git a/tests/test_late_interaction_multimodal.py b/tests/test_late_interaction_multimodal.py index 17139475..04b6f300 100644 --- a/tests/test_late_interaction_multimodal.py +++ b/tests/test_late_interaction_multimodal.py @@ -1,5 +1,6 @@ import os +import pytest from PIL import Image import numpy as np @@ -45,38 +46,38 @@ def test_batch_embedding(): - is_ci = os.getenv("CI") + if os.getenv("CI"): + pytest.skip("Colpali is too large to test in CI") - if not is_ci: - for model_name, expected_result in CANONICAL_IMAGE_VALUES.items(): - print("evaluating", model_name) - model = LateInteractionMultimodalEmbedding(model_name=model_name) - result = list(model.embed_image(images, batch_size=2)) + for model_name, expected_result in CANONICAL_IMAGE_VALUES.items(): + print("evaluating", model_name) + model = LateInteractionMultimodalEmbedding(model_name=model_name) + result = list(model.embed_image(images, batch_size=2)) - for value in result: - token_num, abridged_dim = expected_result.shape - assert np.allclose(value[:token_num, :abridged_dim], expected_result, atol=2e-3) + for value in result: + token_num, abridged_dim = expected_result.shape + assert np.allclose(value[:token_num, :abridged_dim], expected_result, atol=2e-3) def test_single_embedding(): - is_ci = os.getenv("CI") - if not is_ci: - for model_name, expected_result in CANONICAL_IMAGE_VALUES.items(): - print("evaluating", model_name) - model = LateInteractionMultimodalEmbedding(model_name=model_name) - result = next(iter(model.embed_image(images, batch_size=6))) - token_num, abridged_dim = expected_result.shape - assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3) + if os.getenv("CI"): + pytest.skip("Colpali is too large to test in CI") + + for model_name, expected_result in CANONICAL_IMAGE_VALUES.items(): + print("evaluating", model_name) + model = LateInteractionMultimodalEmbedding(model_name=model_name) + result = next(iter(model.embed_image(images, batch_size=6))) + token_num, abridged_dim = expected_result.shape + assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3) def test_single_embedding_query(): - is_ci = os.getenv("CI") - if not is_ci: - queries_to_embed = queries - - for model_name, expected_result in CANONICAL_QUERY_VALUES.items(): - print("evaluating", model_name) - model = LateInteractionMultimodalEmbedding(model_name=model_name) - result = next(iter(model.embed_text(queries_to_embed))) - token_num, abridged_dim = expected_result.shape - assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3) + if os.getenv("CI"): + pytest.skip("Colpali is too large to test in CI") + + for model_name, expected_result in CANONICAL_QUERY_VALUES.items(): + print("evaluating", model_name) + model = LateInteractionMultimodalEmbedding(model_name=model_name) + result = next(iter(model.embed_text(queries))) + token_num, abridged_dim = expected_result.shape + assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3) diff --git a/tests/test_sparse_embeddings.py b/tests/test_sparse_embeddings.py index 48cc14fd..71941baf 100644 --- a/tests/test_sparse_embeddings.py +++ b/tests/test_sparse_embeddings.py @@ -5,10 +5,10 @@ from fastembed.sparse.bm25 import Bm25 from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding -from tests.utils import delete_model_cache +from tests.utils import delete_model_cache, should_test_model CANONICAL_COLUMN_VALUES = { - "prithvida/Splade_PP_en_v1": { + "prithivida/Splade_PP_en_v1": { "indices": [ 2040, 2047, @@ -49,28 +49,41 @@ docs = ["Hello World"] -def test_batch_embedding() -> None: +@pytest.mark.parametrize("model_name", ["prithivida/Splade_PP_en_v1"]) +def test_batch_embedding(model_name: str) -> None: is_ci = os.getenv("CI") docs_to_embed = docs * 10 - for model_name, expected_result in CANONICAL_COLUMN_VALUES.items(): - model = SparseTextEmbedding(model_name=model_name) - result = next(iter(model.embed(docs_to_embed, batch_size=6))) - assert result.indices.tolist() == expected_result["indices"] + model = SparseTextEmbedding(model_name=model_name) + result = next(iter(model.embed(docs_to_embed, batch_size=6))) + expected_result = CANONICAL_COLUMN_VALUES[model_name] + assert result.indices.tolist() == expected_result["indices"] - for i, value in enumerate(result.values): - assert pytest.approx(value, abs=0.001) == expected_result["values"][i] - if is_ci: - delete_model_cache(model.model._model_dir) + for i, value in enumerate(result.values): + assert pytest.approx(value, abs=0.001) == expected_result["values"][i] + if is_ci: + delete_model_cache(model.model._model_dir) -def test_single_embedding() -> None: +@pytest.mark.parametrize("model_name", ["prithivida/Splade_PP_en_v1"]) +def test_single_embedding(model_name: str) -> None: is_ci = os.getenv("CI") - for model_name, expected_result in CANONICAL_COLUMN_VALUES.items(): + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" + + for model_desc in SparseTextEmbedding._list_supported_models(): + if ( + model_desc.model not in CANONICAL_COLUMN_VALUES + ): # attention models and bm25 are also parts of + # SparseTextEmbedding, however, they have their own tests + continue + if not should_test_model(model_desc, model_name, is_ci, is_manual): + continue + model = SparseTextEmbedding(model_name=model_name) passage_result = next(iter(model.embed(docs, batch_size=6))) query_result = next(iter(model.query_embed(docs))) + expected_result = CANONICAL_COLUMN_VALUES[model_name] for result in [passage_result, query_result]: assert result.indices.tolist() == expected_result["indices"] @@ -80,9 +93,10 @@ def test_single_embedding() -> None: delete_model_cache(model.model._model_dir) -def test_parallel_processing() -> None: +@pytest.mark.parametrize("model_name", ["prithivida/Splade_PP_en_v1"]) +def test_parallel_processing(model_name: str) -> None: is_ci = os.getenv("CI") - model = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1") + model = SparseTextEmbedding(model_name=model_name) docs = ["hello world", "flag embedding"] * 30 sparse_embeddings_duo = list(model.embed(docs, batch_size=10, parallel=2)) sparse_embeddings_all = list(model.embed(docs, batch_size=10, parallel=0)) @@ -172,10 +186,7 @@ def test_disable_stemmer_behavior(disable_stemmer: bool) -> None: assert result == expected, f"Expected {expected}, but got {result}" -@pytest.mark.parametrize( - "model_name", - ["prithivida/Splade_PP_en_v1"], -) +@pytest.mark.parametrize("model_name", ["prithivida/Splade_PP_en_v1"]) def test_lazy_load(model_name: str) -> None: is_ci = os.getenv("CI") model = SparseTextEmbedding(model_name=model_name, lazy_load=True) diff --git a/tests/test_text_cross_encoder.py b/tests/test_text_cross_encoder.py index 680c1a09..76362fdc 100644 --- a/tests/test_text_cross_encoder.py +++ b/tests/test_text_cross_encoder.py @@ -4,7 +4,7 @@ import pytest from fastembed.rerank.cross_encoder import TextCrossEncoder -from tests.utils import delete_model_cache +from tests.utils import delete_model_cache, should_test_model CANONICAL_SCORE_VALUES = { "Xenova/ms-marco-MiniLM-L-6-v2": np.array([8.500708, -2.541011]), @@ -15,44 +15,37 @@ "jinaai/jina-reranker-v2-base-multilingual": np.array([1.6533, -1.6455]), } -SELECTED_MODELS = { - "Xenova": "Xenova/ms-marco-MiniLM-L-6-v2", - "BAAI": "BAAI/bge-reranker-base", - "jinaai": "jinaai/jina-reranker-v1-tiny-en", -} - -@pytest.mark.parametrize( - "model_name", - [model_name for model_name in CANONICAL_SCORE_VALUES], -) +@pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"]) def test_rerank(model_name: str) -> None: is_ci = os.getenv("CI") + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" - model = TextCrossEncoder(model_name=model_name) + for model_desc in TextCrossEncoder._list_supported_models(): + if not should_test_model(model_desc, model_name, is_ci, is_manual): + continue - query = "What is the capital of France?" - documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."] - scores = np.array(list(model.rerank(query, documents))) + model = TextCrossEncoder(model_name=model_name) - pairs = [(query, doc) for doc in documents] - scores2 = np.array(list(model.rerank_pairs(pairs))) - assert np.allclose( - scores, scores2, atol=1e-5 - ), f"Model: {model_name}, Scores: {scores}, Scores2: {scores2}" + query = "What is the capital of France?" + documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."] + scores = np.array(list(model.rerank(query, documents))) - canonical_scores = CANONICAL_SCORE_VALUES[model_name] - assert np.allclose( - scores, canonical_scores, atol=1e-3 - ), f"Model: {model_name}, Scores: {scores}, Expected: {canonical_scores}" - if is_ci: - delete_model_cache(model.model._model_dir) + pairs = [(query, doc) for doc in documents] + scores2 = np.array(list(model.rerank_pairs(pairs))) + assert np.allclose( + scores, scores2, atol=1e-5 + ), f"Model: {model_name}, Scores: {scores}, Scores2: {scores2}" + + canonical_scores = CANONICAL_SCORE_VALUES[model_name] + assert np.allclose( + scores, canonical_scores, atol=1e-3 + ), f"Model: {model_name}, Scores: {scores}, Expected: {canonical_scores}" + if is_ci: + delete_model_cache(model.model._model_dir) -@pytest.mark.parametrize( - "model_name", - [model_name for model_name in SELECTED_MODELS.values()], -) +@pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"]) def test_batch_rerank(model_name: str) -> None: is_ci = os.getenv("CI") @@ -78,10 +71,7 @@ def test_batch_rerank(model_name: str) -> None: delete_model_cache(model.model._model_dir) -@pytest.mark.parametrize( - "model_name", - ["Xenova/ms-marco-MiniLM-L-6-v2"], -) +@pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"]) def test_lazy_load(model_name: str) -> None: is_ci = os.getenv("CI") model = TextCrossEncoder(model_name=model_name, lazy_load=True) @@ -95,10 +85,7 @@ def test_lazy_load(model_name: str) -> None: delete_model_cache(model.model._model_dir) -@pytest.mark.parametrize( - "model_name", - [model_name for model_name in SELECTED_MODELS.values()], -) +@pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"]) def test_rerank_pairs_parallel(model_name: str) -> None: is_ci = os.getenv("CI") diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py index b9bc89f1..874ffcec 100644 --- a/tests/test_text_multitask_embeddings.py +++ b/tests/test_text_multitask_embeddings.py @@ -4,7 +4,7 @@ import pytest from fastembed import TextEmbedding -from fastembed.text.multitask_embedding import Task +from fastembed.text.multitask_embedding import JinaEmbeddingV3, Task from tests.utils import delete_model_cache @@ -60,52 +60,43 @@ docs = ["Hello World", "Follow the white rabbit."] -def test_batch_embedding(): +@pytest.mark.parametrize("dim,model_name", [(1024, "jinaai/jina-embeddings-v3")]) +def test_batch_embedding(dim: int, model_name: str): is_ci = os.getenv("CI") + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" + if is_ci and not is_manual: + pytest.skip("Skipping multitask models in CI non-manual mode") + docs_to_embed = docs * 10 default_task = Task.RETRIEVAL_PASSAGE - for model_desc in TextEmbedding._list_supported_models(): - if not is_ci and model_desc.size_in_GB > 1: - continue - - model_name = model_desc.model - dim = model_desc.dim - - if model_name not in CANONICAL_VECTOR_VALUES.keys(): - continue - - model = TextEmbedding(model_name=model_name) - - print(f"evaluating {model_name} default task") + model = TextEmbedding(model_name=model_name) - embeddings = list(model.embed(documents=docs_to_embed, batch_size=6)) - embeddings = np.stack(embeddings, axis=0) + embeddings = list(model.embed(documents=docs_to_embed, batch_size=6)) + embeddings = np.stack(embeddings, axis=0) - assert embeddings.shape == (len(docs_to_embed), dim) + assert embeddings.shape == (len(docs_to_embed), dim) - canonical_vector = CANONICAL_VECTOR_VALUES[model_name][default_task]["vectors"] - assert np.allclose( - embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 - ), model_desc.model + canonical_vector = CANONICAL_VECTOR_VALUES[model_name][default_task]["vectors"] + assert np.allclose( + embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + ), model_name - if is_ci: - delete_model_cache(model.model._model_dir) + if is_ci: + delete_model_cache(model.model._model_dir) def test_single_embedding(): is_ci = os.getenv("CI") + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" + if is_ci and not is_manual: + pytest.skip("Skipping multitask models in CI non-manual mode") - for model_desc in TextEmbedding._list_supported_models(): - if not is_ci and model_desc.size_in_GB > 1: - continue - + for model_desc in JinaEmbeddingV3._list_supported_models(): + # todo: once we add more models, we should not test models >1GB size locally model_name = model_desc.model dim = model_desc.dim - if model_name not in CANONICAL_VECTOR_VALUES.keys(): - continue - model = TextEmbedding(model_name=model_name) for task in CANONICAL_VECTOR_VALUES[model_name]: @@ -127,18 +118,17 @@ def test_single_embedding(): def test_single_embedding_query(): is_ci = os.getenv("CI") - task_id = Task.RETRIEVAL_QUERY + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" + if is_ci and not is_manual: + pytest.skip("Skipping multitask models in CI non-manual mode") - for model_desc in TextEmbedding._list_supported_models(): - if not is_ci and model_desc.size_in_GB > 1: - continue + task_id = Task.RETRIEVAL_QUERY + for model_desc in JinaEmbeddingV3._list_supported_models(): + # todo: once we add more models, we should not test models >1GB size locally model_name = model_desc.model dim = model_desc.dim - if model_name not in CANONICAL_VECTOR_VALUES.keys(): - continue - model = TextEmbedding(model_name=model_name) print(f"evaluating {model_name} query_embed task_id: {task_id}") @@ -159,18 +149,18 @@ def test_single_embedding_query(): def test_single_embedding_passage(): is_ci = os.getenv("CI") + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" + if is_ci and not is_manual: + pytest.skip("Skipping multitask models in CI non-manual mode") + task_id = Task.RETRIEVAL_PASSAGE - for model_desc in TextEmbedding._list_supported_models(): - if not is_ci and model_desc.size_in_GB > 1: - continue + for model_desc in JinaEmbeddingV3._list_supported_models(): + # todo: once we add more models, we should not test models >1GB size locally model_name = model_desc.model dim = model_desc.dim - if model_name not in CANONICAL_VECTOR_VALUES.keys(): - continue - model = TextEmbedding(model_name=model_name) print(f"evaluating {model_name} passage_embed task_id: {task_id}") @@ -189,14 +179,15 @@ def test_single_embedding_passage(): delete_model_cache(model.model._model_dir) -def test_parallel_processing(): +@pytest.mark.parametrize("dim,model_name", [(1024, "jinaai/jina-embeddings-v3")]) +def test_parallel_processing(dim: int, model_name: str): is_ci = os.getenv("CI") + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" + if is_ci and not is_manual: + pytest.skip("Skipping in CI non-manual mode") docs = ["Hello World", "Follow the white rabbit."] * 10 - model_name = "jinaai/jina-embeddings-v3" - dim = 1024 - model = TextEmbedding(model_name=model_name) task_id = Task.SEPARATION @@ -218,14 +209,14 @@ def test_parallel_processing(): def test_task_assignment(): is_ci = os.getenv("CI") + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" - for model_desc in TextEmbedding._list_supported_models(): - if not is_ci and model_desc.size_in_GB > 1: - continue + if is_ci and not is_manual: + pytest.skip("Skipping in CI non-manual mode") + for model_desc in JinaEmbeddingV3._list_supported_models(): + # todo: once we add more models, we should not test models >1GB size locally model_name = model_desc.model - if model_name not in CANONICAL_VECTOR_VALUES.keys(): - continue model = TextEmbedding(model_name=model_name) @@ -237,12 +228,14 @@ def test_task_assignment(): delete_model_cache(model.model._model_dir) -@pytest.mark.parametrize( - "model_name", - ["jinaai/jina-embeddings-v3"], -) +@pytest.mark.parametrize("model_name", ["jinaai/jina-embeddings-v3"]) def test_lazy_load(model_name: str): is_ci = os.getenv("CI") + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" + + if is_ci and not is_manual: + pytest.skip("Skipping in CI non-manual mode") + model = TextEmbedding(model_name=model_name, lazy_load=True) assert not hasattr(model.model, "model") diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index cf39d7d1..927ef54e 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -5,7 +5,7 @@ import pytest from fastembed.text.text_embedding import TextEmbedding -from tests.utils import delete_model_cache +from tests.utils import delete_model_cache, should_test_model CANONICAL_VECTOR_VALUES = { "BAAI/bge-small-en": np.array([-0.0232, -0.0255, 0.0174, -0.0639, -0.0006]), @@ -72,17 +72,19 @@ MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3"] -def test_embedding() -> None: +@pytest.mark.parametrize("model_name", ["BAAI/bge-small-en-v1.5"]) +def test_embedding(model_name: str) -> None: is_ci = os.getenv("CI") is_mac = platform.system() == "Darwin" + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" for model_desc in TextEmbedding._list_supported_models(): - if ( - (not is_ci and model_desc.size_in_GB > 1) - or model_desc.model in MULTI_TASK_MODELS - or (is_mac and model_desc.model == "nomic-ai/nomic-embed-text-v1.5-Q") + if model_desc.model in MULTI_TASK_MODELS or ( + is_mac and model_desc.model == "nomic-ai/nomic-embed-text-v1.5-Q" ): continue + if not should_test_model(model_desc, model_name, is_ci, is_manual): + continue dim = model_desc.dim @@ -95,15 +97,12 @@ def test_embedding() -> None: canonical_vector = CANONICAL_VECTOR_VALUES[model_desc.model] assert np.allclose( embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 - ), model_desc["model"] + ), model_desc.model if is_ci: delete_model_cache(model.model._model_dir) -@pytest.mark.parametrize( - "n_dims,model_name", - [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")], -) +@pytest.mark.parametrize("n_dims,model_name", [(384, "BAAI/bge-small-en-v1.5")]) def test_batch_embedding(n_dims: int, model_name: str) -> None: is_ci = os.getenv("CI") model = TextEmbedding(model_name=model_name) @@ -112,15 +111,12 @@ def test_batch_embedding(n_dims: int, model_name: str) -> None: embeddings = list(model.embed(docs, batch_size=10)) embeddings = np.stack(embeddings, axis=0) - assert embeddings.shape == (200, n_dims) + assert embeddings.shape == (len(docs), n_dims) if is_ci: delete_model_cache(model.model._model_dir) -@pytest.mark.parametrize( - "n_dims,model_name", - [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")], -) +@pytest.mark.parametrize("n_dims,model_name", [(384, "BAAI/bge-small-en-v1.5")]) def test_parallel_processing(n_dims: int, model_name: str) -> None: is_ci = os.getenv("CI") model = TextEmbedding(model_name=model_name) @@ -135,7 +131,7 @@ def test_parallel_processing(n_dims: int, model_name: str) -> None: embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) embeddings_3 = np.stack(embeddings_3, axis=0) - assert embeddings.shape == (200, n_dims) + assert embeddings.shape == (len(docs), n_dims) assert np.allclose(embeddings, embeddings_2, atol=1e-3) assert np.allclose(embeddings, embeddings_3, atol=1e-3) @@ -143,10 +139,7 @@ def test_parallel_processing(n_dims: int, model_name: str) -> None: delete_model_cache(model.model._model_dir) -@pytest.mark.parametrize( - "model_name", - ["BAAI/bge-small-en-v1.5"], -) +@pytest.mark.parametrize("model_name", ["BAAI/bge-small-en-v1.5"]) def test_lazy_load(model_name: str) -> None: is_ci = os.getenv("CI") model = TextEmbedding(model_name=model_name, lazy_load=True) diff --git a/tests/utils.py b/tests/utils.py index cfd6ae8b..40a7febf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,7 +3,9 @@ from pathlib import Path from types import TracebackType -from typing import Union, Callable, Any, Type +from typing import Union, Callable, Any, Type, Optional + +from fastembed.common.model_description import BaseModelDescription def delete_model_cache(model_dir: Union[str, Path]) -> None: @@ -35,3 +37,31 @@ def on_error( if model_dir.exists(): # todo: PermissionDenied is raised on blobs removal in Windows, with blobs > 2GB shutil.rmtree(model_dir, onerror=on_error) + + +def should_test_model( + model_desc: BaseModelDescription, + autotest_model_name: str, + is_ci: Optional[str], + is_manual: bool, +): + """Determine if a model should be tested based on environment + + Tests can be run either in ci or locally. + Testing all models each time in ci is too long. + The testing scheme in ci and on a local machine are different, therefore, there are 3 possible scenarious. + 1) Run lightweight tests in ci: + - test only one model that has been manually chosen as a representative for a certain class family + 2) Run heavyweight (manual) tests in ci: + - test all models + Running tests in ci each time is too expensive, however, it's fine to run it one time with a manual dispatch + 3) Run tests locally: + - test all models, which are not too heavy, since network speed might be a bottleneck + + """ + if not is_ci: + if model_desc.size_in_GB > 1: + return False + elif not is_manual and model_desc.model != autotest_model_name: + return False + return True