Skip to content

Commit

Permalink
Implement alternative compound-gene benchmarking (#81)
Browse files Browse the repository at this point in the history
* slow v1

* sampling v2

* fixes

* flexibility, v3

* fix

* tests

* move tests properly

* fix baseline

* add quantiles

* update tests

* newline

* format
  • Loading branch information
johnurbanik authored Nov 8, 2024
1 parent 532efe7 commit 67870a2
Show file tree
Hide file tree
Showing 2 changed files with 486 additions and 51 deletions.
321 changes: 277 additions & 44 deletions efaar_benchmarking/benchmarking.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,43 @@
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple

import numpy as np
import pandas as pd
from geomloss import SamplesLoss
from joblib import Parallel, delayed
from scipy.stats import hypergeom, ks_2samp
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.metrics import roc_auc_score
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.utils import Bunch
from torch import from_numpy

import efaar_benchmarking.constants as cst


class AverageType(Enum):
MICRO = "micro"
MACRO = "macro"


class AggregateBy(Enum):
COMPOUND = "compound"
GENE = "gene"


@dataclass
class BenchmarkConfig:
"""Configuration for benchmark computation."""

average_type: AverageType = AverageType.MACRO
aggregate_by: AggregateBy = AggregateBy.COMPOUND
min_negatives: int = 20
n_baseline_sims: int = 100
random_seed: int = 42
quantiles: Optional[List[float]] = None # New parameter for quantiles


def pert_signal_consistency_metric(
arr: np.ndarray, sorted_null: np.ndarray = np.array([])
) -> float | None | tuple[float | None, float | None]:
Expand Down Expand Up @@ -547,56 +572,264 @@ def cosine_similarity_from_map(
return compound_values.dot(gene_values) / (np.linalg.norm(compound_values) * np.linalg.norm(gene_values))


def compound_gene_benchmark(
map_data: pd.DataFrame,
nM_activity_threshold: float = 1000,
pert_col: str = "perturbation",
benchmark_data_dir: str = cst.BENCHMARK_DATA_DIR,
) -> tuple[pd.DataFrame, dict]:
"""Compute benchmarks for compound-gene pairs.
def load_truth_data(benchmark_data_dir: str) -> pd.DataFrame:
"""Load the ground truth data from a CSV file."""
truth_data_path = Path(benchmark_data_dir) / "compound_gene_interactions.csv"
return pd.read_csv(truth_data_path)

Args:
map_data (pd.DataFrame): DataFrame containing the embeddings and metadata.
nM_activity_threshold (float): Concentration threshold to use for the benchmark in nM.
Ground truth relationships below this threshold will be considered as positives. Default is 1000 nM.
pert_col (str): Column name containing the perturbation information. Default is "perturbation".
benchmark_data_dir (str): Directory to save the benchmark data. Default is cst.BENCHMARK_DATA_DIR.

Returns:
tuple[pd.DataFrame, dict]: A tuple containing:
- A DataFrame containing the average precision scores for different concentrations and the maximum value.
- A dictionary containing the precision-recall curves for different concentrations and the maximum value.
def compute_similarities(
truth: pd.DataFrame,
map_data: Bunch,
pert_col: str,
randomize: bool = False,
) -> pd.DataFrame:
"""Compute cosine similarities between compounds and genes."""
treatments = truth["treatment"].unique()
genes = truth["gene_symbol"].unique()

"""
compound_meta = map_data.metadata[map_data.metadata[pert_col].isin(treatments)].copy()
gene_meta = map_data.metadata[map_data.metadata[pert_col].isin(genes)].copy()

truth = pd.read_csv(Path(benchmark_data_dir).joinpath("compound_gene_interactions.csv"))
truth["active"] = truth["nM_value"] <= nM_activity_threshold
if compound_meta.empty or gene_meta.empty:
raise ValueError("No matching compounds or genes found in metadata.")

for conc in cst.COMPOUND_CONCENTRATIONS:
truth[f"cosine_similarity_{conc}"] = truth.apply(
lambda x: cosine_similarity_from_map(x["treatment"], x["gene_symbol"], conc, map_data, pert_col), axis=1
)
truth[f"cosine_similarity_{conc}"] = truth[f"cosine_similarity_{conc}"].apply(
lambda x: abs(x) if x is not None else x
)
truth["cosine_similarity_max"] = truth[[col for col in truth.columns if col.startswith("cosine_similarity_")]].max(
axis=1, skipna=True
if randomize:
rng = np.random.default_rng(cst.RANDOM_SEED)
similarities = rng.uniform(0, 1, size=(len(compound_meta), len(gene_meta)))
else:
compound_features = map_data.features.loc[compound_meta.index]
gene_features = map_data.features.loc[gene_meta.index]
similarities = np.abs(cosine_similarity(compound_features, gene_features))

index = pd.MultiIndex.from_arrays(
[compound_meta[pert_col].values, compound_meta["concentration"].values],
names=[pert_col, "concentration"],
)

curves, aps = {}, {}
for conc_str in cst.COMPOUND_CONCENTRATIONS + ["max"]:
cos_sim = truth[f"cosine_similarity_{conc_str}"]
to_keep = ~cos_sim.isna()
cos_sim = cos_sim[to_keep]
labels = truth["active"][to_keep]
return pd.DataFrame(similarities, index=index, columns=gene_meta[pert_col].values)


def compute_baseline_predictions(
predictions: Dict[str, List[Tuple[np.ndarray, np.ndarray]]],
config: BenchmarkConfig,
) -> Dict[str, List[Tuple[np.ndarray, np.ndarray]]]:
"""Generate baseline predictions using random scores for the same labels."""
rng = np.random.default_rng(config.random_seed)
baseline_predictions = {}
for conc, preds in predictions.items():
baseline_preds = []
for _, labels in preds:
scores = rng.random(len(labels))
baseline_preds.append((scores, labels))
baseline_predictions[conc] = baseline_preds
return baseline_predictions


def aggregate_predictions(
predictions: Dict[str, List[Tuple[np.ndarray, np.ndarray]]],
config: BenchmarkConfig,
) -> Dict[str, Dict[str, float]]:
"""Aggregate predictions across compounds or genes."""
results = {}
for conc, preds in predictions.items():
if not preds:
# Initialize result dictionary with zeros and default quantiles
result = {"average_precision": 0.0, "auc_roc": 0.5}
if config.quantiles:
for q in config.quantiles:
result[f"ap_quantile_{q}"] = 0.0
results[conc] = result
continue

if config.average_type == AverageType.MICRO:
scores = np.concatenate([p[0] for p in preds])
labels = np.concatenate([p[1] for p in preds])
ap, auc = compute_metrics(scores, labels)
result = {"average_precision": ap, "auc_roc": auc}
if config.quantiles:
# For micro averaging, quantiles are not applicable; set to overall AP
for q in config.quantiles:
result[f"ap_quantile_{q}"] = ap
else: # MACRO
aps = []
aucs = []
for scores, labels in preds:
if len(scores) == 0 or not np.any(labels):
continue
ap, auc = compute_metrics(scores, labels)
aps.append(ap)
aucs.append(auc)
if aps:
mean_ap = np.mean(aps)
mean_auc = np.mean(aucs)
result = {"average_precision": mean_ap, "auc_roc": mean_auc}
if config.quantiles:
# Compute quantiles
for q in config.quantiles:
quantile_value = np.quantile(aps, q)
result[f"ap_quantile_{q}"] = quantile_value
else:
result = {"average_precision": 0.0, "auc_roc": 0.5}
if config.quantiles:
for q in config.quantiles:
result[f"ap_quantile_{q}"] = 0.0
results[conc] = result
return results


def compute_metrics(scores: np.ndarray, labels: np.ndarray) -> Tuple[float, float]:
"""Compute average precision and AUC-ROC."""
if len(scores) == 0 or not np.any(labels):
return 0.0, 0.5

sorted_indices = np.argsort(scores)[::-1]
sorted_labels = labels[sorted_indices]
tp_cumsum = np.cumsum(sorted_labels)
precision = tp_cumsum / np.arange(1, len(sorted_labels) + 1)
ap = np.sum(precision * sorted_labels) / sorted_labels.sum()
auc = roc_auc_score(labels, scores)
return ap, auc


def sample_for_item(
item_data: pd.DataFrame,
pool: Set[str],
activity_threshold: float,
inactivity_threshold: float,
target_col: str,
min_negatives: int = 20,
random_seed: int = 42,
) -> Tuple[np.ndarray, np.ndarray]:
"""Generic sampling function for both compounds and genes."""
rng = np.random.default_rng(random_seed)

actives = item_data.loc[item_data["nM_value"] <= activity_threshold, target_col].unique()
if len(actives) == 0:
return np.array([]), np.array([])

ineligibles = item_data.loc[
(item_data["nM_value"] > activity_threshold) & (item_data["nM_value"] <= inactivity_threshold), target_col
].unique()

eligibles = pool - set(actives) - set(ineligibles)
n_negatives = max(2 * len(actives), min_negatives)

if len(eligibles) < n_negatives:
return np.array([]), np.array([])

negatives = rng.choice(list(eligibles), n_negatives, replace=False)
items = np.concatenate([actives, negatives])
labels = np.isin(items, actives).astype(int)

return items, labels


def process_predictions(
data: pd.DataFrame,
similarities: pd.DataFrame,
config: BenchmarkConfig,
thresholds: Tuple[float, float],
pert_col: str = "perturbation",
) -> Dict[str, List[Tuple[np.ndarray, np.ndarray]]]:
"""Process predictions for either compounds or genes."""
activity_threshold, inactivity_threshold = thresholds
predictions = {conc: [] for conc in cst.COMPOUND_CONCENTRATIONS + ["max"]}

if config.aggregate_by == AggregateBy.COMPOUND:
pool = set(data["gene_symbol"].unique())
item_col, target_col = "treatment", "gene_symbol"
else:
pool = set(data["treatment"].unique())
item_col, target_col = "gene_symbol", "treatment"

for item in data[item_col].unique():
if config.aggregate_by == AggregateBy.COMPOUND and item not in similarities.index.get_level_values(pert_col):
continue

item_data = data[data[item_col] == item]
targets, labels = sample_for_item(
item_data,
pool,
activity_threshold,
inactivity_threshold,
target_col,
config.min_negatives,
config.random_seed,
)

if len(targets) == 0:
continue

scores_by_conc = {}
if config.aggregate_by == AggregateBy.COMPOUND:
item_similarities = similarities.loc[item]
for conc in item_similarities.index.unique():
scores = item_similarities.loc[conc, targets].values
if not np.all(np.isnan(scores)):
predictions[conc].append((scores, labels))
scores_by_conc[conc] = scores
if scores_by_conc:
max_scores = np.nanmax(np.vstack(list(scores_by_conc.values())), axis=0)
if not np.all(np.isnan(max_scores)):
predictions["max"].append((max_scores, labels))
else:
for conc in cst.COMPOUND_CONCENTRATIONS:
try:
sim_conc = similarities.xs(conc, level="concentration")
available = sim_conc.index.intersection(targets)
if not available.empty:
sim_conc_item = sim_conc.loc[available, item]
scores = sim_conc_item.values
labels_filtered = labels[np.isin(targets, available)]
if not np.all(np.isnan(scores)):
predictions[conc].append((scores, labels_filtered))
scores_by_conc[conc] = pd.Series(scores, index=available)
except KeyError:
continue
if scores_by_conc:
available = set().union(*[scores_by_conc[conc].index for conc in scores_by_conc])
if available:
available = list(available)
scores_df = pd.DataFrame({conc: scores_by_conc[conc] for conc in scores_by_conc}, index=available)
max_scores = scores_df.max(axis=1).values
labels_filtered = labels[np.isin(targets, available)]
if not np.all(np.isnan(max_scores)):
predictions["max"].append((max_scores, labels_filtered))

return predictions


def compound_gene_benchmark(
map_data: Bunch,
activity_threshold: float = 1000,
inactivity_threshold: float = 10000,
pert_col: str = "perturbation",
benchmark_data_dir: str = cst.BENCHMARK_DATA_DIR,
truth_data: Optional[pd.DataFrame] = None,
check_random: bool = False,
config: Optional[BenchmarkConfig] = None,
) -> pd.DataFrame:
"""Main benchmark function."""
config = config or BenchmarkConfig()
truth = truth_data if truth_data is not None else load_truth_data(benchmark_data_dir)
similarities = compute_similarities(truth, map_data, pert_col, randomize=check_random)

thresholds = (activity_threshold, inactivity_threshold)
predictions = process_predictions(truth, similarities, config, thresholds, pert_col)
baseline_preds = compute_baseline_predictions(predictions, config)

results_dict = aggregate_predictions(predictions, config)
baseline_dict = aggregate_predictions(baseline_preds, config)

try:
precision, recall, _ = precision_recall_curve(labels, cos_sim)
curves[f"{conc_str}"] = (precision, recall)
aps[f"{conc_str}"] = average_precision_score(labels, cos_sim)
except ValueError:
pass
# Convert results to DataFrame
results = pd.DataFrame.from_dict(results_dict, orient="index").reset_index()
results.rename(columns={"index": "concentration"}, inplace=True)

aps["baseline"] = truth["active"].mean()
# Convert baseline results to DataFrame
baseline = pd.DataFrame.from_dict(baseline_dict, orient="index").reset_index()
baseline.rename(columns={"index": "concentration"}, inplace=True)
# Merge baseline metrics with results
results = results.merge(baseline, on="concentration", suffixes=("", "_baseline"))

return pd.DataFrame(aps.items(), columns=["concentration", "average_precision"]), curves
return results
Loading

0 comments on commit 67870a2

Please sign in to comment.