diff --git a/README.md b/README.md index 325c641..ef8c639 100644 --- a/README.md +++ b/README.md @@ -1,60 +1,79 @@ -## SAE Bench: Template for custom evals +# SAE Bench -This repo contains the template we would like to use for the SAE Bench project. The `template.ipynb` explains the input to your custom eval (SAEs hosted on SAELens) and the output (a standardized results file). +## Table of Contents +- [Installation](#installation) +- [Overview](#overview) +- [Running Evaluations](#running-evaluations) +- [Custom SAE Usage](#custom-sae-usage) +- [Training Your Own SAEs](#training-your-own-saes) +- [Graphing Results](#graphing-results) ### Installation -Set up a virtual environment running on `python 3.11.9`. +Set up a virtual environment with python >= 3.10. ``` git clone https://github.com/adamkarvonen/SAE_Bench_Template.git cd SAE_Bench_Template pip install -e . ``` -### Quick start: - -1. Browse through `template.ipynb` to learn about input and output of your metric. -2. Execute `main.py` in `evals/sparse_probing` to see an example implementation. -3. Inspect sparse probing results with `graphing.ipynb`. +All evals can be ran with current batch sizes on Gemma-2-2B on a 24GB VRAM GPU (e.g. a RTX 3090). ## Overview -The `evals/sparse_probing` folder contains a full example implementation of a custom eval. In `main.py`, we have a function that takes a list of SAELens SAE names (defined in `eval_config.py`) and an sae release and returns a dictionary of results in a standard format. This folder contains some helper functions, like pre-computing model activations in `utils/activation_collection.py`, that might be useful for you, too! We try to reuse functions as much as possible across evals to reduce bugs. Let Adam and Can know if you've implemented a helper function that might be useful for other evals as well (like autointerp, feature scoring). +SAE Bench is a comprehensive suite of 8 evaluations for Sparse Autoencoder (SAE) models: +- **Feature Absorption** +- **AutoInterp** +- **L0 / Loss Recovered** +- **RAVEL** +- **SHIFT** +- **TPP** +- **Sparse Probing** +- **Unlearning** -`python3 main.py` (after `cd evals/sparse_probing/`) should run as is and demonstrate how to use our SAE Bench SAEs with Transformer Lens and SAE Lens. It will also generate a results file which can be graphed using `graphing.ipynb`. +### Supported Models and SAEs +- **SAE Lens Pretrained Models**: Supports evaluations on any SAE Lens pretrained model. +- **Custom SAEs**: Supports any general SAE object with `encode()` and `decode()` methods (see [Custom SAE Usage](#custom-sae-usage)). -Here is what we would like to see from each eval: +## Running Evaluations +Each evaluation has an example command located in its respective `main.py` file. Here's how to run a sparse probing evaluation on a single SAE Bench Pythia-70M SAE: -- Making sure we are returned both the results any config required for reproducibility (eg: eval config / function args). -- Ensuring the code meets some minimum bar (isn't missing anything, isn't abysmally slow etc). -- Ensuring we have example output to validate against. +``` +python evals/sparse_probing/main.py \ + --sae_regex_pattern "sae_bench_pythia70m_sweep_standard_ctx128_0712" \ + --sae_block_pattern "blocks.4.hook_resid_post__trainer_10" \ + --model_name pythia-70m-deduped +``` -the results dictionary of you custom eval can be loaded in to `graphing.ipynb`, to create a wide variety of plots. We also already have the basic `L0 / Loss Recovered` metrics for every SAE, as specified in `template.ipynb`. +The results will be saved to the evals/sparse_probing/results directory. -For the purpose of validating evaluation outputs, we have `compare_run_results.ipynb`. Using this, you can run the same eval twice with the same input, and verify within a tolerance that it returns the same outputs. If the eval is fully deterministic, the results should be identical. +We use regex patterns to select SAE Lens SAEs. For more examples of regex patterns, refer to `sae_selection.ipynb`. -Once evaluations have been completed, please submit them as pull requests to this repo. +Every eval folder contains an `eval_config.py`, which contains all relevant hyperparamters for that evaluation. The values are currently set to the default recommended values. -## Eval Format +For a tutorial of using SAE Lens SAEs, including calculating L0 and Loss Recovered and getting a set of tokens from The Pile, refer to this notebook: https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb -Ideally, we would like to see something like `evals.sparse_probing.main.py`, which contains a `run_eval()` function. This function takes in hyperparameters and a `selected_saes_dict`, which is a dict of `sae_release` : `list[sae_names]` (as shown in template.ipynb). +## Custom SAE Usage -All evals and submodules will share the same dependencies, which are set in pyproject.toml. +Our goal is to have first class support for custom SAEs as the field is rapidly evolving. Our evaluations can run on any SAE object with encode(), decode(), and a few config values. For example custom SAEs, refer to the `baselines/` folder. -For a tutorial of using SAE Lens SAEs, including calculating L0 and Loss Recovered and getting a set of tokens from The Pile, refer to this notebook: https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb +There are two ways to evaluate custom SAEs: -## Custom SAE Usage +1. **Using Evaluation Templates**: + - Use the secondary `if __name__ == "__main__"` block in each `main.py` + - Results are saved in SAE Bench format for easy visualization + - Compatible with provided plotting tools -For the sparse probing and SHIFT / TPP evals, we support evaluating any SAE object that has the following implemented, with inputs / outputs matching the SAELens SAE format: +2. **Direct Function Calls**: + - Use `run_eval_single_sae()` in each `main.py` + - Simpler interface requiring only model, SAE, and config values + - Graphing will require manual formatting -``` -sae.encode() -sae.decode() -sae.forward() -sae.W_dec # nn.Parameter(d_sae, d_in), required for SHIFT, TPP, and Feature Absorption -sae.device -sae.dtype -``` +We currently have a suite of SAE Bench SAEs on layers 3 and 4 of Pythia-70M and layers 5, 12, and 19 of Gemma-2-2B, each trained on 200M tokens. These SAEs can serve as baselines for any new custom SAEs. We also have baseline eval results, saved at TODO. + +## Training Your Own SAEs + +You can replicate the training of our SAEs using scripts provided [here](https://github.com/canrager/dictionary_training/), or implement your own SAE, or make a change to one of our SAE implementations. Once you train your new version, you can benchmark against our existing SAEs for a true apples to apples comparison. -Just pass the appropriate inputs to `run_eval_single_sae()`, referring to individual eval READMEs as needed. If you match our output format you can reuse our graphing notebook. +## Graphing Results -To run our baselines in pythia-70m and gemma-2-2b, refer to `if __name__ == "__main__":` in `shift_and_tpp/main.py`. \ No newline at end of file +To graph the results, refer to `graphing.ipynb`, which can graph the generated SAE Bench data. Note that many graphs plot SAEs by L0 and / or Loss Recovered. To obtain these scores, run `evals/core/main.py`. \ No newline at end of file diff --git a/baselines/identity_sae.py b/baselines/identity_sae.py new file mode 100644 index 0000000..30428a8 --- /dev/null +++ b/baselines/identity_sae.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class SAEConfig: + model_name: str + d_in: int + d_sae: int + hook_layer: int + hook_name: str + context_size: int = 128 # Can be used for auto-interp + hook_head_index: Optional[int] = None + + +class IdentitySAE(nn.Module): + def __init__(self, model_name: str, d_model: int, hook_layer: int): + super().__init__() + + # Initialize W_enc and W_dec as identity matrices + self.W_enc = nn.Parameter(torch.eye(d_model)) + self.W_dec = nn.Parameter(torch.eye(d_model)) + self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.dtype: torch.dtype = torch.float32 + + hook_name = f"blocks.{hook_layer}.hook_resid_post" + + # Initialize the configuration dataclass + self.cfg = SAEConfig( + model_name, d_in=d_model, d_sae=d_model, hook_name=hook_name, hook_layer=hook_layer + ) + + def encode(self, input_acts: torch.Tensor): + acts = input_acts @ self.W_enc + return acts + + def decode(self, acts: torch.Tensor): + return acts @ self.W_dec + + def forward(self, acts): + acts = self.encode(acts) + recon = self.decode(acts) + return recon + + # required as we have device and dtype class attributes + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + # Update the device and dtype attributes based on the first parameter + device = kwargs.get("device", None) + dtype = kwargs.get("dtype", None) + + # Update device and dtype if they were provided + if device: + self.device = device + if dtype: + self.dtype = dtype + return self diff --git a/baselines/jumprelu_sae.py b/baselines/jumprelu_sae.py new file mode 100644 index 0000000..74a8676 --- /dev/null +++ b/baselines/jumprelu_sae.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +import numpy as np +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class SAEConfig: + model_name: str + d_in: int + d_sae: int + hook_layer: int + hook_name: str + context_size: int = 128 # Can be used for auto-interp + hook_head_index: Optional[int] = None + + +class JumpReLUSAE(nn.Module): + def __init__(self, d_model: int, d_sae: int, hook_layer: int, model_name: str = "gemma-2-2b"): + super().__init__() + self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae)) + self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model)) + self.threshold = nn.Parameter(torch.zeros(d_sae)) + self.b_enc = nn.Parameter(torch.zeros(d_sae)) + self.b_dec = nn.Parameter(torch.zeros(d_model)) + self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.dtype: torch.dtype = torch.float32 + + hook_name = f"blocks.{hook_layer}.hook_resid_post" + + self.cfg = SAEConfig( + model_name, d_in=d_model, d_sae=d_model, hook_name=hook_name, hook_layer=hook_layer + ) + + def encode(self, input_acts): + pre_acts = input_acts @ self.W_enc + self.b_enc + mask = pre_acts > self.threshold + acts = mask * torch.nn.functional.relu(pre_acts) + return acts + + def decode(self, acts): + return acts @ self.W_dec + self.b_dec + + def forward(self, acts): + acts = self.encode(acts) + recon = self.decode(acts) + return recon + + # required as we have device and dtype class attributes + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + # Update the device and dtype attributes based on the first parameter + device = kwargs.get("device", None) + dtype = kwargs.get("dtype", None) + + # Update device and dtype if they were provided + if device: + self.device = device + if dtype: + self.dtype = dtype + return self + + +def load_jumprelu_sae(repo_id: str, filename: str, layer: int) -> JumpReLUSAE: + path_to_params = hf_hub_download( + repo_id=repo_id, + filename=filename, + force_download=False, + ) + + params = np.load(path_to_params) + pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()} + + sae = JumpReLUSAE(params["W_enc"].shape[0], params["W_enc"].shape[1], layer) + sae.load_state_dict(pt_params) + + return sae + + +if __name__ == "__main__": + repo_id = "google/gemma-scope-2b-pt-res" + filename = "layer_20/width_16k/average_l0_71/params.npz" + + sae = load_jumprelu_sae(repo_id, filename, 20) diff --git a/evals/absorption/eval_config.py b/evals/absorption/eval_config.py index 08c11a2..268afdf 100644 --- a/evals/absorption/eval_config.py +++ b/evals/absorption/eval_config.py @@ -39,3 +39,13 @@ class AbsorptionEvalConfig(BaseEvalConfig): title="Model Name", description="Model name", ) + llm_batch_size: int = Field( + default=32, + title="LLM Batch Size", + description="LLM batch size, inference only", + ) + llm_dtype: str = Field( + default="bfloat16", + title="LLM Data Type", + description="LLM data type", + ) diff --git a/evals/absorption/k_sparse_probing.py b/evals/absorption/k_sparse_probing.py index 28c2494..4d5d88d 100644 --- a/evals/absorption/k_sparse_probing.py +++ b/evals/absorption/k_sparse_probing.py @@ -34,9 +34,7 @@ class KSparseProbe(nn.Module): bias: torch.Tensor # scalar feature_ids: torch.Tensor # shape (k) - def __init__( - self, weight: torch.Tensor, bias: torch.Tensor, feature_ids: torch.Tensor - ): + def __init__(self, weight: torch.Tensor, bias: torch.Tensor, feature_ids: torch.Tensor): super().__init__() self.weight = weight self.bias = bias @@ -47,9 +45,7 @@ def k(self) -> int: return self.weight.shape[0] def forward(self, x: torch.Tensor) -> torch.Tensor: - filtered_acts = ( - x[:, self.feature_ids] if len(x.shape) == 2 else x[self.feature_ids] - ) + filtered_acts = x[:, self.feature_ids] if len(x.shape) == 2 else x[self.feature_ids] return filtered_acts @ self.weight + self.bias @@ -82,21 +78,18 @@ def train_sparse_multi_probe( show_progress=show_progress, verbose=verbose, device=device, - extra_loss_fn=lambda probe, _x, _y: l1_decay - * probe.weights.abs().sum(dim=-1).mean(), + extra_loss_fn=lambda probe, _x, _y: l1_decay * probe.weights.abs().sum(dim=-1).mean(), ) def _get_sae_acts( sae: SAE, input_activations: torch.Tensor, - sae_post_act: bool, # whether to train the probe before or after the SAE Relu activation batch_size: int = 4096, ) -> torch.Tensor: - hook_point = "hook_sae_acts_post" if sae_post_act else "hook_sae_acts_pre" batch_acts = [] for batch in batchify(input_activations, batch_size): - acts = sae.run_with_cache(batch.to(sae.device))[1][hook_point].cpu() + acts = sae.encode(batch.to(device=sae.device, dtype=sae.dtype)).cpu() batch_acts.append(acts) return torch.cat(batch_acts) @@ -114,10 +107,8 @@ def train_k_sparse_probes( results: dict[int, dict[int, KSparseProbe]] = defaultdict(dict) with torch.no_grad(): labels = {label for _, label in train_labels} - sparse_train_y = torch.nn.functional.one_hot( - torch.tensor([idx for _, idx in train_labels]) - ) - sae_feat_acts = _get_sae_acts(sae, train_activations, sae_post_act=True) + sparse_train_y = torch.nn.functional.one_hot(torch.tensor([idx for _, idx in train_labels])) + sae_feat_acts = _get_sae_acts(sae, train_activations) l1_probe = ( train_sparse_multi_probe( sae_feat_acts.to(sae.device), @@ -138,9 +129,9 @@ def train_k_sparse_probes( sparse_feat_ids = l1_probe.weights[label].topk(k).indices.numpy() train_k_x = sae_feat_acts[:, sparse_feat_ids].float().numpy() # Use SKLearn here because it's much faster than torch if the data is small - sk_probe = LogisticRegression( - max_iter=500, class_weight="balanced" - ).fit(train_k_x, (train_k_y == label).astype(np.int64)) + sk_probe = LogisticRegression(max_iter=500, class_weight="balanced").fit( + train_k_x, (train_k_y == label).astype(np.int64) + ) probe = KSparseProbe( weight=torch.tensor(sk_probe.coef_[0]).float(), bias=torch.tensor(sk_probe.intercept_[0]).float(), @@ -163,18 +154,12 @@ def sae_k_sparse_metadata( norm_W_enc = sae.W_enc / torch.norm(sae.W_enc, dim=0, keepdim=True) norm_W_dec = sae.W_dec / torch.norm(sae.W_dec, dim=-1, keepdim=True) probe_dec_cos = ( - ( - norm_probe_weights.to(dtype=norm_W_dec.dtype, device=norm_W_dec.device) - @ norm_W_dec.T - ) + (norm_probe_weights.to(dtype=norm_W_dec.dtype, device=norm_W_dec.device) @ norm_W_dec.T) .cpu() .float() ) probe_enc_cos = ( - ( - norm_probe_weights.to(dtype=norm_W_enc.dtype, device=norm_W_enc.device) - @ norm_W_enc - ) + (norm_probe_weights.to(dtype=norm_W_enc.dtype, device=norm_W_enc.device) @ norm_W_enc) .cpu() .float() ) @@ -191,12 +176,8 @@ def sae_k_sparse_metadata( row["letter"] = letter row["k"] = k row["feats"] = k_probe.feature_ids.numpy() - row["cos_probe_sae_enc"] = probe_enc_cos[ - letter_i, k_probe.feature_ids - ].numpy() - row["cos_probe_sae_dec"] = probe_dec_cos[ - letter_i, k_probe.feature_ids - ].numpy() + row["cos_probe_sae_enc"] = probe_enc_cos[letter_i, k_probe.feature_ids].numpy() + row["cos_probe_sae_dec"] = probe_dec_cos[letter_i, k_probe.feature_ids].numpy() row["weights"] = k_probe.weight.float().numpy() row["bias"] = k_probe.bias.item() rows.append(row) @@ -224,15 +205,9 @@ def row_generator(): "answer_letter": LETTERS[answer_idx], } sae_acts = ( - _get_sae_acts( - sae, token_act.unsqueeze(0).to(sae.device), sae_post_act=True - ) - .float() - .cpu() + _get_sae_acts(sae, token_act.unsqueeze(0).to(sae.device)).float().cpu() ).squeeze() - for letter_i, (letter, probe_score) in enumerate( - zip(LETTERS, probe_scores) - ): + for letter_i, (letter, probe_score) in enumerate(zip(LETTERS, probe_scores)): row[f"score_probe_{letter}"] = probe_score for k, k_probes in k_sparse_probes.items(): k_probe = k_probes[letter_i] @@ -358,9 +333,7 @@ def build_metrics_df(results_df, metadata_df, max_k_value: int): auc_info[f"recall_sum_sparse_sae_{k}"] = recall_sum_sae auc_info[f"precision_sum_sparse_sae_{k}"] = precision_sum_sae - meta_row = metadata_df[ - (metadata_df["letter"] == letter) & (metadata_df["k"] == k) - ] + meta_row = metadata_df[(metadata_df["letter"] == letter) & (metadata_df["k"] == k)] auc_info[f"sparse_sae_k_{k}_feats"] = meta_row["feats"].iloc[0] auc_info[f"cos_probe_sae_enc_k_{k}"] = meta_row["cos_probe_sae_enc"].iloc[0] auc_info[f"cos_probe_sae_dec_k_{k}"] = meta_row["cos_probe_sae_dec"].iloc[0] @@ -392,9 +365,7 @@ def add_feature_splits_to_metrics_df( split_feats_by_letter[letter] = k_feats else: break - df["split_feats"] = df["letter"].apply( - lambda letter: split_feats_by_letter.get(letter, []) - ) + df["split_feats"] = df["letter"].apply(lambda letter: split_feats_by_letter.get(letter, [])) df["num_split_features"] = df["split_feats"].apply(len) - 1 @@ -426,15 +397,9 @@ def run_k_sparse_probing_experiment( verbose: bool = True, ) -> pd.DataFrame: task_output_dir = get_or_make_dir(experiment_dir) / sae_name - raw_results_path = task_output_dir / get_sparse_probing_raw_results_filename( - sae_name, layer - ) - metadata_results_path = task_output_dir / get_sparse_probing_metadata_filename( - sae_name, layer - ) - metrics_results_path = task_output_dir / get_sparse_probing_metrics_filename( - sae_name, layer - ) + raw_results_path = task_output_dir / get_sparse_probing_raw_results_filename(sae_name, layer) + metadata_results_path = task_output_dir / get_sparse_probing_metadata_filename(sae_name, layer) + metrics_results_path = task_output_dir / get_sparse_probing_metrics_filename(sae_name, layer) def get_raw_results_df(): return load_dfs_or_run( diff --git a/evals/absorption/main.py b/evals/absorption/main.py index 0a5e6bf..a83adb6 100644 --- a/evals/absorption/main.py +++ b/evals/absorption/main.py @@ -5,7 +5,6 @@ import torch from tqdm import tqdm import pandas as pd -from sae_lens.sae import TopK from evals.absorption.eval_config import AbsorptionEvalConfig from evals.absorption.eval_output import ( @@ -34,19 +33,28 @@ def run_eval( config: AbsorptionEvalConfig, - selected_saes_dict: dict[str, list[str]], + selected_saes_dict: dict[str, list[str] | SAE], device: str, output_path: str, force_rerun: bool = False, ): + """ + selected_saes_dict is a dict mapping either: + - Release name -> list of SAE IDs to load from that release + - Custom name -> Single SAE object + """ eval_instance_id = get_eval_uuid() sae_lens_version = get_sae_lens_version() sae_bench_commit_hash = get_sae_bench_version() results_dict = {} - llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name] - llm_dtype = activation_collection.LLM_NAME_TO_DTYPE[config.model_name] + if config.llm_dtype == "bfloat16": + llm_dtype = torch.bfloat16 + elif config.llm_dtype == "float32": + llm_dtype = torch.float32 + else: + raise ValueError(f"Invalid dtype: {config.llm_dtype}") model = HookedTransformer.from_pretrained_no_processing( config.model_name, device=device, dtype=llm_dtype @@ -57,6 +65,10 @@ def run_eval( f"Running evaluation for SAE release: {sae_release}, SAEs: {selected_saes_dict[sae_release]}" ) + # Wrap single SAE objects in a list to unify processing of both pretrained and custom SAEs + if not isinstance(selected_saes_dict[sae_release], list): + selected_saes_dict[sae_release] = [selected_saes_dict[sae_release]] + for sae_id in tqdm( selected_saes_dict[sae_release], desc="Running SAE evaluation on all selected SAEs", @@ -64,13 +76,18 @@ def run_eval( gc.collect() torch.cuda.empty_cache() - sae = SAE.from_pretrained( - release=sae_release, - sae_id=sae_id, - device=device, - )[0] + # Handle both pretrained SAEs (identified by string) and custom SAEs (passed as objects) + if isinstance(sae_id, str): + sae = SAE.from_pretrained( + release=sae_release, + sae_id=sae_id, + device=device, + )[0] + else: + sae = sae_id + sae_id = "custom_sae" + sae = sae.to(device=device, dtype=llm_dtype) - sae = _fix_topk(sae, sae_id, sae_release) k_sparse_probing_results = run_k_sparse_probing_experiment( model=model, @@ -90,13 +107,9 @@ def run_eval( os.makedirs(artifacts_folder, exist_ok=True) k_sparse_probing_file = f"{sae_release}_{sae_id}_k_sparse_probing.json" k_sparse_probing_file = k_sparse_probing_file.replace("/", "_") - k_sparse_probing_path = os.path.join( - artifacts_folder, k_sparse_probing_file - ) + k_sparse_probing_path = os.path.join(artifacts_folder, k_sparse_probing_file) os.makedirs(os.path.dirname(k_sparse_probing_path), exist_ok=True) - k_sparse_probing_results.to_json( - k_sparse_probing_path, orient="records", indent=4 - ) + k_sparse_probing_results.to_json(k_sparse_probing_path, orient="records", indent=4) raw_df = run_feature_absortion_experiment( model=model, @@ -108,7 +121,7 @@ def run_eval( feature_split_f1_jump_threshold=config.f1_jump_threshold, prompt_template=config.prompt_template, prompt_token_pos=config.prompt_token_pos, - batch_size=llm_batch_size, + batch_size=config.llm_batch_size, device=device, ) agg_df = _aggregate_results_df(raw_df) @@ -185,33 +198,14 @@ def _aggregate_results_df( ) agg_df["num_split_feats"] = agg_df["split_feats"].apply(len) agg_df["num_absorption"] = agg_df["is_absorption"] - agg_df["absorption_rate"] = ( - agg_df["num_absorption"] / agg_df["num_probe_true_positives"] - ) + agg_df["absorption_rate"] = agg_df["num_absorption"] / agg_df["num_probe_true_positives"] return agg_df -def _fix_topk( - sae: SAE, - sae_name: str, - sae_release: str, -): - if "topk" in sae_name: - if isinstance(sae.activation_fn, TopK): - return sae - - sae = formatting_utils.fix_topk_saes(sae, sae_release, sae_name, data_dir="../") - - assert isinstance(sae.activation_fn, TopK) - return sae - - def arg_parser(): parser = argparse.ArgumentParser(description="Run absorption evaluation") parser.add_argument("--random_seed", type=int, default=42, help="Random seed") - parser.add_argument( - "--f1_jump_threshold", type=float, default=0.03, help="F1 jump threshold" - ) + parser.add_argument("--f1_jump_threshold", type=float, default=0.03, help="F1 jump threshold") parser.add_argument("--max_k_value", type=int, default=10, help="Maximum k value") parser.add_argument( "--prompt_template", @@ -219,12 +213,8 @@ def arg_parser(): default="{word} has the first letter:", help="Prompt template", ) - parser.add_argument( - "--prompt_token_pos", type=int, default=-6, help="Prompt token position" - ) - parser.add_argument( - "--model_name", type=str, default="pythia-70m-deduped", help="Model name" - ) + parser.add_argument("--prompt_token_pos", type=int, default=-6, help="Prompt token position") + parser.add_argument("--model_name", type=str, default="pythia-70m-deduped", help="Model name") parser.add_argument( "--sae_regex_pattern", type=str, @@ -243,9 +233,7 @@ def arg_parser(): default="evals/absorption/results", help="Output folder", ) - parser.add_argument( - "--force_rerun", action="store_true", help="Force rerun of experiments" - ) + parser.add_argument("--force_rerun", action="store_true", help="Force rerun of experiments") return parser @@ -272,9 +260,7 @@ def create_config_and_selected_saes(args): model_name=args.model_name, ) - selected_saes_dict = get_saes_from_regex( - args.sae_regex_pattern, args.sae_block_pattern - ) + selected_saes_dict = get_saes_from_regex(args.sae_regex_pattern, args.sae_block_pattern) assert len(selected_saes_dict) > 0, "No SAEs selected" for release, saes in selected_saes_dict.items(): @@ -289,8 +275,7 @@ def create_config_and_selected_saes(args): python evals/absorption/main.py \ --sae_regex_pattern "sae_bench_pythia70m_sweep_standard_ctx128_0712" \ --sae_block_pattern "blocks.4.hook_resid_post__trainer_10" \ - --model_name pythia-70m-deduped \ - --output_folder results + --model_name pythia-70m-deduped """ args = arg_parser().parse_args() device = setup_environment() @@ -299,6 +284,11 @@ def create_config_and_selected_saes(args): config, selected_saes_dict = create_config_and_selected_saes(args) + config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name] + config.llm_dtype = str(activation_collection.LLM_NAME_TO_DTYPE[config.model_name]).split(".")[ + -1 + ] + # create output folder os.makedirs(args.output_folder, exist_ok=True) @@ -310,3 +300,67 @@ def create_config_and_selected_saes(args): end_time = time.time() print(f"Finished evaluation in {end_time - start_time:.2f} seconds") + + +# # Use this code snippet to use custom SAE objects +# if __name__ == "__main__": +# import baselines.identity_sae as identity_sae +# import baselines.jumprelu_sae as jumprelu_sae + +# """ +# python evals/absorption/main.py +# """ +# device = setup_environment() + +# start_time = time.time() + +# random_seed = 42 +# output_folder = "evals/absorption/results" + +# baseline_type = "identity_sae" +# # baseline_type = "jumprelu_sae" + +# model_name = "pythia-70m-deduped" +# hook_layer = 4 +# d_model = 512 + +# # model_name = "gemma-2-2b" +# # hook_layer = 19 +# # d_model = 2304 + +# if baseline_type == "identity_sae": +# sae = identity_sae.IdentitySAE(model_name, d_model=d_model, hook_layer=hook_layer) +# selected_saes_dict = {f"{model_name}_layer_{hook_layer}_identity_sae": sae} +# elif baseline_type == "jumprelu_sae": +# repo_id = "google/gemma-scope-2b-pt-res" +# filename = "layer_20/width_16k/average_l0_71/params.npz" +# sae = jumprelu_sae.load_jumprelu_sae(repo_id, filename, 20) +# selected_saes_dict = {f"{repo_id}_{filename}_gemmascope_sae": sae} +# else: +# raise ValueError(f"Invalid baseline type: {baseline_type}") + +# config = AbsorptionEvalConfig( +# random_seed=random_seed, +# model_name=model_name, +# ) + +# config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name] +# config.llm_dtype = str(activation_collection.LLM_NAME_TO_DTYPE[config.model_name]).split(".")[ +# -1 +# ] + +# # create output folder +# os.makedirs(output_folder, exist_ok=True) + +# # run the evaluation on all selected SAEs +# results_dict = run_eval( +# config, +# selected_saes_dict, +# device, +# output_folder, +# force_rerun=True, +# ) + +# end_time = time.time() + +# print(f"Finished evaluation in {end_time - start_time} seconds") diff --git a/evals/autointerp/eval_config.py b/evals/autointerp/eval_config.py index 0487d20..34a3029 100644 --- a/evals/autointerp/eval_config.py +++ b/evals/autointerp/eval_config.py @@ -65,6 +65,11 @@ class AutoInterpEvalConfig: title="Dataset Name", description="The name of the dataset to use", ) + llm_context_size: int = Field( + default=128, + title="LLM Context Size", + description="The context size to use for the LLM", + ) llm_batch_size: int = Field( default=512, title="LLM Batch Size", diff --git a/evals/autointerp/main.py b/evals/autointerp/main.py index 61057e0..58dccdc 100644 --- a/evals/autointerp/main.py +++ b/evals/autointerp/main.py @@ -380,7 +380,7 @@ def gather_data(self) -> tuple[dict[int, Examples], dict[int, Examples]]: generation_examples = {} scoring_examples = {} - for i, latent in enumerate(self.latents): + for i, latent in tqdm(enumerate(self.latents), desc="Collecting examples for LLM judge"): # (1/3) Get random examples (we don't need their values) rand_indices = torch.stack( [ @@ -471,17 +471,21 @@ def run_eval_single_sae( torch.manual_seed(config.random_seed) torch.set_grad_enabled(False) - tokens_filename = f"{config.total_tokens}_tokens_{sae.cfg.context_size}_ctx.pt" + os.makedirs(artifacts_folder, exist_ok=True) + + tokens_filename = f"{config.total_tokens}_tokens_{config.llm_context_size}_ctx.pt" tokens_path = os.path.join(artifacts_folder, tokens_filename) if os.path.exists(tokens_path): tokenized_dataset = torch.load(tokens_path).to(device) else: tokenized_dataset = dataset_utils.load_and_tokenize_dataset( - config.dataset_name, sae.cfg.context_size, config.total_tokens, model.tokenizer + config.dataset_name, config.llm_context_size, config.total_tokens, model.tokenizer ).to(device) torch.save(tokenized_dataset, tokens_path) + print(f"Loaded tokenized dataset of shape {tokenized_dataset.shape}") + if sae_sparsity is None: sae_sparsity = activation_collection.get_feature_activation_sparsity( tokenized_dataset, @@ -508,7 +512,7 @@ def run_eval_single_sae( def run_eval( config: AutoInterpEvalConfig, - selected_saes_dict: dict[str, list[str]], + selected_saes_dict: dict[str, list[str] | SAE], device: str, api_key: str, output_path: str, @@ -516,10 +520,9 @@ def run_eval( save_logs_path: Optional[str] = None, ) -> dict[str, Any]: """ - Runs autointerp eval. Returns results as a dict with the following structure: - - custom_eval_config - dict of config parameters used for this evaluation - custom_eval_results - nested dict of {sae_name: {"score": score}} + selected_saes_dict is a dict mapping either: + - Release name -> list of SAE IDs to load from that release + - Custom name -> Single SAE object """ eval_instance_id = get_eval_uuid() sae_lens_version = get_sae_lens_version() @@ -546,17 +549,32 @@ def run_eval( f"Running evaluation for SAE release: {sae_release}, SAEs: {selected_saes_dict[sae_release]}" ) + # Wrap single SAE objects in a list to unify processing of both pretrained and custom SAEs + if not isinstance(selected_saes_dict[sae_release], list): + selected_saes_dict[sae_release] = [selected_saes_dict[sae_release]] + for sae_id in tqdm( selected_saes_dict[sae_release], desc="Running SAE evaluation on all selected SAEs", ): gc.collect() torch.cuda.empty_cache() - sae, _, sparsity = SAE.from_pretrained(sae_release, sae_id, device=str(device)) + + # Handle both pretrained SAEs (identified by string) and custom SAEs (passed as objects) + if isinstance(sae_id, str): + sae, _, sparsity = SAE.from_pretrained( + release=sae_release, + sae_id=sae_id, + device=device, + ) + else: + sae = sae_id + sae_id = "custom_sae" + sparsity = None + sae = sae.to(device=device, dtype=llm_dtype) artifacts_folder = os.path.join(artifacts_base_folder, EVAL_TYPE_ID_AUTOINTERP) - os.makedirs(artifacts_folder, exist_ok=True) sae_result_file = f"{sae_release}_{sae_id}_eval_results.json" sae_result_file = sae_result_file.replace("/", "_") @@ -599,7 +617,6 @@ def run_eval( # Put important results into the results dict score = sum([r["score"] for r in sae_eval_result.values()]) / len(sae_eval_result) - eval_result_metrics = {"autointerp_metrics": {"autointerp_score": score}} eval_output = AutoInterpEvalOutput( eval_config=config, @@ -694,36 +711,21 @@ def arg_parser(): --sae_block_pattern "blocks.4.hook_resid_post__trainer_10" \ --model_name pythia-70m-deduped \ --api_key - + python evals/autointerp/main.py \ --sae_regex_pattern "gemma-scope-2b-pt-res" \ --sae_block_pattern "layer_20/width_16k/average_l0_139" \ --model_name gemma-2-2b \ --api_key - + """ args = arg_parser().parse_args() device = setup_environment() start_time = time.time() - sae_regex_patterns = [ - r"(sae_bench_pythia70m_sweep_topk_ctx128_0730).*", - r"(sae_bench_pythia70m_sweep_standard_ctx128_0712).*", - ] - sae_block_pattern = [ - r".*blocks\.([4])\.hook_resid_post__trainer_(2|6|10|14)$", - r".*blocks\.([4])\.hook_resid_post__trainer_(2|6|10|14)$", - ] - - sae_regex_patterns = None - sae_block_pattern = None - config, selected_saes_dict = create_config_and_selected_saes(args) - if sae_regex_patterns is not None: - selected_saes_dict = select_saes_multiple_patterns(sae_regex_patterns, sae_block_pattern) - print(selected_saes_dict) config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name] @@ -747,3 +749,73 @@ def arg_parser(): end_time = time.time() print(f"Finished evaluation in {end_time - start_time} seconds") + + +# Use this code snippet to use custom SAE objects +# if __name__ == "__main__": +# """ +# python evals/autointerp/main.py +# NOTE: We don't use argparse here. This requires a file openai_api_key.txt to be present in the root directory. +# """ + +# import baselines.identity_sae as identity_sae +# import baselines.jumprelu_sae as jumprelu_sae + +# device = setup_environment() + +# start_time = time.time() + +# random_seed = 42 +# output_folder = "evals/autointerp/results" + +# with open("openai_api_key.txt", "r") as f: +# api_key = f.read().strip() + +# baseline_type = "identity_sae" +# # baseline_type = "jumprelu_sae" + +# model_name = "pythia-70m-deduped" +# hook_layer = 4 +# d_model = 512 + +# # model_name = "gemma-2-2b" +# # hook_layer = 19 +# # d_model = 2304 + +# if baseline_type == "identity_sae": +# sae = identity_sae.IdentitySAE(model_name, d_model=d_model, hook_layer=hook_layer) +# selected_saes_dict = {f"{model_name}_layer_{hook_layer}_identity_sae": sae} +# elif baseline_type == "jumprelu_sae": +# repo_id = "google/gemma-scope-2b-pt-res" +# filename = "layer_20/width_16k/average_l0_71/params.npz" +# sae = jumprelu_sae.load_jumprelu_sae(repo_id, filename, 20) +# selected_saes_dict = {f"{repo_id}_{filename}_gemmascope_sae": sae} +# else: +# raise ValueError(f"Invalid baseline type: {baseline_type}") + +# config = AutoInterpEvalConfig( +# random_seed=random_seed, +# model_name=model_name, +# ) + +# config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name] +# config.llm_dtype = str(activation_collection.LLM_NAME_TO_DTYPE[config.model_name]).split(".")[ +# -1 +# ] + +# # create output folder +# os.makedirs(output_folder, exist_ok=True) + +# # run the evaluation on all selected SAEs +# results_dict = run_eval( +# config, +# selected_saes_dict, +# device, +# api_key, +# output_folder, +# force_rerun=True, +# ) + +# end_time = time.time() + +# print(f"Finished evaluation in {end_time - start_time} seconds") diff --git a/evals/shift_and_tpp/main.py b/evals/shift_and_tpp/main.py index 5568edb..6978868 100644 --- a/evals/shift_and_tpp/main.py +++ b/evals/shift_and_tpp/main.py @@ -611,8 +611,6 @@ def run_eval_single_sae( config: ShiftAndTppEvalConfig, sae: SAE, model: HookedTransformer, - layer: int, - hook_point: str, device: str, artifacts_folder: str, save_activations: bool = True, @@ -625,6 +623,8 @@ def run_eval_single_sae( random.seed(config.random_seed) torch.manual_seed(config.random_seed) + os.makedirs(artifacts_folder, exist_ok=True) + dataset_results = {} averaging_names = [] @@ -639,8 +639,8 @@ def run_eval_single_sae( config, sae, model, - layer, - hook_point, + sae.cfg.hook_layer, + sae.cfg.hook_name, device, artifacts_folder, save_activations, @@ -660,8 +660,8 @@ def run_eval_single_sae( config, sae, model, - layer, - hook_point, + sae.cfg.hook_layer, + sae.cfg.hook_name, device, artifacts_folder, save_activations, @@ -680,14 +680,19 @@ def run_eval_single_sae( def run_eval( config: ShiftAndTppEvalConfig, - selected_saes_dict: dict[str, list[str]], + selected_saes_dict: dict[str, list[str] | SAE], device: str, output_path: str, force_rerun: bool = False, clean_up_activations: bool = False, ): - """By default, clean_up_activations is True, which means that the activations are deleted after the evaluation is done. - This is because activations for all datasets can easily be 10s of GBs. + """ + selected_saes_dict is a dict mapping either: + - Release name -> list of SAE IDs to load from that release + - Custom name -> Single SAE object + + If clean_up_activations is True, which means that the activations are deleted after the evaluation is done. + You may want to use this because activations for all datasets can easily be 10s of GBs. Return dict is a dict of SAE name: evaluation results for that SAE.""" eval_instance_id = get_eval_uuid() sae_lens_version = get_sae_lens_version() @@ -720,6 +725,10 @@ def run_eval( f"Running evaluation for SAE release: {sae_release}, SAEs: {selected_saes_dict[sae_release]}" ) + # Wrap single SAE objects in a list to unify processing of both pretrained and custom SAEs + if not isinstance(selected_saes_dict[sae_release], list): + selected_saes_dict[sae_release] = [selected_saes_dict[sae_release]] + for sae_id in tqdm( selected_saes_dict[sae_release], desc="Running SAE evaluation on all selected SAEs", @@ -727,17 +736,22 @@ def run_eval( gc.collect() torch.cuda.empty_cache() - sae = SAE.from_pretrained( - release=sae_release, - sae_id=sae_id, - device=device, - )[0] + # Handle both pretrained SAEs (identified by string) and custom SAEs (passed as objects) + if isinstance(sae_id, str): + sae = SAE.from_pretrained( + release=sae_release, + sae_id=sae_id, + device=device, + )[0] + else: + sae = sae_id + sae_id = "custom_sae" + sae = sae.to(device=device, dtype=llm_dtype) artifacts_folder = os.path.join( artifacts_base_folder, eval_type, config.model_name, sae.cfg.hook_name ) - os.makedirs(artifacts_folder, exist_ok=True) sae_result_file = f"{sae_release}_{sae_id}_eval_results.json" sae_result_file = sae_result_file.replace("/", "_") @@ -757,8 +771,6 @@ def run_eval( config, sae, model, - sae.cfg.hook_layer, - sae.cfg.hook_name, device, artifacts_folder, ) @@ -934,35 +946,8 @@ def str_to_bool(value): start_time = time.time() - sae_regex_patterns = [ - r"(sae_bench_pythia70m_sweep_topk_ctx128_0730).*", - r"(sae_bench_pythia70m_sweep_standard_ctx128_0712).*", - ] - sae_block_pattern = [ - r".*blocks\.([4])\.hook_resid_post__trainer_(2|6|10|14)$", - r".*blocks\.([4])\.hook_resid_post__trainer_(2|6|10|14)$", - ] - - # For Gemma-2-2b - sae_regex_patterns = [ - r"sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824", - r"sae_bench_gemma-2-2b_sweep_standard_ctx128_ef8_0824", - r"(gemma-scope-2b-pt-res)", - ] - sae_block_pattern = [ - r".*blocks\.19(?!.*step).*", - r".*blocks\.19(?!.*step).*", - r".*layer_(19).*(16k).*", - ] - - sae_regex_patterns = None - sae_block_pattern = None - config, selected_saes_dict = create_config_and_selected_saes(args) - if sae_regex_patterns is not None: - selected_saes_dict = select_saes_multiple_patterns(sae_regex_patterns, sae_block_pattern) - print(selected_saes_dict) config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name] @@ -986,3 +971,70 @@ def str_to_bool(value): end_time = time.time() print(f"Finished evaluation in {end_time - start_time} seconds") + + +# Use this code snippet to use custom SAE objects +# if __name__ == "__main__": +# import baselines.identity_sae as identity_sae +# import baselines.jumprelu_sae as jumprelu_sae + +# """ +# python evals/shift_and_tpp/main.py +# """ +# device = setup_environment() + +# start_time = time.time() + +# random_seed = 42 +# output_folder = "evals/shift_and_tpp/results" +# perform_scr = True + +# baseline_type = "identity_sae" +# # baseline_type = "jumprelu_sae" + +# model_name = "pythia-70m-deduped" +# hook_layer = 4 +# d_model = 512 + +# # model_name = "gemma-2-2b" +# # hook_layer = 19 +# # d_model = 2304 + +# if baseline_type == "identity_sae": +# sae = identity_sae.IdentitySAE(model_name, d_model=d_model, hook_layer=hook_layer) +# selected_saes_dict = {f"{model_name}_layer_{hook_layer}_identity_sae": sae} +# elif baseline_type == "jumprelu_sae": +# repo_id = "google/gemma-scope-2b-pt-res" +# filename = "layer_20/width_16k/average_l0_71/params.npz" +# sae = jumprelu_sae.load_jumprelu_sae(repo_id, filename, 20) +# selected_saes_dict = {f"{repo_id}_{filename}_gemmascope_sae": sae} +# else: +# raise ValueError(f"Invalid baseline type: {baseline_type}") + +# config = ShiftAndTppEvalConfig( +# random_seed=random_seed, +# model_name=model_name, +# perform_scr=perform_scr, +# ) + +# config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name] +# config.llm_dtype = str(activation_collection.LLM_NAME_TO_DTYPE[config.model_name]).split(".")[ +# -1 +# ] + +# # create output folder +# os.makedirs(output_folder, exist_ok=True) + +# # run the evaluation on all selected SAEs +# results_dict = run_eval( +# config, +# selected_saes_dict, +# device, +# output_folder, +# force_rerun=True, +# clean_up_activations=False, +# ) + +# end_time = time.time() + +# print(f"Finished evaluation in {end_time - start_time} seconds") diff --git a/evals/sparse_probing/main.py b/evals/sparse_probing/main.py index fc962d7..7c4f838 100644 --- a/evals/sparse_probing/main.py +++ b/evals/sparse_probing/main.py @@ -194,8 +194,6 @@ def run_eval_single_sae( config: SparseProbingEvalConfig, sae: SAE, model: HookedTransformer, - layer: int, - hook_point: str, device: str, artifacts_folder: str, save_activations: bool = True, @@ -207,6 +205,7 @@ def run_eval_single_sae( random.seed(config.random_seed) torch.manual_seed(config.random_seed) + os.makedirs(artifacts_folder, exist_ok=True) results_dict = {} @@ -217,8 +216,8 @@ def run_eval_single_sae( config, sae, model, - layer, - hook_point, + sae.cfg.hook_layer, + sae.cfg.hook_name, device, artifacts_folder, save_activations, @@ -236,14 +235,19 @@ def run_eval_single_sae( def run_eval( config: SparseProbingEvalConfig, - selected_saes_dict: dict[str, list[str]], + selected_saes_dict: dict[str, list[str] | SAE], device: str, output_path: str, force_rerun: bool = False, clean_up_activations: bool = False, ): - """By default, clean_up_activations is True, which means that the activations are deleted after the evaluation is done. - This is because activations for all datasets can easily be 10s of GBs. + """ + selected_saes_dict is a dict mapping either: + - Release name -> list of SAE IDs to load from that release + - Custom name -> Single SAE object + + If clean_up_activations is True, which means that the activations are deleted after the evaluation is done. + You may want to use this because activations for all datasets can easily be 10s of GBs. Return dict is a dict of SAE name: evaluation results for that SAE.""" eval_instance_id = get_eval_uuid() sae_lens_version = get_sae_lens_version() @@ -270,6 +274,10 @@ def run_eval( f"Running evaluation for SAE release: {sae_release}, SAEs: {selected_saes_dict[sae_release]}" ) + # Wrap single SAE objects in a list to unify processing of both pretrained and custom SAEs + if not isinstance(selected_saes_dict[sae_release], list): + selected_saes_dict[sae_release] = [selected_saes_dict[sae_release]] + for sae_id in tqdm( selected_saes_dict[sae_release], desc="Running SAE evaluation on all selected SAEs", @@ -277,11 +285,17 @@ def run_eval( gc.collect() torch.cuda.empty_cache() - sae = SAE.from_pretrained( - release=sae_release, - sae_id=sae_id, - device=device, - )[0] + # Handle both pretrained SAEs (identified by string) and custom SAEs (passed as objects) + if isinstance(sae_id, str): + sae = SAE.from_pretrained( + release=sae_release, + sae_id=sae_id, + device=device, + )[0] + else: + sae = sae_id + sae_id = "custom_sae" + sae = sae.to(device=device, dtype=llm_dtype) artifacts_folder = os.path.join( @@ -290,7 +304,6 @@ def run_eval( config.model_name, sae.cfg.hook_name, ) - os.makedirs(artifacts_folder, exist_ok=True) sae_result_file = f"{sae_release}_{sae_id}_eval_results.json" sae_result_file = sae_result_file.replace("/", "_") @@ -305,8 +318,6 @@ def run_eval( config, sae, model, - sae.cfg.hook_layer, - sae.cfg.hook_name, device, artifacts_folder, ) @@ -387,7 +398,7 @@ def create_config_and_selected_saes( def arg_parser(): parser = argparse.ArgumentParser(description="Run sparse probing evaluation") parser.add_argument("--random_seed", type=int, default=42, help="Random seed") - parser.add_argument("--model_name", type=str, default="pythia-70m-deduped", help="Model name") + parser.add_argument("--model_name", type=str, help="Model name") parser.add_argument( "--sae_regex_pattern", type=str, @@ -421,32 +432,17 @@ def arg_parser(): python evals/sparse_probing/main.py \ --sae_regex_pattern "sae_bench_pythia70m_sweep_standard_ctx128_0712" \ --sae_block_pattern "blocks.4.hook_resid_post__trainer_10" \ - --model_name pythia-70m-deduped - - + --model_name pythia-70m-deduped + + """ args = arg_parser().parse_args() device = setup_environment() start_time = time.time() - sae_regex_patterns = [ - r"(sae_bench_pythia70m_sweep_topk_ctx128_0730).*", - r"(sae_bench_pythia70m_sweep_standard_ctx128_0712).*", - ] - sae_block_pattern = [ - r".*blocks\.([4])\.hook_resid_post__trainer_(2|6|10|14)$", - r".*blocks\.([4])\.hook_resid_post__trainer_(2|6|10|14)$", - ] - - sae_regex_patterns = None - sae_block_pattern = None - config, selected_saes_dict = create_config_and_selected_saes(args) - if sae_regex_patterns is not None: - selected_saes_dict = select_saes_multiple_patterns(sae_regex_patterns, sae_block_pattern) - print(selected_saes_dict) config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name] @@ -470,3 +466,67 @@ def arg_parser(): end_time = time.time() print(f"Finished evaluation in {end_time - start_time} seconds") + + +# Use this code snippet to use custom SAE objects +# if __name__ == "__main__": +# import baselines.identity_sae as identity_sae +# import baselines.jumprelu_sae as jumprelu_sae +# """ +# python evals/sparse_probing/main.py +# """ +# device = setup_environment() + +# start_time = time.time() + +# random_seed = 42 +# output_folder = "evals/sparse_probing/results" + +# baseline_type = "identity_sae" +# baseline_type = "jumprelu_sae" + +# model_name = "pythia-70m-deduped" +# hook_layer = 4 +# d_model = 512 + +# model_name = "gemma-2-2b" +# hook_layer = 19 +# d_model = 2304 + +# if baseline_type == "identity_sae": +# sae = identity_sae.IdentitySAE(model_name, d_model=d_model, hook_layer=hook_layer) +# selected_saes_dict = {f"{model_name}_layer_{hook_layer}_identity_sae": sae} +# elif baseline_type == "jumprelu_sae": +# repo_id = "google/gemma-scope-2b-pt-res" +# filename = "layer_20/width_16k/average_l0_71/params.npz" +# sae = jumprelu_sae.load_jumprelu_sae(repo_id, filename, 20) +# selected_saes_dict = {f"{repo_id}_{filename}_gemmascope_sae": sae} +# else: +# raise ValueError(f"Invalid baseline type: {baseline_type}") + +# config = SparseProbingEvalConfig( +# random_seed=random_seed, +# model_name=model_name, +# ) + +# config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name] +# config.llm_dtype = str(activation_collection.LLM_NAME_TO_DTYPE[config.model_name]).split(".")[ +# -1 +# ] + +# # create output folder +# os.makedirs(output_folder, exist_ok=True) + +# # run the evaluation on all selected SAEs +# results_dict = run_eval( +# config, +# selected_saes_dict, +# device, +# output_folder, +# force_rerun=True, +# clean_up_activations=False, +# ) + +# end_time = time.time() + +# print(f"Finished evaluation in {end_time - start_time} seconds") diff --git a/evals/unlearning/main.py b/evals/unlearning/main.py index 3a586a4..3ce0ecc 100644 --- a/evals/unlearning/main.py +++ b/evals/unlearning/main.py @@ -16,7 +16,11 @@ from datetime import datetime from transformer_lens import HookedTransformer from sae_lens import SAE -from evals.unlearning.eval_output import UnlearningEvalOutput, UnlearningMetricCategories, UnlearningMetrics +from evals.unlearning.eval_output import ( + UnlearningEvalOutput, + UnlearningMetricCategories, + UnlearningMetrics, +) from evals.unlearning.utils.eval import run_eval_single_sae import sae_bench_utils.activation_collection as activation_collection from evals.unlearning.eval_config import UnlearningEvalConfig @@ -104,12 +108,17 @@ def convert_ndarrays_to_lists(obj): def run_eval( config: UnlearningEvalConfig, - selected_saes_dict: dict[str, list[str]], + selected_saes_dict: dict[str, list[str] | SAE], device: str, output_path: str, force_rerun: bool = False, clean_up_artifacts: bool = False, ): + """ + selected_saes_dict is a dict mapping either: + - Release name -> list of SAE IDs to load from that release + - Custom name -> Single SAE object + """ eval_instance_id = get_eval_uuid() sae_lens_version = get_sae_lens_version() sae_bench_commit_hash = get_sae_bench_version() @@ -139,6 +148,10 @@ def run_eval( f"Running evaluation for SAE release: {sae_release}, SAEs: {selected_saes_dict[sae_release]}" ) + # Wrap single SAE objects in a list to unify processing of both pretrained and custom SAEs + if not isinstance(selected_saes_dict[sae_release], list): + selected_saes_dict[sae_release] = [selected_saes_dict[sae_release]] + for sae_id in tqdm( selected_saes_dict[sae_release], desc="Running SAE evaluation on all selected SAEs", @@ -146,11 +159,17 @@ def run_eval( gc.collect() torch.cuda.empty_cache() - sae, cfg_dict, sparsity = SAE.from_pretrained( - release=sae_release, - sae_id=sae_id, - device=device, - ) + # Handle both pretrained SAEs (identified by string) and custom SAEs (passed as objects) + if isinstance(sae_id, str): + sae = SAE.from_pretrained( + release=sae_release, + sae_id=sae_id, + device=device, + )[0] + else: + sae = sae_id + sae_id = "custom_sae" + sae = sae.to(device=device, dtype=llm_dtype) sae_release_and_id = f"{sae_release}_{sae_id}" @@ -158,7 +177,6 @@ def run_eval( sae_results_folder = os.path.join( artifacts_folder, sae_release_and_id, "results/metrics" ) - os.makedirs(artifacts_folder, exist_ok=True) sae_result_file = f"{sae_release}_{sae_id}_eval_results.json" sae_result_file = sae_result_file.replace("/", "_") @@ -181,7 +199,9 @@ def run_eval( eval_config=config, eval_id=eval_instance_id, datetime_epoch_millis=int(datetime.now().timestamp() * 1000), - eval_result_metrics=UnlearningMetricCategories(unlearning=UnlearningMetrics(unlearning_score=unlearning_score)), + eval_result_metrics=UnlearningMetricCategories( + unlearning=UnlearningMetrics(unlearning_score=unlearning_score) + ), eval_result_details=[], sae_bench_commit_hash=sae_bench_commit_hash, sae_lens_id=sae_id, @@ -283,26 +303,8 @@ def arg_parser(): start_time = time.time() - # For Gemma-2-2b - sae_regex_patterns = [ - r"sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824", - r"sae_bench_gemma-2-2b_sweep_standard_ctx128_ef8_0824", - r"(gemma-scope-2b-pt-res)", - ] - sae_block_pattern = [ - r".*blocks\.3(?!.*step).*", - r".*blocks\.3(?!.*step).*", - r".*layer_(3).*(16k).*", - ] - - sae_regex_patterns = None - sae_block_pattern = None - config, selected_saes_dict = create_config_and_selected_saes(args) - if sae_regex_patterns is not None: - selected_saes_dict = select_saes_multiple_patterns(sae_regex_patterns, sae_block_pattern) - print(selected_saes_dict) config.llm_dtype = str(activation_collection.LLM_NAME_TO_DTYPE[config.model_name]).split(".")[ @@ -325,3 +327,62 @@ def arg_parser(): end_time = time.time() print(f"Finished evaluation in {end_time - start_time} seconds") + +# Use this code snippet to use custom SAE objects +# if __name__ == "__main__": +# import baselines.identity_sae as identity_sae +# import baselines.jumprelu_sae as jumprelu_sae +# """ +# python evals/unlearning/main.py +# """ +# device = setup_environment() + +# start_time = time.time() + +# random_seed = 42 +# output_folder = "evals/unlearning/results" + +# baseline_type = "identity_sae" +# baseline_type = "jumprelu_sae" + +# model_name = "gemma-2-2b" +# hook_layer = 19 +# d_model = 2304 + +# if baseline_type == "identity_sae": +# sae = identity_sae.IdentitySAE(model_name, d_model=d_model, hook_layer=hook_layer) +# selected_saes_dict = {f"{model_name}_layer_{hook_layer}_identity_sae": sae} +# elif baseline_type == "jumprelu_sae": +# repo_id = "google/gemma-scope-2b-pt-res" +# filename = "layer_20/width_16k/average_l0_71/params.npz" +# sae = jumprelu_sae.load_jumprelu_sae(repo_id, filename, 20) +# selected_saes_dict = {f"{repo_id}_{filename}_gemmascope_sae": sae} +# else: +# raise ValueError(f"Invalid baseline type: {baseline_type}") + +# config = UnlearningEvalConfig( +# random_seed=random_seed, +# model_name=model_name, +# ) + +# config.llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name] +# config.llm_dtype = str(activation_collection.LLM_NAME_TO_DTYPE[config.model_name]).split(".")[ +# -1 +# ] + +# # create output folder +# os.makedirs(output_folder, exist_ok=True) + +# # run the evaluation on all selected SAEs +# results_dict = run_eval( +# config, +# selected_saes_dict, +# device, +# output_folder, +# force_rerun=True, +# clean_up_activations=False, +# ) + +# end_time = time.time() + +# print(f"Finished evaluation in {end_time - start_time} seconds") diff --git a/evals/unlearning/utils/eval.py b/evals/unlearning/utils/eval.py index 0e29dd7..64a83e2 100644 --- a/evals/unlearning/utils/eval.py +++ b/evals/unlearning/utils/eval.py @@ -1,5 +1,6 @@ import os import numpy as np +import torch from transformer_lens import HookedTransformer from sae_lens import SAE from evals.unlearning.utils.feature_activation import ( @@ -72,6 +73,11 @@ def run_eval_single_sae( force_rerun: bool, ): """sae_release_and_id: str is the name used when saving data for this SAE. This data will be reused at various points in the evaluation.""" + + os.makedirs(artifacts_folder, exist_ok=True) + + torch.set_grad_enabled(False) + # calculate feature sparsity save_feature_sparsity( model, diff --git a/sae_bench_utils/activation_collection.py b/sae_bench_utils/activation_collection.py index 0e36195..d2799ab 100644 --- a/sae_bench_utils/activation_collection.py +++ b/sae_bench_utils/activation_collection.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from tqdm import tqdm -from typing import Callable, Optional +from typing import Callable, Optional, Any from jaxtyping import Bool, Int, Float, jaxtyped from beartype import beartype import einops @@ -23,10 +23,11 @@ } -# @jaxtyped(typechecker=beartype) # TODO: jaxtyped struggles with the tokenizer +# beartype struggles with the tokenizer +@jaxtyped(typechecker=beartype) @torch.no_grad def get_bos_pad_eos_mask( - tokens: Int[torch.Tensor, "dataset_size seq_len"], tokenizer: AutoTokenizer + tokens: Int[torch.Tensor, "dataset_size seq_len"], tokenizer: AutoTokenizer | Any ) -> Bool[torch.Tensor, "dataset_size seq_len"]: mask = ( (tokens == tokenizer.pad_token_id) @@ -109,7 +110,7 @@ def get_all_llm_activations( def collect_sae_activations( tokens: Int[torch.Tensor, "dataset_size seq_len"], model: HookedTransformer, - sae: SAE, + sae: SAE | Any, batch_size: int, layer: int, hook_name: str, @@ -154,7 +155,7 @@ def collect_sae_activations( def get_feature_activation_sparsity( tokens: Int[torch.Tensor, "dataset_size seq_len"], model: HookedTransformer, - sae: SAE, + sae: SAE | Any, batch_size: int, layer: int, hook_name: str, @@ -221,7 +222,7 @@ def create_meaned_model_activations( @torch.no_grad def get_sae_meaned_activations( all_llm_activations_BLD: dict[str, Float[torch.Tensor, "batch_size seq_len d_model"]], - sae: SAE, + sae: SAE | Any, sae_batch_size: int, ) -> dict[str, Float[torch.Tensor, "batch_size d_sae"]]: """Encode LLM activations with an SAE and mean across the sequence length dimension for each class while ignoring padding tokens. @@ -302,7 +303,7 @@ def save_activations( @jaxtyped(typechecker=beartype) @torch.no_grad() def encode_precomputed_activations( - sae: SAE, + sae: SAE | Any, sae_batch_size: int, num_chunks: int, activation_dir: str, diff --git a/sae_bench_utils/dataset_utils.py b/sae_bench_utils/dataset_utils.py index 6cfecc6..7f2b3b1 100644 --- a/sae_bench_utils/dataset_utils.py +++ b/sae_bench_utils/dataset_utils.py @@ -24,6 +24,7 @@ def load_and_tokenize_dataset( total_token_count = 0 # Tokenize rows and accumulate tokens + pbar = tqdm(total=num_tokens, desc="Tokenizing dataset") for row in dataset: tokens = tokenizer(row["text"], truncation=True, max_length=ctx_len, return_tensors="pt")[ "input_ids" @@ -31,10 +32,12 @@ def load_and_tokenize_dataset( all_tokens.append(tokens) total_token_count += tokens.shape[0] + pbar.update(tokens.shape[0]) # Stop once we reach the target token count if total_token_count >= num_tokens: break + pbar.close() # Concatenate tokens into a single tensor concatenated_tensor = torch.cat(all_tokens) diff --git a/proposed_sae_selection_strategy.ipynb b/sae_regex_selection.ipynb similarity index 63% rename from proposed_sae_selection_strategy.ipynb rename to sae_regex_selection.ipynb index 76d4a65..a2c5700 100644 --- a/proposed_sae_selection_strategy.ipynb +++ b/sae_regex_selection.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 8, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -52,15 +52,21 @@ "│ gpt2-small │ gpt2-small-res_sle-ajt │ neuronpedia/gpt2-small__res_sle-ajt │ 3 │\n", "│ gpt2-small │ gpt2-small-res_sce-ajt │ neuronpedia/gpt2-small__res_sce-ajt │ 3 │\n", "│ gpt2-small │ gpt2-small-res_scefr-ajt │ neuronpedia/gpt2-small__res_scefr-ajt │ 3 │\n", + "│ meta-llama/Llama-3.1-8B │ llama_scope_lxa_32x │ fnlp/Llama3_1-8B-Base-LXA-32x │ 32 │\n", + "│ meta-llama/Llama-3.1-8B │ llama_scope_lxa_8x │ fnlp/Llama3_1-8B-Base-LXA-8x │ 32 │\n", + "│ meta-llama/Llama-3.1-8B │ llama_scope_lxm_32x │ fnlp/Llama3_1-8B-Base-LXM-32x │ 32 │\n", + "│ meta-llama/Llama-3.1-8B │ llama_scope_lxm_8x │ fnlp/Llama3_1-8B-Base-LXM-8x │ 32 │\n", + "│ meta-llama/Llama-3.1-8B │ llama_scope_lxr_32x │ fnlp/Llama3_1-8B-Base-LXR-32x │ 32 │\n", + "│ meta-llama/Llama-3.1-8B │ llama_scope_lxr_8x │ fnlp/Llama3_1-8B-Base-LXR-8x │ 32 │\n", "│ meta-llama/Meta-Llama-3-8B-Instruct │ llama-3-8b-it-res-jh │ Juliushanhanhan/llama-3-8b-it-res │ 1 │\n", "│ mistral-7b │ mistral-7b-res-wg │ JoshEngels/Mistral-7B-Residual-Stream-SAEs │ 3 │\n", - "│ pythia-70m │ sae_bench_pythia70m_sweep_gated_ctx128_0730 │ canrager/lm_sae │ 40 │\n", - "│ pythia-70m │ sae_bench_pythia70m_sweep_panneal_ctx128_0730 │ canrager/lm_sae │ 56 │\n", - "│ pythia-70m │ sae_bench_pythia70m_sweep_standard_ctx128_0712 │ canrager/lm_sae │ 44 │\n", - "│ pythia-70m │ sae_bench_pythia70m_sweep_topk_ctx128_0730 │ canrager/lm_sae │ 48 │\n", "│ pythia-70m-deduped │ pythia-70m-deduped-res-sm │ ctigges/pythia-70m-deduped__res-sm_processed │ 7 │\n", "│ pythia-70m-deduped │ pythia-70m-deduped-mlp-sm │ ctigges/pythia-70m-deduped__mlp-sm_processed │ 6 │\n", "│ pythia-70m-deduped │ pythia-70m-deduped-att-sm │ ctigges/pythia-70m-deduped__att-sm_processed │ 6 │\n", + "│ pythia-70m-deduped │ sae_bench_pythia70m_sweep_gated_ctx128_0730 │ canrager/lm_sae │ 40 │\n", + "│ pythia-70m-deduped │ sae_bench_pythia70m_sweep_panneal_ctx128_0730 │ canrager/lm_sae │ 56 │\n", + "│ pythia-70m-deduped │ sae_bench_pythia70m_sweep_standard_ctx128_0712 │ canrager/lm_sae │ 44 │\n", + "│ pythia-70m-deduped │ sae_bench_pythia70m_sweep_topk_ctx128_0730 │ canrager/lm_sae │ 48 │\n", "└─────────────────────────────────────┴─────────────────────────────────────────────────────┴────────────────────────────────────────────────────────┴──────────┘\n", "┌────────────────────────┬─────────────────────────────────────────────────────────────────────────┐\n", "│ Field │ Value │\n", @@ -100,46 +106,31 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 12, "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 49/49 [00:00<00:00, 17211.36it/s]" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7b706b194dd44b55aea631b7363aca44", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/55 [00:00 0 test_preds = probe(x_test)[:, 0] > 0 @@ -113,13 +113,15 @@ def test_train_binary_probe_scores_highly_on_fully_separable_datasets(seed): assert test_acc > 0.98 sk_probe = LogisticRegression(max_iter=100, class_weight="balanced").fit( - x_train.numpy(), y_train.numpy() + x_train.cpu().numpy(), y_train.cpu().numpy() ) # since this is a synthetic dataset, we know the correct direction we should learn correct_dir = (pos_center - neg_center).unsqueeze(0) # just verify that sklearn does get the right answer - sk_cos_sim = cosine_similarity(correct_dir, torch.tensor(sk_probe.coef_), dim=1) + sk_cos_sim = cosine_similarity( + correct_dir.cpu(), torch.tensor(sk_probe.coef_, device=device), dim=1 + ) assert sk_cos_sim.min().item() > 0.98 cos_sim = cosine_similarity(correct_dir, probe.weights, dim=1) assert cos_sim.min().item() > 0.98 @@ -127,6 +129,7 @@ def test_train_binary_probe_scores_highly_on_fully_separable_datasets(seed): @pytest.mark.parametrize("seed", range(5)) def test_train_binary_probe_scores_highly_on_noisy_datasets(seed): + device = torch.device("cpu") torch.manual_seed(seed) neg_center = 1.0 * torch.ones(64) pos_center = -1.0 * torch.ones(64) @@ -156,6 +159,7 @@ def test_train_binary_probe_scores_highly_on_noisy_datasets(seed): weight_decay=1e-7, lr=1.0, show_progress=False, + device=device, ) train_preds = probe(x_train)[:, 0] > 0 @@ -182,6 +186,7 @@ def test_train_binary_probe_scores_highly_on_noisy_datasets(seed): @pytest.mark.parametrize("seed", range(5)) def test_train_multi_probe_scores_highly_on_fully_separable_datasets(seed): + device = torch.device("cpu") torch.manual_seed(seed) class1_center = 10 * torch.randn(64) class2_center = 10 * torch.randn(64) @@ -214,7 +219,7 @@ def test_train_multi_probe_scores_highly_on_fully_separable_datasets(seed): y_test = y[300:] probe = train_multi_probe( - x_train, y_train, num_probes=3, num_epochs=100, batch_size=32 + x_train, y_train, num_probes=3, num_epochs=100, batch_size=32, device=device ) train_preds = probe(x_train) > 0 @@ -229,6 +234,7 @@ def test_train_multi_probe_scores_highly_on_fully_separable_datasets(seed): @pytest.mark.parametrize("seed", range(5)) def test_train_multi_probe_scores_highly_on_noisy_datasets(seed): + device = torch.device("cpu") torch.manual_seed(seed) class1_center = 0.5 * torch.randn(64) class2_center = 0.5 * torch.randn(64) @@ -261,7 +267,7 @@ def test_train_multi_probe_scores_highly_on_noisy_datasets(seed): y_test = y[500:] probe = train_multi_probe( - x_train, y_train, num_probes=3, num_epochs=100, batch_size=128 + x_train, y_train, num_probes=3, num_epochs=100, batch_size=128, device=device ) train_preds = probe(x_train) > 0 @@ -325,15 +331,11 @@ def test_create_dataset_probe_training(): def test_gen_and_save_df_acts_probing(mock_to_csv, mock_model, tmp_path): dataset = [ ( - SpellingPrompt( - base="The word 'cat' is spelled:", answer=" c-a-t", word="cat" - ), + SpellingPrompt(base="The word 'cat' is spelled:", answer=" c-a-t", word="cat"), 0, ), ( - SpellingPrompt( - base="The word 'dog' is spelled:", answer=" d-o-g", word="dog" - ), + SpellingPrompt(base="The word 'dog' is spelled:", answer=" d-o-g", word="dog"), 1, ), ] @@ -405,11 +407,12 @@ def test_train_linear_probe_for_task(): def test_gen_probe_stats(): + device = torch.device("cpu") probe = LinearProbe(input_dim=768, num_outputs=26) X_val = torch.rand(100, 768) y_val = torch.randint(0, 26, (100,)) - results = gen_probe_stats(probe, X_val, y_val) + results = gen_probe_stats(probe, X_val, y_val, device=device) assert len(results) == 26 for result in results: diff --git a/tests/test_absorption.py b/tests/test_absorption.py index 52561de..0b1af87 100644 --- a/tests/test_absorption.py +++ b/tests/test_absorption.py @@ -9,9 +9,7 @@ from sae_bench_utils.testing_utils import validate_eval_output_format_file test_data_dir = "tests/test_data/absorption" -expected_results_filename = os.path.join( - test_data_dir, "absorption_expected_results.json" -) +expected_results_filename = os.path.join(test_data_dir, "absorption_expected_results.json") expected_probe_results_filename = os.path.join( test_data_dir, "absorption_expected_probe_results.json" ) @@ -42,6 +40,7 @@ def test_end_to_end_different_seed(): device = "mps" else: device = "cuda" if torch.cuda.is_available() else "cpu" + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" print(f"Using device: {device}") @@ -52,6 +51,8 @@ def test_end_to_end_different_seed(): max_k_value=10, prompt_template="{word} has the first letter:", prompt_token_pos=-6, + llm_batch_size=512, + llm_dtype="float32", ) selected_saes_dict = get_saes_from_regex(TEST_RELEASE, TEST_SAE_NAME) print(f"Selected SAEs: {selected_saes_dict}") @@ -61,24 +62,22 @@ def test_end_to_end_different_seed(): selected_saes_dict=selected_saes_dict, device=device, output_path=test_data_dir, - force_rerun=False, + force_rerun=True, ) path_to_eval_results = os.path.join( test_data_dir, f"{TEST_RELEASE}_{TEST_SAE_NAME}_eval_results.json" ) - validate_eval_output_format_file( - path_to_eval_results, eval_output_type=AbsorptionEvalOutput - ) + validate_eval_output_format_file(path_to_eval_results, eval_output_type=AbsorptionEvalOutput) # New checks for the updated JSON structure assert isinstance(run_results, dict), "run_results should be a dictionary" # Find the correct key in the new structure actual_result_key = f"{TEST_RELEASE}_{TEST_SAE_NAME}" - actual_mean_absorption_rate = run_results[actual_result_key]["eval_result_metrics"][ - "mean" - ]["mean_absorption_score"] + actual_mean_absorption_rate = run_results[actual_result_key]["eval_result_metrics"]["mean"][ + "mean_absorption_score" + ] # Load expected results and compare with open(expected_results_filename, "r") as f: @@ -87,7 +86,4 @@ def test_end_to_end_different_seed(): expected_mean_absorption_rate = expected_results["eval_result_metrics"]["mean"][ "mean_absorption_score" ] - assert ( - abs(actual_mean_absorption_rate - expected_mean_absorption_rate) - < TEST_TOLERANCE - ) + assert abs(actual_mean_absorption_rate - expected_mean_absorption_rate) < TEST_TOLERANCE