diff --git a/efaar_benchmarking/benchmarking.py b/efaar_benchmarking/benchmarking.py index 9df85b9..e12d612 100644 --- a/efaar_benchmarking/benchmarking.py +++ b/efaar_benchmarking/benchmarking.py @@ -1,11 +1,14 @@ +from dataclasses import dataclass +from enum import Enum from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple import numpy as np import pandas as pd from geomloss import SamplesLoss from joblib import Parallel, delayed from scipy.stats import hypergeom, ks_2samp -from sklearn.metrics import average_precision_score, precision_recall_curve +from sklearn.metrics import roc_auc_score from sklearn.metrics.pairwise import cosine_similarity from sklearn.utils import Bunch from torch import from_numpy @@ -13,6 +16,28 @@ import efaar_benchmarking.constants as cst +class AverageType(Enum): + MICRO = "micro" + MACRO = "macro" + + +class AggregateBy(Enum): + COMPOUND = "compound" + GENE = "gene" + + +@dataclass +class BenchmarkConfig: + """Configuration for benchmark computation.""" + + average_type: AverageType = AverageType.MACRO + aggregate_by: AggregateBy = AggregateBy.COMPOUND + min_negatives: int = 20 + n_baseline_sims: int = 100 + random_seed: int = 42 + quantiles: Optional[List[float]] = None # New parameter for quantiles + + def pert_signal_consistency_metric( arr: np.ndarray, sorted_null: np.ndarray = np.array([]) ) -> float | None | tuple[float | None, float | None]: @@ -547,56 +572,264 @@ def cosine_similarity_from_map( return compound_values.dot(gene_values) / (np.linalg.norm(compound_values) * np.linalg.norm(gene_values)) -def compound_gene_benchmark( - map_data: pd.DataFrame, - nM_activity_threshold: float = 1000, - pert_col: str = "perturbation", - benchmark_data_dir: str = cst.BENCHMARK_DATA_DIR, -) -> tuple[pd.DataFrame, dict]: - """Compute benchmarks for compound-gene pairs. +def load_truth_data(benchmark_data_dir: str) -> pd.DataFrame: + """Load the ground truth data from a CSV file.""" + truth_data_path = Path(benchmark_data_dir) / "compound_gene_interactions.csv" + return pd.read_csv(truth_data_path) - Args: - map_data (pd.DataFrame): DataFrame containing the embeddings and metadata. - nM_activity_threshold (float): Concentration threshold to use for the benchmark in nM. - Ground truth relationships below this threshold will be considered as positives. Default is 1000 nM. - pert_col (str): Column name containing the perturbation information. Default is "perturbation". - benchmark_data_dir (str): Directory to save the benchmark data. Default is cst.BENCHMARK_DATA_DIR. - Returns: - tuple[pd.DataFrame, dict]: A tuple containing: - - A DataFrame containing the average precision scores for different concentrations and the maximum value. - - A dictionary containing the precision-recall curves for different concentrations and the maximum value. +def compute_similarities( + truth: pd.DataFrame, + map_data: Bunch, + pert_col: str, + randomize: bool = False, +) -> pd.DataFrame: + """Compute cosine similarities between compounds and genes.""" + treatments = truth["treatment"].unique() + genes = truth["gene_symbol"].unique() - """ + compound_meta = map_data.metadata[map_data.metadata[pert_col].isin(treatments)].copy() + gene_meta = map_data.metadata[map_data.metadata[pert_col].isin(genes)].copy() - truth = pd.read_csv(Path(benchmark_data_dir).joinpath("compound_gene_interactions.csv")) - truth["active"] = truth["nM_value"] <= nM_activity_threshold + if compound_meta.empty or gene_meta.empty: + raise ValueError("No matching compounds or genes found in metadata.") - for conc in cst.COMPOUND_CONCENTRATIONS: - truth[f"cosine_similarity_{conc}"] = truth.apply( - lambda x: cosine_similarity_from_map(x["treatment"], x["gene_symbol"], conc, map_data, pert_col), axis=1 - ) - truth[f"cosine_similarity_{conc}"] = truth[f"cosine_similarity_{conc}"].apply( - lambda x: abs(x) if x is not None else x - ) - truth["cosine_similarity_max"] = truth[[col for col in truth.columns if col.startswith("cosine_similarity_")]].max( - axis=1, skipna=True + if randomize: + rng = np.random.default_rng(cst.RANDOM_SEED) + similarities = rng.uniform(0, 1, size=(len(compound_meta), len(gene_meta))) + else: + compound_features = map_data.features.loc[compound_meta.index] + gene_features = map_data.features.loc[gene_meta.index] + similarities = np.abs(cosine_similarity(compound_features, gene_features)) + + index = pd.MultiIndex.from_arrays( + [compound_meta[pert_col].values, compound_meta["concentration"].values], + names=[pert_col, "concentration"], ) - curves, aps = {}, {} - for conc_str in cst.COMPOUND_CONCENTRATIONS + ["max"]: - cos_sim = truth[f"cosine_similarity_{conc_str}"] - to_keep = ~cos_sim.isna() - cos_sim = cos_sim[to_keep] - labels = truth["active"][to_keep] + return pd.DataFrame(similarities, index=index, columns=gene_meta[pert_col].values) + + +def compute_baseline_predictions( + predictions: Dict[str, List[Tuple[np.ndarray, np.ndarray]]], + config: BenchmarkConfig, +) -> Dict[str, List[Tuple[np.ndarray, np.ndarray]]]: + """Generate baseline predictions using random scores for the same labels.""" + rng = np.random.default_rng(config.random_seed) + baseline_predictions = {} + for conc, preds in predictions.items(): + baseline_preds = [] + for _, labels in preds: + scores = rng.random(len(labels)) + baseline_preds.append((scores, labels)) + baseline_predictions[conc] = baseline_preds + return baseline_predictions + + +def aggregate_predictions( + predictions: Dict[str, List[Tuple[np.ndarray, np.ndarray]]], + config: BenchmarkConfig, +) -> Dict[str, Dict[str, float]]: + """Aggregate predictions across compounds or genes.""" + results = {} + for conc, preds in predictions.items(): + if not preds: + # Initialize result dictionary with zeros and default quantiles + result = {"average_precision": 0.0, "auc_roc": 0.5} + if config.quantiles: + for q in config.quantiles: + result[f"ap_quantile_{q}"] = 0.0 + results[conc] = result + continue + + if config.average_type == AverageType.MICRO: + scores = np.concatenate([p[0] for p in preds]) + labels = np.concatenate([p[1] for p in preds]) + ap, auc = compute_metrics(scores, labels) + result = {"average_precision": ap, "auc_roc": auc} + if config.quantiles: + # For micro averaging, quantiles are not applicable; set to overall AP + for q in config.quantiles: + result[f"ap_quantile_{q}"] = ap + else: # MACRO + aps = [] + aucs = [] + for scores, labels in preds: + if len(scores) == 0 or not np.any(labels): + continue + ap, auc = compute_metrics(scores, labels) + aps.append(ap) + aucs.append(auc) + if aps: + mean_ap = np.mean(aps) + mean_auc = np.mean(aucs) + result = {"average_precision": mean_ap, "auc_roc": mean_auc} + if config.quantiles: + # Compute quantiles + for q in config.quantiles: + quantile_value = np.quantile(aps, q) + result[f"ap_quantile_{q}"] = quantile_value + else: + result = {"average_precision": 0.0, "auc_roc": 0.5} + if config.quantiles: + for q in config.quantiles: + result[f"ap_quantile_{q}"] = 0.0 + results[conc] = result + return results + + +def compute_metrics(scores: np.ndarray, labels: np.ndarray) -> Tuple[float, float]: + """Compute average precision and AUC-ROC.""" + if len(scores) == 0 or not np.any(labels): + return 0.0, 0.5 + + sorted_indices = np.argsort(scores)[::-1] + sorted_labels = labels[sorted_indices] + tp_cumsum = np.cumsum(sorted_labels) + precision = tp_cumsum / np.arange(1, len(sorted_labels) + 1) + ap = np.sum(precision * sorted_labels) / sorted_labels.sum() + auc = roc_auc_score(labels, scores) + return ap, auc + + +def sample_for_item( + item_data: pd.DataFrame, + pool: Set[str], + activity_threshold: float, + inactivity_threshold: float, + target_col: str, + min_negatives: int = 20, + random_seed: int = 42, +) -> Tuple[np.ndarray, np.ndarray]: + """Generic sampling function for both compounds and genes.""" + rng = np.random.default_rng(random_seed) + + actives = item_data.loc[item_data["nM_value"] <= activity_threshold, target_col].unique() + if len(actives) == 0: + return np.array([]), np.array([]) + + ineligibles = item_data.loc[ + (item_data["nM_value"] > activity_threshold) & (item_data["nM_value"] <= inactivity_threshold), target_col + ].unique() + + eligibles = pool - set(actives) - set(ineligibles) + n_negatives = max(2 * len(actives), min_negatives) + + if len(eligibles) < n_negatives: + return np.array([]), np.array([]) + + negatives = rng.choice(list(eligibles), n_negatives, replace=False) + items = np.concatenate([actives, negatives]) + labels = np.isin(items, actives).astype(int) + + return items, labels + + +def process_predictions( + data: pd.DataFrame, + similarities: pd.DataFrame, + config: BenchmarkConfig, + thresholds: Tuple[float, float], + pert_col: str = "perturbation", +) -> Dict[str, List[Tuple[np.ndarray, np.ndarray]]]: + """Process predictions for either compounds or genes.""" + activity_threshold, inactivity_threshold = thresholds + predictions = {conc: [] for conc in cst.COMPOUND_CONCENTRATIONS + ["max"]} + + if config.aggregate_by == AggregateBy.COMPOUND: + pool = set(data["gene_symbol"].unique()) + item_col, target_col = "treatment", "gene_symbol" + else: + pool = set(data["treatment"].unique()) + item_col, target_col = "gene_symbol", "treatment" + + for item in data[item_col].unique(): + if config.aggregate_by == AggregateBy.COMPOUND and item not in similarities.index.get_level_values(pert_col): + continue + + item_data = data[data[item_col] == item] + targets, labels = sample_for_item( + item_data, + pool, + activity_threshold, + inactivity_threshold, + target_col, + config.min_negatives, + config.random_seed, + ) + + if len(targets) == 0: + continue + + scores_by_conc = {} + if config.aggregate_by == AggregateBy.COMPOUND: + item_similarities = similarities.loc[item] + for conc in item_similarities.index.unique(): + scores = item_similarities.loc[conc, targets].values + if not np.all(np.isnan(scores)): + predictions[conc].append((scores, labels)) + scores_by_conc[conc] = scores + if scores_by_conc: + max_scores = np.nanmax(np.vstack(list(scores_by_conc.values())), axis=0) + if not np.all(np.isnan(max_scores)): + predictions["max"].append((max_scores, labels)) + else: + for conc in cst.COMPOUND_CONCENTRATIONS: + try: + sim_conc = similarities.xs(conc, level="concentration") + available = sim_conc.index.intersection(targets) + if not available.empty: + sim_conc_item = sim_conc.loc[available, item] + scores = sim_conc_item.values + labels_filtered = labels[np.isin(targets, available)] + if not np.all(np.isnan(scores)): + predictions[conc].append((scores, labels_filtered)) + scores_by_conc[conc] = pd.Series(scores, index=available) + except KeyError: + continue + if scores_by_conc: + available = set().union(*[scores_by_conc[conc].index for conc in scores_by_conc]) + if available: + available = list(available) + scores_df = pd.DataFrame({conc: scores_by_conc[conc] for conc in scores_by_conc}, index=available) + max_scores = scores_df.max(axis=1).values + labels_filtered = labels[np.isin(targets, available)] + if not np.all(np.isnan(max_scores)): + predictions["max"].append((max_scores, labels_filtered)) + + return predictions + + +def compound_gene_benchmark( + map_data: Bunch, + activity_threshold: float = 1000, + inactivity_threshold: float = 10000, + pert_col: str = "perturbation", + benchmark_data_dir: str = cst.BENCHMARK_DATA_DIR, + truth_data: Optional[pd.DataFrame] = None, + check_random: bool = False, + config: Optional[BenchmarkConfig] = None, +) -> pd.DataFrame: + """Main benchmark function.""" + config = config or BenchmarkConfig() + truth = truth_data if truth_data is not None else load_truth_data(benchmark_data_dir) + similarities = compute_similarities(truth, map_data, pert_col, randomize=check_random) + + thresholds = (activity_threshold, inactivity_threshold) + predictions = process_predictions(truth, similarities, config, thresholds, pert_col) + baseline_preds = compute_baseline_predictions(predictions, config) + + results_dict = aggregate_predictions(predictions, config) + baseline_dict = aggregate_predictions(baseline_preds, config) - try: - precision, recall, _ = precision_recall_curve(labels, cos_sim) - curves[f"{conc_str}"] = (precision, recall) - aps[f"{conc_str}"] = average_precision_score(labels, cos_sim) - except ValueError: - pass + # Convert results to DataFrame + results = pd.DataFrame.from_dict(results_dict, orient="index").reset_index() + results.rename(columns={"index": "concentration"}, inplace=True) - aps["baseline"] = truth["active"].mean() + # Convert baseline results to DataFrame + baseline = pd.DataFrame.from_dict(baseline_dict, orient="index").reset_index() + baseline.rename(columns={"index": "concentration"}, inplace=True) + # Merge baseline metrics with results + results = results.merge(baseline, on="concentration", suffixes=("", "_baseline")) - return pd.DataFrame(aps.items(), columns=["concentration", "average_precision"]), curves + return results diff --git a/tests/test_benchmarking.py b/tests/test_benchmarking.py index bae672e..53d4e0f 100644 --- a/tests/test_benchmarking.py +++ b/tests/test_benchmarking.py @@ -3,19 +3,50 @@ import numpy as np import pandas as pd import pytest +from sklearn.utils import Bunch from efaar_benchmarking import benchmarking, constants +from efaar_benchmarking.benchmarking import ( + AggregateBy, + AverageType, + BenchmarkConfig, + compound_gene_benchmark, + compute_metrics, + compute_similarities, + process_predictions, + sample_for_item, +) + + +@pytest.fixture +def sample_truth_data(): + """Create sample ground truth data with known properties.""" + return pd.DataFrame( + { + "treatment": ["compound1", "compound1", "compound2", "compound2"], + "gene_symbol": ["gene1", "gene2", "gene1", "gene2"], + "nM_value": [100, 5000, 15000, 500], # Mix of active, gray zone, and inactive + } + ) @pytest.fixture def sample_map_data(): - data = { - "perturbation": ["compound1", "gene1", "compound2", "gene2"], - "concentration": [10.0, np.nan, 1.0, np.nan], - "feature_1": [0.1, 0.2, 0.3, 0.4], - "feature_2": [0.5, 0.6, 0.7, 0.8], - } - return pd.DataFrame(data) + """Create sample embedding data with known similarities.""" + features = pd.DataFrame( + {"feat1": [1.0, 0.0, 0.0, 1.0], "feat2": [0.0, 1.0, 1.0, 0.0]}, + index=["compound1_id", "compound2_id", "gene1_id", "gene2_id"], + ) + + metadata = pd.DataFrame( + { + "perturbation": ["compound1", "compound2", "gene1", "gene2"], + "concentration": ["1.0", "1.0", np.nan, np.nan], + }, + index=["compound1_id", "compound2_id", "gene1_id", "gene2_id"], + ) + + return Bunch(features=features, metadata=metadata) @pytest.fixture @@ -114,3 +145,174 @@ def test_compound_gene_benchmark(mock_read_csv, sample_map_data): assert "1.0" in curves assert "max" in curves assert isinstance(curves["max"], tuple) + + +def test_compute_similarities_basic(sample_map_data): + """Test basic similarity computation.""" + truth = pd.DataFrame({"treatment": ["compound1", "compound2"], "gene_symbol": ["gene1", "gene2"]}) + + sims = compute_similarities(truth, sample_map_data, "perturbation") + + assert isinstance(sims, pd.DataFrame) + assert sims.shape == (2, 2) # 2 compounds x 2 genes + # compound1 should be perfectly similar to gene1 (same features) + assert np.isclose(sims.loc[("compound1", "1.0"), "gene1"], 1.0) + + +def test_compute_similarities_randomized(sample_map_data): + """Test randomized similarity computation.""" + truth = pd.DataFrame({"treatment": ["compound1"], "gene_symbol": ["gene1"]}) + + sims = compute_similarities(truth, sample_map_data, "perturbation", randomize=True) + + assert isinstance(sims, pd.DataFrame) + assert sims.shape == (1, 1) + assert 0 <= sims.iloc[0, 0] <= 1 + + +def test_sample_for_item(): + """Test negative sampling logic.""" + item_data = pd.DataFrame({"gene_symbol": ["gene1", "gene2", "gene3"], "nM_value": [100, 5000, 15000]}) + pool = {"gene1", "gene2", "gene3", "gene4", "gene5"} + + items, labels = sample_for_item( + item_data, + pool, + activity_threshold=1000, + inactivity_threshold=10000, + target_col="gene_symbol", + min_negatives=2, + random_seed=42, + ) + + assert len(items) > 0 + assert len(labels) == len(items) + assert sum(labels) == 1 # Only gene1 should be positive + assert "gene3" not in items # Should be excluded as it's in gray zone + + +def test_compute_metrics(): + """Test metric computation with known values.""" + scores = np.array([0.9, 0.8, 0.3, 0.2]) + labels = np.array([1, 0, 0, 1]) + + ap, auc = compute_metrics(scores, labels) + + assert 0 <= ap <= 1 + assert 0 <= auc <= 1 + # With these specific values, AP should be less than 0.75 + assert ap < 0.75 + + +def test_full_benchmark_macro_compound(sample_truth_data, sample_map_data): + """Test full benchmark with macro averaging by compound.""" + config = BenchmarkConfig( + average_type=AverageType.MACRO, + aggregate_by=AggregateBy.COMPOUND, + min_negatives=2, # Using min_negatives + random_seed=42, + quantiles=[0.25, 0.5, 0.75], # Include quantiles if applicable + ) + + results = compound_gene_benchmark( + map_data=sample_map_data, + activity_threshold=1000, + inactivity_threshold=10000, + truth_data=sample_truth_data, + config=config, + ) + + assert isinstance(results, pd.DataFrame) + assert "concentration" in results.columns + assert "average_precision" in results.columns + assert "auc_roc" in results.columns + assert "baseline_average_precision" in results.columns + assert "baseline_auc_roc" in results.columns + # If quantiles are included + if config.quantiles: + for q in config.quantiles: + assert f"ap_quantile_{q}" in results.columns + assert len(results) > 0 + assert "max" in results["concentration"].values + + +def test_full_benchmark_micro_gene(sample_truth_data, sample_map_data): + """Test full benchmark with micro averaging by gene.""" + config = BenchmarkConfig( + average_type=AverageType.MICRO, + aggregate_by=AggregateBy.GENE, + min_negatives=2, # Using min_negatives + random_seed=42, + ) + + results = compound_gene_benchmark( + map_data=sample_map_data, + activity_threshold=1000, + inactivity_threshold=10000, + truth_data=sample_truth_data, + config=config, + ) + + assert isinstance(results, pd.DataFrame) + assert all(results["baseline_auc_roc"] == 0.5) # Random baseline should have 0.5 AUC-ROC + + +def test_benchmark_edge_cases(sample_map_data): + """Test benchmark behavior with edge cases.""" + # Empty truth data + empty_truth = pd.DataFrame(columns=["treatment", "gene_symbol", "nM_value"]) + config = BenchmarkConfig(random_seed=42) + with pytest.raises(ValueError): + compound_gene_benchmark(map_data=sample_map_data, truth_data=empty_truth, config=config) + + # All inactive data + all_inactive = pd.DataFrame( + { + "treatment": ["compound1"], + "gene_symbol": ["gene1"], + "nM_value": [20000], # Above inactivity threshold + } + ) + results = compound_gene_benchmark(map_data=sample_map_data, truth_data=all_inactive, config=config) + assert len(results) > 0 + assert all(results["average_precision"] == 0.0) # No positives should give 0 AP + + +def test_process_predictions(sample_truth_data, sample_map_data): + """Test prediction processing logic.""" + config = BenchmarkConfig( + min_negatives=2, + random_seed=42, + ) + similarities = compute_similarities(sample_truth_data, sample_map_data, "perturbation") + + predictions = process_predictions( + sample_truth_data, similarities, config, thresholds=(1000, 10000), pert_col="perturbation" + ) + + assert isinstance(predictions, dict) + assert "max" in predictions + assert all(isinstance(p, list) for p in predictions.values()) + + # Check prediction format + for conc, preds in predictions.items(): + for scores, labels in preds: + assert isinstance(scores, np.ndarray) + assert isinstance(labels, np.ndarray) + assert len(scores) == len(labels) + assert set(labels).issubset({0, 1}) + + +def test_config_validation(): + """Test configuration validation.""" + # Invalid average type + with pytest.raises(ValueError): + BenchmarkConfig(average_type="invalid") + + # Invalid aggregate by + with pytest.raises(ValueError): + BenchmarkConfig(aggregate_by="invalid") + + # Invalid min_negatives + with pytest.raises(ValueError): + BenchmarkConfig(min_negatives=-1)