Skip to content

Commit

Permalink
bugfix sae_bench prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
canrager committed Sep 30, 2024
1 parent 54a6156 commit da9f95f
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion formatting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def extract_saes_unique_info(sae_names: list[str], checkpoint_only: bool = False
return infos


def filter_sae_names(sae_names: Union[list, str], layers: list[int], trainer_ids: list[int], include_checkpoints: bool) -> list[str]:
def filter_sae_names(sae_names: Union[list, str], layers: list[int], trainer_ids: list[int], include_checkpoints: bool, drop_sae_bench_prefix: bool = False) -> list[str]:
'''Filter SAE names based on layer, trainer_id, and whether they are checkpoints
Args:
sae_names: List of SAE names or a string representing a release name
Expand All @@ -110,6 +110,8 @@ def filter_sae_names(sae_names: Union[list, str], layers: list[int], trainer_ids
if info['layer'] in layers \
and info['trainer_id'] in trainer_ids \
and info['is_checkpoint'] == include_checkpoints:
if drop_sae_bench_prefix:
sae_name = sae_name.replace("sae_bench_", "")
filtered_sae_names.append(sae_name)

return filtered_sae_names
Expand Down

0 comments on commit da9f95f

Please sign in to comment.