-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #33 from adamkarvonen/add_baselines
Add baselines
- Loading branch information
Showing
18 changed files
with
810 additions
and
362 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,60 +1,79 @@ | ||
## SAE Bench: Template for custom evals | ||
# SAE Bench | ||
|
||
This repo contains the template we would like to use for the SAE Bench project. The `template.ipynb` explains the input to your custom eval (SAEs hosted on SAELens) and the output (a standardized results file). | ||
## Table of Contents | ||
- [Installation](#installation) | ||
- [Overview](#overview) | ||
- [Running Evaluations](#running-evaluations) | ||
- [Custom SAE Usage](#custom-sae-usage) | ||
- [Training Your Own SAEs](#training-your-own-saes) | ||
- [Graphing Results](#graphing-results) | ||
|
||
### Installation | ||
Set up a virtual environment running on `python 3.11.9`. | ||
Set up a virtual environment with python >= 3.10. | ||
``` | ||
git clone https://github.com/adamkarvonen/SAE_Bench_Template.git | ||
cd SAE_Bench_Template | ||
pip install -e . | ||
``` | ||
|
||
### Quick start: | ||
|
||
1. Browse through `template.ipynb` to learn about input and output of your metric. | ||
2. Execute `main.py` in `evals/sparse_probing` to see an example implementation. | ||
3. Inspect sparse probing results with `graphing.ipynb`. | ||
All evals can be ran with current batch sizes on Gemma-2-2B on a 24GB VRAM GPU (e.g. a RTX 3090). | ||
|
||
## Overview | ||
|
||
The `evals/sparse_probing` folder contains a full example implementation of a custom eval. In `main.py`, we have a function that takes a list of SAELens SAE names (defined in `eval_config.py`) and an sae release and returns a dictionary of results in a standard format. This folder contains some helper functions, like pre-computing model activations in `utils/activation_collection.py`, that might be useful for you, too! We try to reuse functions as much as possible across evals to reduce bugs. Let Adam and Can know if you've implemented a helper function that might be useful for other evals as well (like autointerp, feature scoring). | ||
SAE Bench is a comprehensive suite of 8 evaluations for Sparse Autoencoder (SAE) models: | ||
- **Feature Absorption** | ||
- **AutoInterp** | ||
- **L0 / Loss Recovered** | ||
- **RAVEL** | ||
- **SHIFT** | ||
- **TPP** | ||
- **Sparse Probing** | ||
- **Unlearning** | ||
|
||
`python3 main.py` (after `cd evals/sparse_probing/`) should run as is and demonstrate how to use our SAE Bench SAEs with Transformer Lens and SAE Lens. It will also generate a results file which can be graphed using `graphing.ipynb`. | ||
### Supported Models and SAEs | ||
- **SAE Lens Pretrained Models**: Supports evaluations on any SAE Lens pretrained model. | ||
- **Custom SAEs**: Supports any general SAE object with `encode()` and `decode()` methods (see [Custom SAE Usage](#custom-sae-usage)). | ||
|
||
Here is what we would like to see from each eval: | ||
## Running Evaluations | ||
Each evaluation has an example command located in its respective `main.py` file. Here's how to run a sparse probing evaluation on a single SAE Bench Pythia-70M SAE: | ||
|
||
- Making sure we are returned both the results any config required for reproducibility (eg: eval config / function args). | ||
- Ensuring the code meets some minimum bar (isn't missing anything, isn't abysmally slow etc). | ||
- Ensuring we have example output to validate against. | ||
``` | ||
python evals/sparse_probing/main.py \ | ||
--sae_regex_pattern "sae_bench_pythia70m_sweep_standard_ctx128_0712" \ | ||
--sae_block_pattern "blocks.4.hook_resid_post__trainer_10" \ | ||
--model_name pythia-70m-deduped | ||
``` | ||
|
||
the results dictionary of you custom eval can be loaded in to `graphing.ipynb`, to create a wide variety of plots. We also already have the basic `L0 / Loss Recovered` metrics for every SAE, as specified in `template.ipynb`. | ||
The results will be saved to the evals/sparse_probing/results directory. | ||
|
||
For the purpose of validating evaluation outputs, we have `compare_run_results.ipynb`. Using this, you can run the same eval twice with the same input, and verify within a tolerance that it returns the same outputs. If the eval is fully deterministic, the results should be identical. | ||
We use regex patterns to select SAE Lens SAEs. For more examples of regex patterns, refer to `sae_selection.ipynb`. | ||
|
||
Once evaluations have been completed, please submit them as pull requests to this repo. | ||
Every eval folder contains an `eval_config.py`, which contains all relevant hyperparamters for that evaluation. The values are currently set to the default recommended values. | ||
|
||
## Eval Format | ||
For a tutorial of using SAE Lens SAEs, including calculating L0 and Loss Recovered and getting a set of tokens from The Pile, refer to this notebook: https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb | ||
|
||
Ideally, we would like to see something like `evals.sparse_probing.main.py`, which contains a `run_eval()` function. This function takes in hyperparameters and a `selected_saes_dict`, which is a dict of `sae_release` : `list[sae_names]` (as shown in template.ipynb). | ||
## Custom SAE Usage | ||
|
||
All evals and submodules will share the same dependencies, which are set in pyproject.toml. | ||
Our goal is to have first class support for custom SAEs as the field is rapidly evolving. Our evaluations can run on any SAE object with encode(), decode(), and a few config values. For example custom SAEs, refer to the `baselines/` folder. | ||
|
||
For a tutorial of using SAE Lens SAEs, including calculating L0 and Loss Recovered and getting a set of tokens from The Pile, refer to this notebook: https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb | ||
There are two ways to evaluate custom SAEs: | ||
|
||
## Custom SAE Usage | ||
1. **Using Evaluation Templates**: | ||
- Use the secondary `if __name__ == "__main__"` block in each `main.py` | ||
- Results are saved in SAE Bench format for easy visualization | ||
- Compatible with provided plotting tools | ||
|
||
For the sparse probing and SHIFT / TPP evals, we support evaluating any SAE object that has the following implemented, with inputs / outputs matching the SAELens SAE format: | ||
2. **Direct Function Calls**: | ||
- Use `run_eval_single_sae()` in each `main.py` | ||
- Simpler interface requiring only model, SAE, and config values | ||
- Graphing will require manual formatting | ||
|
||
``` | ||
sae.encode() | ||
sae.decode() | ||
sae.forward() | ||
sae.W_dec # nn.Parameter(d_sae, d_in), required for SHIFT, TPP, and Feature Absorption | ||
sae.device | ||
sae.dtype | ||
``` | ||
We currently have a suite of SAE Bench SAEs on layers 3 and 4 of Pythia-70M and layers 5, 12, and 19 of Gemma-2-2B, each trained on 200M tokens. These SAEs can serve as baselines for any new custom SAEs. We also have baseline eval results, saved at TODO. | ||
|
||
## Training Your Own SAEs | ||
|
||
You can replicate the training of our SAEs using scripts provided [here](https://github.com/canrager/dictionary_training/), or implement your own SAE, or make a change to one of our SAE implementations. Once you train your new version, you can benchmark against our existing SAEs for a true apples to apples comparison. | ||
|
||
Just pass the appropriate inputs to `run_eval_single_sae()`, referring to individual eval READMEs as needed. If you match our output format you can reuse our graphing notebook. | ||
## Graphing Results | ||
|
||
To run our baselines in pythia-70m and gemma-2-2b, refer to `if __name__ == "__main__":` in `shift_and_tpp/main.py`. | ||
To graph the results, refer to `graphing.ipynb`, which can graph the generated SAE Bench data. Note that many graphs plot SAEs by L0 and / or Loss Recovered. To obtain these scores, run `evals/core/main.py`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
import torch.nn as nn | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
|
||
@dataclass | ||
class SAEConfig: | ||
model_name: str | ||
d_in: int | ||
d_sae: int | ||
hook_layer: int | ||
hook_name: str | ||
context_size: int = 128 # Can be used for auto-interp | ||
hook_head_index: Optional[int] = None | ||
|
||
|
||
class IdentitySAE(nn.Module): | ||
def __init__(self, model_name: str, d_model: int, hook_layer: int): | ||
super().__init__() | ||
|
||
# Initialize W_enc and W_dec as identity matrices | ||
self.W_enc = nn.Parameter(torch.eye(d_model)) | ||
self.W_dec = nn.Parameter(torch.eye(d_model)) | ||
self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
self.dtype: torch.dtype = torch.float32 | ||
|
||
hook_name = f"blocks.{hook_layer}.hook_resid_post" | ||
|
||
# Initialize the configuration dataclass | ||
self.cfg = SAEConfig( | ||
model_name, d_in=d_model, d_sae=d_model, hook_name=hook_name, hook_layer=hook_layer | ||
) | ||
|
||
def encode(self, input_acts: torch.Tensor): | ||
acts = input_acts @ self.W_enc | ||
return acts | ||
|
||
def decode(self, acts: torch.Tensor): | ||
return acts @ self.W_dec | ||
|
||
def forward(self, acts): | ||
acts = self.encode(acts) | ||
recon = self.decode(acts) | ||
return recon | ||
|
||
# required as we have device and dtype class attributes | ||
def to(self, *args, **kwargs): | ||
super().to(*args, **kwargs) | ||
# Update the device and dtype attributes based on the first parameter | ||
device = kwargs.get("device", None) | ||
dtype = kwargs.get("dtype", None) | ||
|
||
# Update device and dtype if they were provided | ||
if device: | ||
self.device = device | ||
if dtype: | ||
self.dtype = dtype | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import torch | ||
import torch.nn as nn | ||
from huggingface_hub import hf_hub_download | ||
import numpy as np | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
|
||
@dataclass | ||
class SAEConfig: | ||
model_name: str | ||
d_in: int | ||
d_sae: int | ||
hook_layer: int | ||
hook_name: str | ||
context_size: int = 128 # Can be used for auto-interp | ||
hook_head_index: Optional[int] = None | ||
|
||
|
||
class JumpReLUSAE(nn.Module): | ||
def __init__(self, d_model: int, d_sae: int, hook_layer: int, model_name: str = "gemma-2-2b"): | ||
super().__init__() | ||
self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae)) | ||
self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model)) | ||
self.threshold = nn.Parameter(torch.zeros(d_sae)) | ||
self.b_enc = nn.Parameter(torch.zeros(d_sae)) | ||
self.b_dec = nn.Parameter(torch.zeros(d_model)) | ||
self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
self.dtype: torch.dtype = torch.float32 | ||
|
||
hook_name = f"blocks.{hook_layer}.hook_resid_post" | ||
|
||
self.cfg = SAEConfig( | ||
model_name, d_in=d_model, d_sae=d_model, hook_name=hook_name, hook_layer=hook_layer | ||
) | ||
|
||
def encode(self, input_acts): | ||
pre_acts = input_acts @ self.W_enc + self.b_enc | ||
mask = pre_acts > self.threshold | ||
acts = mask * torch.nn.functional.relu(pre_acts) | ||
return acts | ||
|
||
def decode(self, acts): | ||
return acts @ self.W_dec + self.b_dec | ||
|
||
def forward(self, acts): | ||
acts = self.encode(acts) | ||
recon = self.decode(acts) | ||
return recon | ||
|
||
# required as we have device and dtype class attributes | ||
def to(self, *args, **kwargs): | ||
super().to(*args, **kwargs) | ||
# Update the device and dtype attributes based on the first parameter | ||
device = kwargs.get("device", None) | ||
dtype = kwargs.get("dtype", None) | ||
|
||
# Update device and dtype if they were provided | ||
if device: | ||
self.device = device | ||
if dtype: | ||
self.dtype = dtype | ||
return self | ||
|
||
|
||
def load_jumprelu_sae(repo_id: str, filename: str, layer: int) -> JumpReLUSAE: | ||
path_to_params = hf_hub_download( | ||
repo_id=repo_id, | ||
filename=filename, | ||
force_download=False, | ||
) | ||
|
||
params = np.load(path_to_params) | ||
pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()} | ||
|
||
sae = JumpReLUSAE(params["W_enc"].shape[0], params["W_enc"].shape[1], layer) | ||
sae.load_state_dict(pt_params) | ||
|
||
return sae | ||
|
||
|
||
if __name__ == "__main__": | ||
repo_id = "google/gemma-scope-2b-pt-res" | ||
filename = "layer_20/width_16k/average_l0_71/params.npz" | ||
|
||
sae = load_jumprelu_sae(repo_id, filename, 20) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.