Skip to content

Commit

Permalink
tests: Updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Feb 27, 2025
1 parent 340d46c commit 3b7d3d7
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 149 deletions.
21 changes: 0 additions & 21 deletions tests/get_all_model_hash.py

This file was deleted.

15 changes: 13 additions & 2 deletions tests/test_image_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,16 @@

def test_embedding() -> None:
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"

for model_desc in ImageEmbedding._list_supported_models():
if not is_ci and model_desc.size_in_GB > 1:
all_models = ImageEmbedding._list_supported_models()

models_to_test = [all_models[0]] if not is_manual else all_models

for model_desc in models_to_test:
if (
not is_ci and model_desc.size_in_GB > 1
) or model_desc.model not in CANONICAL_VECTOR_VALUES:
continue

dim = model_desc.dim
Expand Down Expand Up @@ -74,8 +81,12 @@ def test_batch_embedding(n_dims: int, model_name: str) -> None:

embeddings = list(model.embed(images, batch_size=10))
embeddings = np.stack(embeddings, axis=0)
assert np.allclose(embeddings[1], embeddings[2])

canonical_vector = CANONICAL_VECTOR_VALUES[model_name]

assert embeddings.shape == (len(test_images) * n_images, n_dims)
assert np.allclose(embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3)
if is_ci:
delete_model_cache(model.model._model_dir)

Expand Down
60 changes: 42 additions & 18 deletions tests/test_late_interaction_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,57 +153,81 @@
docs = ["Hello World"]


def test_batch_embedding():
@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"])
def test_batch_embedding(model_name: str):
is_ci = os.getenv("CI")
docs_to_embed = docs * 10

for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
print("evaluating", model_name)
model = LateInteractionTextEmbedding(model_name=model_name)
result = list(model.embed(docs_to_embed, batch_size=6))
model = LateInteractionTextEmbedding(model_name=model_name)
result = list(model.embed(docs_to_embed, batch_size=6))
expected_result = CANONICAL_COLUMN_VALUES[model_name]

for value in result:
token_num, abridged_dim = expected_result.shape
assert np.allclose(value[:, :abridged_dim], expected_result, atol=2e-3)
for value in result:
token_num, abridged_dim = expected_result.shape
assert np.allclose(value[:, :abridged_dim], expected_result, atol=2e-3)

if is_ci:
delete_model_cache(model.model._model_dir)
if is_ci:
delete_model_cache(model.model._model_dir)


def test_single_embedding():
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
docs_to_embed = docs

for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
all_models = LateInteractionTextEmbedding._list_supported_models()
models_to_test = [all_models[0]] if not is_manual else all_models

for model_desc in models_to_test:
model_name = model_desc.model
if (
not is_ci and model_desc.size_in_GB > 1
) or model_desc.model not in CANONICAL_COLUMN_VALUES:
continue
print("evaluating", model_name)
model = LateInteractionTextEmbedding(model_name=model_name)
result = next(iter(model.embed(docs_to_embed, batch_size=6)))
expected_result = CANONICAL_COLUMN_VALUES[model_name]
token_num, abridged_dim = expected_result.shape
assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3)

if is_ci:
delete_model_cache(model.model._model_dir)
if is_ci:
delete_model_cache(model.model._model_dir)


def test_single_embedding_query():
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
queries_to_embed = docs

for model_name, expected_result in CANONICAL_QUERY_VALUES.items():
all_models = LateInteractionTextEmbedding._list_supported_models()
models_to_test = [all_models[0]] if not is_manual else all_models

for model_desc in models_to_test:
model_name = model_desc.model
if (
not is_ci and model_desc.size_in_GB > 1
) or model_desc.model not in CANONICAL_QUERY_VALUES:
continue
print("evaluating", model_name)
model = LateInteractionTextEmbedding(model_name=model_name)
result = next(iter(model.query_embed(queries_to_embed)))
expected_result = CANONICAL_COLUMN_VALUES[model_name]
token_num, abridged_dim = expected_result.shape
assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3)

if is_ci:
delete_model_cache(model.model._model_dir)


def test_parallel_processing():
@pytest.mark.parametrize(
"token_dim,model_name",
[(96, "answerdotai/answerai-colbert-small-v1")],
)
def test_parallel_processing(token_dim: int, model_name: str):
is_ci = os.getenv("CI")
model = LateInteractionTextEmbedding(model_name="colbert-ir/colbertv2.0")
token_dim = 128
model = LateInteractionTextEmbedding(model_name=model_name)

