Skip to content

Commit

Permalink
Improve arg parsing and probe file name
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen committed Oct 25, 2024
1 parent 929cdc0 commit ec8cd87
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions evals/shift_and_tpp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,14 +476,14 @@ def run_eval_single_dataset(
if not config.spurious_corr:
chosen_classes = dataset_info.chosen_classes_per_dataset[dataset_name]
activations_filename = f"{dataset_name}_activations.pt".replace("/", "_")
probes_filename = f"{dataset_name}_activations.pkl".replace("/", "_")
probes_filename = f"{dataset_name}_probes.pkl".replace("/", "_")
else:
chosen_classes = list(dataset_info.PAIRED_CLASS_KEYS.keys())
activations_filename = (
f"{dataset_name}_{column1_vals[0]}_{column1_vals[1]}_activations.pt".replace("/", "_")
)
probes_filename = (
f"{dataset_name}_{column1_vals[0]}_{column1_vals[1]}_activations.pkl".replace("/", "_")
probes_filename = f"{dataset_name}_{column1_vals[0]}_{column1_vals[1]}_probes.pkl".replace(
"/", "_"
)

activations_path = os.path.join(artifacts_folder, activations_filename)
Expand Down Expand Up @@ -812,9 +812,15 @@ def arg_parser():
parser.add_argument(
"--clean_up_activations", action="store_true", help="Clean up activations after evaluation"
)

def str_to_bool(value):
if value.lower() in ("true", "false"):
return value.lower() == "true"
raise argparse.ArgumentTypeError("Boolean value expected.")

parser.add_argument(
"--spurious_corr",
action="store_true",
type=str_to_bool,
required=True,
help="If true, do Spurious Correlation Removal. If false, do TPP.",
)
Expand All @@ -828,7 +834,7 @@ def arg_parser():
--sae_regex_pattern "sae_bench_pythia70m_sweep_standard_ctx128_0712" \
--sae_block_pattern "blocks.4.hook_resid_post__trainer_10" \
--model_name pythia-70m-deduped \
--spurious_corr
--spurious_corr true
"""
Expand Down

0 comments on commit ec8cd87

Please sign in to comment.