From 6f2e38f6481933249b70185f9d3b68737eac44a1 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Mon, 25 Nov 2024 23:12:07 +0000 Subject: [PATCH] Add a flag for k sparse probing batch size --- evals/absorption/eval_config.py | 1 + evals/absorption/main.py | 29 ++++++++------------ shell_scripts/run_reduced_memory.sh | 3 +- shell_scripts/run_reduced_memory_1m_width.py | 2 ++ 4 files changed, 17 insertions(+), 18 deletions(-) diff --git a/evals/absorption/eval_config.py b/evals/absorption/eval_config.py index 87377a9..25f9f1d 100644 --- a/evals/absorption/eval_config.py +++ b/evals/absorption/eval_config.py @@ -7,6 +7,7 @@ @dataclass class AbsorptionEvalConfig(BaseEvalConfig): model_name: str = Field( + default="", title="Model Name", description="Model name. Must be set with a command line argument. For this eval, we currently recommend to only use models >= 2B parameters.", ) diff --git a/evals/absorption/main.py b/evals/absorption/main.py index 81bafba..bce9d1b 100644 --- a/evals/absorption/main.py +++ b/evals/absorption/main.py @@ -44,9 +44,7 @@ def run_eval( """ if "gemma" not in config.model_name: - print( - "\n\n\nWARNING: We recommend running this eval on LLMS >= 2B parameters\n\n\n" - ) + print("\n\n\nWARNING: We recommend running this eval on LLMS >= 2B parameters\n\n\n") eval_instance_id = get_eval_uuid() sae_lens_version = get_sae_lens_version() @@ -102,9 +100,7 @@ def run_eval( k_sparse_probing_file = k_sparse_probing_file.replace("/", "_") k_sparse_probing_path = os.path.join(artifacts_folder, k_sparse_probing_file) os.makedirs(os.path.dirname(k_sparse_probing_path), exist_ok=True) - k_sparse_probing_results.to_json( - k_sparse_probing_path, orient="records", indent=4 - ) + k_sparse_probing_results.to_json(k_sparse_probing_path, orient="records", indent=4) raw_df = run_feature_absortion_experiment( model=model, @@ -198,9 +194,7 @@ def _aggregate_results_df( ) agg_df["num_split_feats"] = agg_df["split_feats"].apply(len) agg_df["num_absorption"] = agg_df["is_absorption"] - agg_df["absorption_rate"] = ( - agg_df["num_absorption"] / agg_df["num_probe_true_positives"] - ) + agg_df["absorption_rate"] = agg_df["num_absorption"] / agg_df["num_probe_true_positives"] return agg_df @@ -276,11 +270,15 @@ def arg_parser(): default=default_config.k_sparse_probe_l1_decay, help="L1 decay for k-sparse probes.", ) - parser.add_argument( - "--force_rerun", action="store_true", help="Force rerun of experiments" + "--k_sparse_probe_batch_size", + type=float, + default=default_config.k_sparse_probe_batch_size, + help="L1 decay for k-sparse probes.", ) + parser.add_argument("--force_rerun", action="store_true", help="Force rerun of experiments") + return parser @@ -295,14 +293,13 @@ def create_config_and_selected_saes( prompt_token_pos=args.prompt_token_pos, model_name=args.model_name, k_sparse_probe_l1_decay=args.k_sparse_probe_l1_decay, + k_sparse_probe_batch_size=args.k_sparse_probe_batch_size, ) if args.llm_batch_size is not None: config.llm_batch_size = args.llm_batch_size else: - config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[ - config.model_name - ] + config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name] if args.llm_dtype is not None: config.llm_dtype = args.llm_dtype @@ -343,9 +340,7 @@ def create_config_and_selected_saes( os.makedirs(args.output_folder, exist_ok=True) # run the evaluation on all selected SAEs - results_dict = run_eval( - config, selected_saes, device, args.output_folder, args.force_rerun - ) + results_dict = run_eval(config, selected_saes, device, args.output_folder, args.force_rerun) end_time = time.time() diff --git a/shell_scripts/run_reduced_memory.sh b/shell_scripts/run_reduced_memory.sh index ace9119..9ff7071 100755 --- a/shell_scripts/run_reduced_memory.sh +++ b/shell_scripts/run_reduced_memory.sh @@ -15,7 +15,8 @@ for sae_block_pattern in "${sae_block_patterns[@]}"; do python evals/absorption/main.py \ --sae_regex_pattern "${sae_regex_pattern}" \ --sae_block_pattern "${sae_block_pattern}" \ - --model_name ${model_name} --llm_batch_size 4 || { + --model_name ${model_name} --llm_batch_size 4 \ + --k_sparse_probe_batch_size 512 || { echo "Pattern ${sae_block_pattern} failed, continuing to next pattern..." continue } diff --git a/shell_scripts/run_reduced_memory_1m_width.py b/shell_scripts/run_reduced_memory_1m_width.py index 4a109ba..2317fb6 100644 --- a/shell_scripts/run_reduced_memory_1m_width.py +++ b/shell_scripts/run_reduced_memory_1m_width.py @@ -149,6 +149,8 @@ def run_command(cmd, fail_message): model_name, "--llm_batch_size", "4", + "--k_sparse_probe_batch_size", + "512", ] if run_command(cmd, f"Absorption eval for pattern {sae_block_pattern} failed"): print(f"Completed absorption eval for pattern {sae_block_pattern}")