docs = ["hello world", "flag embedding"] * 100
embeddings = list(model.embed(docs, batch_size=10, parallel=2))
embeddings = np.stack(embeddings, axis=0)
Expand All @@ -224,7 +248,7 @@ def test_parallel_processing():

@pytest.mark.parametrize(
"model_name",
["colbert-ir/colbertv2.0"],
["answerdotai/answerai-colbert-small-v1"],
)
def test_lazy_load(model_name: str):
is_ci = os.getenv("CI")
Expand Down
12 changes: 11 additions & 1 deletion tests/test_sparse_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,21 @@ def test_batch_embedding() -> None:

def test_single_embedding() -> None:
is_ci = os.getenv("CI")
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"

all_models = SparseTextEmbedding._list_supported_models()
models_to_test = [all_models[0]] if not is_manual else all_models

for model_desc in models_to_test:
model_name = model_desc.model
if (not is_ci and model_desc.size_in_GB > 1) or model_name not in CANONICAL_COLUMN_VALUES:
continue

model = SparseTextEmbedding(model_name=model_name)

passage_result = next(iter(model.embed(docs, batch_size=6)))
query_result = next(iter(model.query_embed(docs)))
expected_result = CANONICAL_COLUMN_VALUES[model_name]
for result in [passage_result, query_result]:
assert result.indices.tolist() == expected_result["indices"]

Expand Down
54 changes: 27 additions & 27 deletions tests/test_text_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,43 @@
"jinaai/jina-reranker-v2-base-multilingual": np.array([1.6533, -1.6455]),
}

SELECTED_MODELS = {
"Xenova": "Xenova/ms-marco-MiniLM-L-6-v2",
"BAAI": "BAAI/bge-reranker-base",
"jinaai": "jinaai/jina-reranker-v1-tiny-en",
}


@pytest.mark.parametrize(
"model_name",
[model_name for model_name in CANONICAL_SCORE_VALUES],
)
def test_rerank(model_name: str) -> None:
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"

model = TextCrossEncoder(model_name=model_name)
all_models = TextCrossEncoder._list_supported_models()
models_to_test = [all_models[0]] if not is_manual else all_models

query = "What is the capital of France?"
documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."]
scores = np.array(list(model.rerank(query, documents)))
for model_desc in models_to_test:
if (
not is_ci and model_desc.size_in_GB > 1
) or model_desc.model not in CANONICAL_SCORE_VALUES:
continue

pairs = [(query, doc) for doc in documents]
scores2 = np.array(list(model.rerank_pairs(pairs)))
assert np.allclose(
scores, scores2, atol=1e-5
), f"Model: {model_name}, Scores: {scores}, Scores2: {scores2}"
model = TextCrossEncoder(model_name=model_name)

canonical_scores = CANONICAL_SCORE_VALUES[model_name]
assert np.allclose(
scores, canonical_scores, atol=1e-3
), f"Model: {model_name}, Scores: {scores}, Expected: {canonical_scores}"
if is_ci:
delete_model_cache(model.model._model_dir)
query = "What is the capital of France?"
documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."]
scores = np.array(list(model.rerank(query, documents)))

pairs = [(query, doc) for doc in documents]
scores2 = np.array(list(model.rerank_pairs(pairs)))
assert np.allclose(
scores, scores2, atol=1e-5
), f"Model: {model_name}, Scores: {scores}, Scores2: {scores2}"

canonical_scores = CANONICAL_SCORE_VALUES[model_name]
assert np.allclose(
scores, canonical_scores, atol=1e-3
), f"Model: {model_name}, Scores: {scores}, Expected: {canonical_scores}"
if is_ci:
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
[model_name for model_name in SELECTED_MODELS.values()],
["Xenova/ms-marco-MiniLM-L-6-v2"],
)
def test_batch_rerank(model_name: str) -> None:
is_ci = os.getenv("CI")
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_lazy_load(model_name: str) -> None:

@pytest.mark.parametrize(
"model_name",
[model_name for model_name in SELECTED_MODELS.values()],
["Xenova/ms-marco-MiniLM-L-6-v2"],
)
def test_rerank_pairs_parallel(model_name: str) -> None:
is_ci = os.getenv("CI")
Expand Down
Loading

0 comments on commit 3b7d3d7

Please sign in to comment.