Skip to content

Commit

Permalink
No longer aggregate over saes in a dict
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen committed Oct 25, 2024
1 parent df72d30 commit 0e5bccb
Showing 1 changed file with 11 additions and 17 deletions.
28 changes: 11 additions & 17 deletions sae_bench_utils/formatting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,31 +242,25 @@ def filter_by_l0_threshold(results: dict, l0_threshold: Optional[int]) -> dict:


def average_results_dictionaries(
results_dict: dict[str, dict[str, dict[str, float]]], dataset_names: list[str]
) -> dict[str, dict[str, float]]:
results_dict: dict[str, dict[str, float]], dataset_names: list[str]
) -> dict[str, float]:
"""If we have multiple dicts of results from separate datasets, get an average performance over all datasets.
Results_dict is dataset -> sae_name -> dict of metric_name : float result"""
Results_dict is dataset -> dict of metric_name : float result"""
averaged_results = {}
aggregated_results = {}

for dataset_name in dataset_names:
dataset_results = results_dict[f"{dataset_name}_results"]

for sae_name, sae_metrics in dataset_results.items():
if sae_name not in aggregated_results:
aggregated_results[sae_name] = {}

for metric_name, metric_value in sae_metrics.items():
if metric_name not in aggregated_results[sae_name]:
aggregated_results[sae_name][metric_name] = []
for metric_name, metric_value in dataset_results.items():
if metric_name not in aggregated_results:
aggregated_results[metric_name] = []

aggregated_results[sae_name][metric_name].append(metric_value)
aggregated_results[metric_name].append(metric_value)

# Compute averages
for sae_name in aggregated_results:
averaged_results[sae_name] = {}
for metric_name, values in aggregated_results[sae_name].items():
average_value = sum(values) / len(values)
averaged_results[sae_name][metric_name] = average_value
averaged_results = {}
for metric_name, values in aggregated_results.items():
average_value = sum(values) / len(values)
averaged_results[metric_name] = average_value

return averaged_results

0 comments on commit 0e5bccb

Please sign in to comment.