From ec8cd872060ef158ca0ec0ef5e8873546fadeff4 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 25 Oct 2024 22:04:59 +0000 Subject: [PATCH] Improve arg parsing and probe file name --- evals/shift_and_tpp/main.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/evals/shift_and_tpp/main.py b/evals/shift_and_tpp/main.py index ee24e73..5cd5e61 100644 --- a/evals/shift_and_tpp/main.py +++ b/evals/shift_and_tpp/main.py @@ -476,14 +476,14 @@ def run_eval_single_dataset( if not config.spurious_corr: chosen_classes = dataset_info.chosen_classes_per_dataset[dataset_name] activations_filename = f"{dataset_name}_activations.pt".replace("/", "_") - probes_filename = f"{dataset_name}_activations.pkl".replace("/", "_") + probes_filename = f"{dataset_name}_probes.pkl".replace("/", "_") else: chosen_classes = list(dataset_info.PAIRED_CLASS_KEYS.keys()) activations_filename = ( f"{dataset_name}_{column1_vals[0]}_{column1_vals[1]}_activations.pt".replace("/", "_") ) - probes_filename = ( - f"{dataset_name}_{column1_vals[0]}_{column1_vals[1]}_activations.pkl".replace("/", "_") + probes_filename = f"{dataset_name}_{column1_vals[0]}_{column1_vals[1]}_probes.pkl".replace( + "/", "_" ) activations_path = os.path.join(artifacts_folder, activations_filename) @@ -812,9 +812,15 @@ def arg_parser(): parser.add_argument( "--clean_up_activations", action="store_true", help="Clean up activations after evaluation" ) + + def str_to_bool(value): + if value.lower() in ("true", "false"): + return value.lower() == "true" + raise argparse.ArgumentTypeError("Boolean value expected.") + parser.add_argument( "--spurious_corr", - action="store_true", + type=str_to_bool, required=True, help="If true, do Spurious Correlation Removal. If false, do TPP.", ) @@ -828,7 +834,7 @@ def arg_parser(): --sae_regex_pattern "sae_bench_pythia70m_sweep_standard_ctx128_0712" \ --sae_block_pattern "blocks.4.hook_resid_post__trainer_10" \ --model_name pythia-70m-deduped \ - --spurious_corr + --spurious_corr true """