Skip to content

Commit

Permalink
Add a flag for k sparse probing batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen committed Nov 25, 2024
1 parent 6ae8235 commit 6f2e38f
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 18 deletions.
1 change: 1 addition & 0 deletions evals/absorption/eval_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
@dataclass
class AbsorptionEvalConfig(BaseEvalConfig):
model_name: str = Field(
default="",
title="Model Name",
description="Model name. Must be set with a command line argument. For this eval, we currently recommend to only use models >= 2B parameters.",
)
Expand Down
29 changes: 12 additions & 17 deletions evals/absorption/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def run_eval(
"""

if "gemma" not in config.model_name:
print(
"\n\n\nWARNING: We recommend running this eval on LLMS >= 2B parameters\n\n\n"
)
print("\n\n\nWARNING: We recommend running this eval on LLMS >= 2B parameters\n\n\n")

eval_instance_id = get_eval_uuid()
sae_lens_version = get_sae_lens_version()
Expand Down Expand Up @@ -102,9 +100,7 @@ def run_eval(
k_sparse_probing_file = k_sparse_probing_file.replace("/", "_")
k_sparse_probing_path = os.path.join(artifacts_folder, k_sparse_probing_file)
os.makedirs(os.path.dirname(k_sparse_probing_path), exist_ok=True)
k_sparse_probing_results.to_json(
k_sparse_probing_path, orient="records", indent=4
)
k_sparse_probing_results.to_json(k_sparse_probing_path, orient="records", indent=4)

raw_df = run_feature_absortion_experiment(
model=model,
Expand Down Expand Up @@ -198,9 +194,7 @@ def _aggregate_results_df(
)
agg_df["num_split_feats"] = agg_df["split_feats"].apply(len)
agg_df["num_absorption"] = agg_df["is_absorption"]
agg_df["absorption_rate"] = (
agg_df["num_absorption"] / agg_df["num_probe_true_positives"]
)
agg_df["absorption_rate"] = agg_df["num_absorption"] / agg_df["num_probe_true_positives"]
return agg_df


Expand Down Expand Up @@ -276,11 +270,15 @@ def arg_parser():
default=default_config.k_sparse_probe_l1_decay,
help="L1 decay for k-sparse probes.",
)

parser.add_argument(
"--force_rerun", action="store_true", help="Force rerun of experiments"
"--k_sparse_probe_batch_size",
type=float,
default=default_config.k_sparse_probe_batch_size,
help="L1 decay for k-sparse probes.",
)

parser.add_argument("--force_rerun", action="store_true", help="Force rerun of experiments")

return parser


Expand All @@ -295,14 +293,13 @@ def create_config_and_selected_saes(
prompt_token_pos=args.prompt_token_pos,
model_name=args.model_name,
k_sparse_probe_l1_decay=args.k_sparse_probe_l1_decay,
k_sparse_probe_batch_size=args.k_sparse_probe_batch_size,
)

if args.llm_batch_size is not None:
config.llm_batch_size = args.llm_batch_size
else:
config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[
config.model_name
]
config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name]

if args.llm_dtype is not None:
config.llm_dtype = args.llm_dtype
Expand Down Expand Up @@ -343,9 +340,7 @@ def create_config_and_selected_saes(
os.makedirs(args.output_folder, exist_ok=True)

# run the evaluation on all selected SAEs
results_dict = run_eval(
config, selected_saes, device, args.output_folder, args.force_rerun
)
results_dict = run_eval(config, selected_saes, device, args.output_folder, args.force_rerun)

end_time = time.time()

Expand Down
3 changes: 2 additions & 1 deletion shell_scripts/run_reduced_memory.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ for sae_block_pattern in "${sae_block_patterns[@]}"; do
python evals/absorption/main.py \
--sae_regex_pattern "${sae_regex_pattern}" \
--sae_block_pattern "${sae_block_pattern}" \
--model_name ${model_name} --llm_batch_size 4 || {
--model_name ${model_name} --llm_batch_size 4 \
--k_sparse_probe_batch_size 512 || {
echo "Pattern ${sae_block_pattern} failed, continuing to next pattern..."
continue
}
Expand Down
2 changes: 2 additions & 0 deletions shell_scripts/run_reduced_memory_1m_width.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def run_command(cmd, fail_message):
model_name,
"--llm_batch_size",
"4",
"--k_sparse_probe_batch_size",
"512",
]
if run_command(cmd, f"Absorption eval for pattern {sae_block_pattern} failed"):
print(f"Completed absorption eval for pattern {sae_block_pattern}")
Expand Down

0 comments on commit 6f2e38f

Please sign in to comment.