From 0e5bccb4fa57a34c6c92294312a8764257993f56 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 25 Oct 2024 17:11:41 +0000 Subject: [PATCH] No longer aggregate over saes in a dict --- sae_bench_utils/formatting_utils.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/sae_bench_utils/formatting_utils.py b/sae_bench_utils/formatting_utils.py index 72e02bf..b4ac364 100644 --- a/sae_bench_utils/formatting_utils.py +++ b/sae_bench_utils/formatting_utils.py @@ -242,31 +242,25 @@ def filter_by_l0_threshold(results: dict, l0_threshold: Optional[int]) -> dict: def average_results_dictionaries( - results_dict: dict[str, dict[str, dict[str, float]]], dataset_names: list[str] -) -> dict[str, dict[str, float]]: + results_dict: dict[str, dict[str, float]], dataset_names: list[str] +) -> dict[str, float]: """If we have multiple dicts of results from separate datasets, get an average performance over all datasets. - Results_dict is dataset -> sae_name -> dict of metric_name : float result""" + Results_dict is dataset -> dict of metric_name : float result""" averaged_results = {} aggregated_results = {} for dataset_name in dataset_names: dataset_results = results_dict[f"{dataset_name}_results"] - for sae_name, sae_metrics in dataset_results.items(): - if sae_name not in aggregated_results: - aggregated_results[sae_name] = {} - - for metric_name, metric_value in sae_metrics.items(): - if metric_name not in aggregated_results[sae_name]: - aggregated_results[sae_name][metric_name] = [] + for metric_name, metric_value in dataset_results.items(): + if metric_name not in aggregated_results: + aggregated_results[metric_name] = [] - aggregated_results[sae_name][metric_name].append(metric_value) + aggregated_results[metric_name].append(metric_value) - # Compute averages - for sae_name in aggregated_results: - averaged_results[sae_name] = {} - for metric_name, values in aggregated_results[sae_name].items(): - average_value = sum(values) / len(values) - averaged_results[sae_name][metric_name] = average_value + averaged_results = {} + for metric_name, values in aggregated_results.items(): + average_value = sum(values) / len(values) + averaged_results[metric_name] = average_value return averaged_results