Skip to content

Commit

Permalink
Add function to average results from multiple runs
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen committed Oct 4, 2024
1 parent 9c42b2f commit 47af366
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import utils.formatting_utils as formatting_utils
import utils.testing_utils as testing_utils


def test_average_results():
# Prepare test data
results_dict = {
"dataset1_results": {
"sae1": {"accuracy": 0.8, "loss": 0.5},
"sae2": {"accuracy": 0.75, "loss": 0.6},
},
"dataset2_results": {
"sae1": {"accuracy": 0.85, "loss": 0.4},
"sae2": {"accuracy": 0.7, "loss": 0.65},
},
}
dataset_names = ["dataset1", "dataset2"]

# Expected output
expected_output = {
"sae1": {"accuracy": 0.825, "loss": 0.45},
"sae2": {"accuracy": 0.725, "loss": 0.625},
}

# Call the function
output = formatting_utils.average_results_dictionaries(results_dict, dataset_names)

testing_utils.compare_dicts_within_tolerance(output, expected_output, tolerance=1e-6)
31 changes: 31 additions & 0 deletions utils/formatting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,34 @@ def filter_by_l0_threshold(results: dict, l0_threshold: Optional[int]) -> dict:
# Replace the original results with the filtered results
results = filtered_results
return results


def average_results_dictionaries(
results_dict: dict[str, dict[str, dict[str, float]]], dataset_names: list[str]
) -> dict[str, dict[str, float]]:
"""If we have multiple dicts of results from separate datasets, get an average performance over all datasets.
Results_dict is dataset -> sae_name -> dict of metric_name : float result"""
averaged_results = {}
aggregated_results = {}

for dataset_name in dataset_names:
dataset_results = results_dict[f"{dataset_name}_results"]

for sae_name, sae_metrics in dataset_results.items():
if sae_name not in aggregated_results:
aggregated_results[sae_name] = {}

for metric_name, metric_value in sae_metrics.items():
if metric_name not in aggregated_results[sae_name]:
aggregated_results[sae_name][metric_name] = []

aggregated_results[sae_name][metric_name].append(metric_value)

# Compute averages
for sae_name in aggregated_results:
averaged_results[sae_name] = {}
for metric_name, values in aggregated_results[sae_name].items():
average_value = sum(values) / len(values)
averaged_results[sae_name][metric_name] = average_value

return averaged_results

0 comments on commit 47af366

Please sign in to comment.