Skip to content

Commit

Permalink
Merge pull request #33 from adamkarvonen/add_baselines
Browse files Browse the repository at this point in the history
Add baselines
  • Loading branch information
adamkarvonen authored Nov 14, 2024
2 parents 186bdb4 + e5b2ba4 commit 20c2a40
Show file tree
Hide file tree
Showing 18 changed files with 810 additions and 362 deletions.
85 changes: 52 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,60 +1,79 @@
## SAE Bench: Template for custom evals
# SAE Bench

This repo contains the template we would like to use for the SAE Bench project. The `template.ipynb` explains the input to your custom eval (SAEs hosted on SAELens) and the output (a standardized results file).
## Table of Contents
- [Installation](#installation)
- [Overview](#overview)
- [Running Evaluations](#running-evaluations)
- [Custom SAE Usage](#custom-sae-usage)
- [Training Your Own SAEs](#training-your-own-saes)
- [Graphing Results](#graphing-results)

### Installation
Set up a virtual environment running on `python 3.11.9`.
Set up a virtual environment with python >= 3.10.
```
git clone https://github.com/adamkarvonen/SAE_Bench_Template.git
cd SAE_Bench_Template
pip install -e .
```

### Quick start:

1. Browse through `template.ipynb` to learn about input and output of your metric.
2. Execute `main.py` in `evals/sparse_probing` to see an example implementation.
3. Inspect sparse probing results with `graphing.ipynb`.
All evals can be ran with current batch sizes on Gemma-2-2B on a 24GB VRAM GPU (e.g. a RTX 3090).

## Overview

The `evals/sparse_probing` folder contains a full example implementation of a custom eval. In `main.py`, we have a function that takes a list of SAELens SAE names (defined in `eval_config.py`) and an sae release and returns a dictionary of results in a standard format. This folder contains some helper functions, like pre-computing model activations in `utils/activation_collection.py`, that might be useful for you, too! We try to reuse functions as much as possible across evals to reduce bugs. Let Adam and Can know if you've implemented a helper function that might be useful for other evals as well (like autointerp, feature scoring).
SAE Bench is a comprehensive suite of 8 evaluations for Sparse Autoencoder (SAE) models:
- **Feature Absorption**
- **AutoInterp**
- **L0 / Loss Recovered**
- **RAVEL**
- **SHIFT**
- **TPP**
- **Sparse Probing**
- **Unlearning**

`python3 main.py` (after `cd evals/sparse_probing/`) should run as is and demonstrate how to use our SAE Bench SAEs with Transformer Lens and SAE Lens. It will also generate a results file which can be graphed using `graphing.ipynb`.
### Supported Models and SAEs
- **SAE Lens Pretrained Models**: Supports evaluations on any SAE Lens pretrained model.
- **Custom SAEs**: Supports any general SAE object with `encode()` and `decode()` methods (see [Custom SAE Usage](#custom-sae-usage)).

Here is what we would like to see from each eval:
## Running Evaluations
Each evaluation has an example command located in its respective `main.py` file. Here's how to run a sparse probing evaluation on a single SAE Bench Pythia-70M SAE:

- Making sure we are returned both the results any config required for reproducibility (eg: eval config / function args).
- Ensuring the code meets some minimum bar (isn't missing anything, isn't abysmally slow etc).
- Ensuring we have example output to validate against.
```
python evals/sparse_probing/main.py \
--sae_regex_pattern "sae_bench_pythia70m_sweep_standard_ctx128_0712" \
--sae_block_pattern "blocks.4.hook_resid_post__trainer_10" \
--model_name pythia-70m-deduped
```

the results dictionary of you custom eval can be loaded in to `graphing.ipynb`, to create a wide variety of plots. We also already have the basic `L0 / Loss Recovered` metrics for every SAE, as specified in `template.ipynb`.
The results will be saved to the evals/sparse_probing/results directory.

For the purpose of validating evaluation outputs, we have `compare_run_results.ipynb`. Using this, you can run the same eval twice with the same input, and verify within a tolerance that it returns the same outputs. If the eval is fully deterministic, the results should be identical.
We use regex patterns to select SAE Lens SAEs. For more examples of regex patterns, refer to `sae_selection.ipynb`.

Once evaluations have been completed, please submit them as pull requests to this repo.
Every eval folder contains an `eval_config.py`, which contains all relevant hyperparamters for that evaluation. The values are currently set to the default recommended values.

## Eval Format
For a tutorial of using SAE Lens SAEs, including calculating L0 and Loss Recovered and getting a set of tokens from The Pile, refer to this notebook: https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb

Ideally, we would like to see something like `evals.sparse_probing.main.py`, which contains a `run_eval()` function. This function takes in hyperparameters and a `selected_saes_dict`, which is a dict of `sae_release` : `list[sae_names]` (as shown in template.ipynb).
## Custom SAE Usage

All evals and submodules will share the same dependencies, which are set in pyproject.toml.
Our goal is to have first class support for custom SAEs as the field is rapidly evolving. Our evaluations can run on any SAE object with encode(), decode(), and a few config values. For example custom SAEs, refer to the `baselines/` folder.

For a tutorial of using SAE Lens SAEs, including calculating L0 and Loss Recovered and getting a set of tokens from The Pile, refer to this notebook: https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb
There are two ways to evaluate custom SAEs:

## Custom SAE Usage
1. **Using Evaluation Templates**:
- Use the secondary `if __name__ == "__main__"` block in each `main.py`
- Results are saved in SAE Bench format for easy visualization
- Compatible with provided plotting tools

For the sparse probing and SHIFT / TPP evals, we support evaluating any SAE object that has the following implemented, with inputs / outputs matching the SAELens SAE format:
2. **Direct Function Calls**:
- Use `run_eval_single_sae()` in each `main.py`
- Simpler interface requiring only model, SAE, and config values
- Graphing will require manual formatting

```
sae.encode()
sae.decode()
sae.forward()
sae.W_dec # nn.Parameter(d_sae, d_in), required for SHIFT, TPP, and Feature Absorption
sae.device
sae.dtype
```
We currently have a suite of SAE Bench SAEs on layers 3 and 4 of Pythia-70M and layers 5, 12, and 19 of Gemma-2-2B, each trained on 200M tokens. These SAEs can serve as baselines for any new custom SAEs. We also have baseline eval results, saved at TODO.

## Training Your Own SAEs

You can replicate the training of our SAEs using scripts provided [here](https://github.com/canrager/dictionary_training/), or implement your own SAE, or make a change to one of our SAE implementations. Once you train your new version, you can benchmark against our existing SAEs for a true apples to apples comparison.

Just pass the appropriate inputs to `run_eval_single_sae()`, referring to individual eval READMEs as needed. If you match our output format you can reuse our graphing notebook.
## Graphing Results

To run our baselines in pythia-70m and gemma-2-2b, refer to `if __name__ == "__main__":` in `shift_and_tpp/main.py`.
To graph the results, refer to `graphing.ipynb`, which can graph the generated SAE Bench data. Note that many graphs plot SAEs by L0 and / or Loss Recovered. To obtain these scores, run `evals/core/main.py`.
59 changes: 59 additions & 0 deletions baselines/identity_sae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Optional


@dataclass
class SAEConfig:
model_name: str
d_in: int
d_sae: int
hook_layer: int
hook_name: str
context_size: int = 128 # Can be used for auto-interp
hook_head_index: Optional[int] = None


class IdentitySAE(nn.Module):
def __init__(self, model_name: str, d_model: int, hook_layer: int):
super().__init__()

# Initialize W_enc and W_dec as identity matrices
self.W_enc = nn.Parameter(torch.eye(d_model))
self.W_dec = nn.Parameter(torch.eye(d_model))
self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype: torch.dtype = torch.float32

hook_name = f"blocks.{hook_layer}.hook_resid_post"

# Initialize the configuration dataclass
self.cfg = SAEConfig(
model_name, d_in=d_model, d_sae=d_model, hook_name=hook_name, hook_layer=hook_layer
)

def encode(self, input_acts: torch.Tensor):
acts = input_acts @ self.W_enc
return acts

def decode(self, acts: torch.Tensor):
return acts @ self.W_dec

def forward(self, acts):
acts = self.encode(acts)
recon = self.decode(acts)
return recon

# required as we have device and dtype class attributes
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
# Update the device and dtype attributes based on the first parameter
device = kwargs.get("device", None)
dtype = kwargs.get("dtype", None)

# Update device and dtype if they were provided
if device:
self.device = device
if dtype:
self.dtype = dtype
return self
86 changes: 86 additions & 0 deletions baselines/jumprelu_sae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
import numpy as np
from dataclasses import dataclass
from typing import Optional


@dataclass
class SAEConfig:
model_name: str
d_in: int
d_sae: int
hook_layer: int
hook_name: str
context_size: int = 128 # Can be used for auto-interp
hook_head_index: Optional[int] = None


class JumpReLUSAE(nn.Module):
def __init__(self, d_model: int, d_sae: int, hook_layer: int, model_name: str = "gemma-2-2b"):
super().__init__()
self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
self.threshold = nn.Parameter(torch.zeros(d_sae))
self.b_enc = nn.Parameter(torch.zeros(d_sae))
self.b_dec = nn.Parameter(torch.zeros(d_model))
self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype: torch.dtype = torch.float32

hook_name = f"blocks.{hook_layer}.hook_resid_post"

self.cfg = SAEConfig(
model_name, d_in=d_model, d_sae=d_model, hook_name=hook_name, hook_layer=hook_layer
)

def encode(self, input_acts):
pre_acts = input_acts @ self.W_enc + self.b_enc
mask = pre_acts > self.threshold
acts = mask * torch.nn.functional.relu(pre_acts)
return acts

def decode(self, acts):
return acts @ self.W_dec + self.b_dec

def forward(self, acts):
acts = self.encode(acts)
recon = self.decode(acts)
return recon

# required as we have device and dtype class attributes
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
# Update the device and dtype attributes based on the first parameter
device = kwargs.get("device", None)
dtype = kwargs.get("dtype", None)

# Update device and dtype if they were provided
if device:
self.device = device
if dtype:
self.dtype = dtype
return self


def load_jumprelu_sae(repo_id: str, filename: str, layer: int) -> JumpReLUSAE:
path_to_params = hf_hub_download(
repo_id=repo_id,
filename=filename,
force_download=False,
)

params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()}

sae = JumpReLUSAE(params["W_enc"].shape[0], params["W_enc"].shape[1], layer)
sae.load_state_dict(pt_params)

return sae


if __name__ == "__main__":
repo_id = "google/gemma-scope-2b-pt-res"
filename = "layer_20/width_16k/average_l0_71/params.npz"

sae = load_jumprelu_sae(repo_id, filename, 20)
10 changes: 10 additions & 0 deletions evals/absorption/eval_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,13 @@ class AbsorptionEvalConfig(BaseEvalConfig):
title="Model Name",
description="Model name",
)
llm_batch_size: int = Field(
default=32,
title="LLM Batch Size",
description="LLM batch size, inference only",
)
llm_dtype: str = Field(
default="bfloat16",
title="LLM Data Type",
description="LLM data type",
)
Loading

0 comments on commit 20c2a40

Please sign in to comment.