Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
callummcdougall committed Oct 16, 2024
1 parent c57eef7 commit 4b23575
Show file tree
Hide file tree
Showing 8 changed files with 1,237 additions and 1 deletion.
19 changes: 19 additions & 0 deletions evals/autointerp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# AutoInterp

## File structure

There are 4 Python files in this folder:

- `config.py` - this contains the config class for AutoInterp.
- `main.py` - this contains the main `AutoInterp` class, as well as the functions which are the interface to the rest of the SAEBench codebase.
- `demo.py` - you can run this via `python demo.py --api_key YOUR_API_KEY` to see an example output & how the function works. It creates & saves a log file (I've left the output of those files in the repo, so you can see what they look like).

## Summary of how it works

### Generation phase

We run a batch through the model & SAE, getting activation values. We take some number of sequences from the top of the activation distribution, and also sample some number of sequences from the rest of the distribution with sample probability proportional to their activation (this is a stand-in for quantile sampling, which should be more compatible with e.g. Gated models which won't have values in all quantiles). We take these sequences and format the activating token using `<<token>>` syntax, then feed them through the model and ask for an explanation.

### Scoring phase

We select some number of top sequences & importance weighting sampled sequences (like the generation phase), but also include some sequences chosen randomly from the rest of the distribution. We'll shuffle these together and give them to the LLM as a numbered list, and we'll ask the LLM to return a comma-separated list of the indices of the sequences which it thinks will activate this feature.
93 changes: 93 additions & 0 deletions evals/autointerp/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from dataclasses import dataclass


@dataclass
class AutoInterpConfig:
"""
Controls all parameters for how autointerp will work.
Arguments:
model_name: The name of the model to use
device: The device to use
n_latents: The number of latents to use
override_latents: The latents to use (overrides n_latents if supplied)
seed: The seed to use for all randomness
buffer: The size of the buffer to use for scoring
no_overlap: Whether to allow overlapping sequences for scoring
act_threshold_frac: The fraction of the maximum activation to use as the activation threshold
total_tokens: The total number of tokens we'll gather data for.
batch_size: The batch size to use for the scoring phase
scoring: Whether to perform the scoring phase, or just return explanation
max_tokens_in_explanation: The maximum number of tokens to allow in an explanation
use_demos_in_explanation: Whether to use demonstrations in the explanation prompt
n_top_ex_for_generation: The number of top activating sequences to use for the generation phase
n_iw_sampled_ex_for_generation: The number of importance-sampled sequences to use for the generation phase (this
is a replacement for quantile sampling)
n_top_ex_for_scoring: The number of top sequences to use for scoring
n_random_ex_for_scoring: The number of random sequences to use for scoring
n_iw_sampled_ex_for_scoring: The number of importance-sampled sequences to use for scoring
"""

# Important stuff
model_name: str
n_latents: int | None = None
override_latents: list[int] | None = None
seed: int = 0

# Main stuff
buffer: int = 10
no_overlap: bool = True
act_threshold_frac: float = 0.01
total_tokens: int = 10_000_000
batch_size: int = 512 # split up total tokens into batches of this size
scoring: bool = True
max_tokens_in_explanation: int = 30
use_demos_in_explanation: bool = True

# Sequences included in generation phase
n_top_ex_for_generation: int = 10
n_iw_sampled_ex_for_generation: int = 5

# Sequences included in scoring phase
n_top_ex_for_scoring: int = 4
n_random_ex_for_scoring: int = 10
n_iw_sampled_ex_for_scoring: int = 0

def __post_init__(self):
if self.n_latents is None:
assert self.override_latents is not None
self.latents = self.override_latents
self.n_latents = len(self.latents)
else:
assert self.override_latents is None
self.latents = None

@property
def n_top_ex(self):
"""When fetching data, we get the top examples for generation & scoring simultaneously."""
return self.n_top_ex_for_generation + self.n_top_ex_for_scoring

@property
def max_tokens_in_prediction(self) -> int:
"""Predictions take the form of comma-separated numbers, which should all be single tokens."""
return 2 * self.n_ex_for_scoring + 5

@property
def n_ex_for_generation(self) -> int:
return self.n_top_ex_for_generation + self.n_iw_sampled_ex_for_generation

@property
def n_ex_for_scoring(self) -> int:
"""For scoring phase, we use a randomly shuffled mix of top-k activations and random sequences."""
return self.n_top_ex_for_scoring + self.n_random_ex_for_scoring + self.n_iw_sampled_ex_for_scoring

@property
def n_iw_sampled_ex(self) -> int:
return self.n_iw_sampled_ex_for_generation + self.n_iw_sampled_ex_for_scoring

@property
def n_correct_for_scoring(self) -> int:
return self.n_top_ex_for_scoring + self.n_iw_sampled_ex_for_scoring
48 changes: 48 additions & 0 deletions evals/autointerp/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import argparse
from pathlib import Path

import torch
from evals.autointerp.config import AutoInterpConfig
from evals.autointerp.main import run_eval

# Set up command-line argument parsing
parser = argparse.ArgumentParser(description="Run AutoInterp evaluation.")
parser.add_argument(
"--api_key", type=str, required=True, help="API key for the evaluation."
)
args = parser.parse_args()

api_key = args.api_key # Use the API key supplied via command line

device = torch.device(
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)

selected_saes_dict = {
"gpt2-small-res-jb": ["blocks.7.hook_resid_pre"],
}
torch.set_grad_enabled(False)

# ! Demo 1: just 4 specially chosen latents
cfg = AutoInterpConfig(model_name="gpt2-small", override_latents=[9, 11, 15, 16873])
save_logs_path = Path(__file__).parent / "logs_4.txt"
save_logs_path.unlink(missing_ok=True)
results = run_eval(
cfg, selected_saes_dict, device, api_key, save_logs_path=save_logs_path
)
print(results)

# ! Demo 2: 100 randomly chosen latents
cfg = AutoInterpConfig(model_name="gpt2-small", n_latents=100)
save_logs_path = Path(__file__).parent / "logs_100.txt"
save_logs_path.unlink(missing_ok=True)
results = run_eval(
cfg, selected_saes_dict, device, api_key, save_logs_path=save_logs_path
)
print(results)

# python demo.py --api_key "YOUR_API_KEY"
Loading

0 comments on commit 4b23575

Please sign in to comment.