Skip to content

Commit

Permalink
Adapt auto interp to enable use with custom saes
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen committed Nov 14, 2024
1 parent 7e2ac58 commit 7ea0e59
Showing 1 changed file with 113 additions and 14 deletions.
127 changes: 113 additions & 14 deletions evals/autointerp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ def run_eval_single_sae(
).to(device)
torch.save(tokenized_dataset, tokens_path)

print(f"Loaded tokenized dataset of shape {tokenized_dataset.shape}")

if sae_sparsity is None:
sae_sparsity = activation_collection.get_feature_activation_sparsity(
tokenized_dataset,
Expand All @@ -510,18 +512,17 @@ def run_eval_single_sae(

def run_eval(
config: AutoInterpEvalConfig,
selected_saes_dict: dict[str, list[str]],
selected_saes_dict: dict[str, list[str] | SAE],
device: str,
api_key: str,
output_path: str,
force_rerun: bool = False,
save_logs_path: Optional[str] = None,
) -> dict[str, Any]:
"""
Runs autointerp eval. Returns results as a dict with the following structure:
custom_eval_config - dict of config parameters used for this evaluation
custom_eval_results - nested dict of {sae_name: {"score": score}}
selected_saes_dict is a dict mapping either:
- Release name -> list of SAE IDs to load from that release
- Custom name -> Single SAE object
"""
eval_instance_id = get_eval_uuid()
sae_lens_version = get_sae_lens_version()
Expand All @@ -548,17 +549,32 @@ def run_eval(
f"Running evaluation for SAE release: {sae_release}, SAEs: {selected_saes_dict[sae_release]}"
)

# Wrap single SAE objects in a list to unify processing of both pretrained and custom SAEs
if not isinstance(selected_saes_dict[sae_release], list):
selected_saes_dict[sae_release] = [selected_saes_dict[sae_release]]

for sae_id in tqdm(
selected_saes_dict[sae_release],
desc="Running SAE evaluation on all selected SAEs",
):
gc.collect()
torch.cuda.empty_cache()
sae, _, sparsity = SAE.from_pretrained(sae_release, sae_id, device=str(device))

# Handle both pretrained SAEs (identified by string) and custom SAEs (passed as objects)
if isinstance(sae_id, str):
sae, _, sparsity = SAE.from_pretrained(
release=sae_release,
sae_id=sae_id,
device=device,
)
else:
sae = sae_id
sae_id = "custom_sae"
sparsity = None

sae = sae.to(device=device, dtype=llm_dtype)

artifacts_folder = os.path.join(artifacts_base_folder, EVAL_TYPE_ID_AUTOINTERP)
os.makedirs(artifacts_folder, exist_ok=True)

sae_result_file = f"{sae_release}_{sae_id}_eval_results.json"
sae_result_file = sae_result_file.replace("/", "_")
Expand Down Expand Up @@ -601,7 +617,6 @@ def run_eval(

# Put important results into the results dict
score = sum([r["score"] for r in sae_eval_result.values()]) / len(sae_eval_result)
eval_result_metrics = {"autointerp_metrics": {"autointerp_score": score}}

eval_output = AutoInterpEvalOutput(
eval_config=config,
Expand Down Expand Up @@ -696,19 +711,26 @@ def arg_parser():
--sae_block_pattern "blocks.4.hook_resid_post__trainer_10" \
--model_name pythia-70m-deduped \
--api_key <API_KEY>
python evals/autointerp/main.py \
--sae_regex_pattern "gemma-scope-2b-pt-res" \
--sae_block_pattern "layer_20/width_16k/average_l0_139" \
--model_name gemma-2-2b \
--api_key <API_KEY>
"""
args = arg_parser().parse_args()
device = setup_environment()

start_time = time.time()

config, selected_saes_dict = create_config_and_selected_saes(args)

sae_regex_patterns = None
sae_block_pattern = None

# Uncomment these to select multiple SAEs based on multiple regex patterns
# This will override the sae_regex_pattern and sae_block_pattern arguments
sae_regex_patterns = [
r"(sae_bench_pythia70m_sweep_topk_ctx128_0730).*",
r"(sae_bench_pythia70m_sweep_standard_ctx128_0712).*",
Expand All @@ -718,10 +740,17 @@ def arg_parser():
r".*blocks\.([4])\.hook_resid_post__trainer_(2|6|10|14)$",
]

sae_regex_patterns = None
sae_block_pattern = None

config, selected_saes_dict = create_config_and_selected_saes(args)
# For Gemma-2-2b
sae_regex_patterns = [
r"sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824",
r"sae_bench_gemma-2-2b_sweep_standard_ctx128_ef8_0824",
r"(gemma-scope-2b-pt-res)",
]
sae_block_pattern = [
r".*blocks\.19(?!.*step).*",
r".*blocks\.19(?!.*step).*",
r".*layer_(19).*(16k).*",
]

if sae_regex_patterns is not None:
selected_saes_dict = select_saes_multiple_patterns(sae_regex_patterns, sae_block_pattern)
Expand Down Expand Up @@ -749,3 +778,73 @@ def arg_parser():
end_time = time.time()

print(f"Finished evaluation in {end_time - start_time} seconds")


# Use this code snippet to use custom SAE objects
# if __name__ == "__main__":
# """
# python evals/autointerp/main.py
# NOTE: We don't use argparse here. This requires a file openai_api_key.txt to be present in the root directory.
# """

# import baselines.identity_sae as identity_sae
# import baselines.jumprelu_sae as jumprelu_sae

# device = setup_environment()

# start_time = time.time()

# random_seed = 42
# output_folder = "evals/autointerp/results"

# with open("openai_api_key.txt", "r") as f:
# api_key = f.read().strip()

# baseline_type = "identity_sae"
# # baseline_type = "jumprelu_sae"

# model_name = "pythia-70m-deduped"
# hook_layer = 4
# d_model = 512

# # model_name = "gemma-2-2b"
# # hook_layer = 19
# # d_model = 2304

# if baseline_type == "identity_sae":
# sae = identity_sae.IdentitySAE(model_name, d_model=d_model, hook_layer=hook_layer)
# selected_saes_dict = {f"{model_name}_layer_{hook_layer}_identity_sae": sae}
# elif baseline_type == "jumprelu_sae":
# repo_id = "google/gemma-scope-2b-pt-res"
# filename = "layer_20/width_16k/average_l0_71/params.npz"
# sae = jumprelu_sae.load_jumprelu_sae(repo_id, filename, 20)
# selected_saes_dict = {f"{repo_id}_{filename}_gemmascope_sae": sae}
# else:
# raise ValueError(f"Invalid baseline type: {baseline_type}")

# config = AutoInterpEvalConfig(
# random_seed=random_seed,
# model_name=model_name,
# )

# config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name]
# config.llm_dtype = str(activation_collection.LLM_NAME_TO_DTYPE[config.model_name]).split(".")[
# -1
# ]

# # create output folder
# os.makedirs(output_folder, exist_ok=True)

# # run the evaluation on all selected SAEs
# results_dict = run_eval(
# config,
# selected_saes_dict,
# device,
# api_key,
# output_folder,
# force_rerun=True,
# )

# end_time = time.time()

# print(f"Finished evaluation in {end_time - start_time} seconds")

0 comments on commit 7ea0e59

Please sign in to comment.