-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c57eef7
commit 4b23575
Showing
8 changed files
with
1,237 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# AutoInterp | ||
|
||
## File structure | ||
|
||
There are 4 Python files in this folder: | ||
|
||
- `config.py` - this contains the config class for AutoInterp. | ||
- `main.py` - this contains the main `AutoInterp` class, as well as the functions which are the interface to the rest of the SAEBench codebase. | ||
- `demo.py` - you can run this via `python demo.py --api_key YOUR_API_KEY` to see an example output & how the function works. It creates & saves a log file (I've left the output of those files in the repo, so you can see what they look like). | ||
|
||
## Summary of how it works | ||
|
||
### Generation phase | ||
|
||
We run a batch through the model & SAE, getting activation values. We take some number of sequences from the top of the activation distribution, and also sample some number of sequences from the rest of the distribution with sample probability proportional to their activation (this is a stand-in for quantile sampling, which should be more compatible with e.g. Gated models which won't have values in all quantiles). We take these sequences and format the activating token using `<<token>>` syntax, then feed them through the model and ask for an explanation. | ||
|
||
### Scoring phase | ||
|
||
We select some number of top sequences & importance weighting sampled sequences (like the generation phase), but also include some sequences chosen randomly from the rest of the distribution. We'll shuffle these together and give them to the LLM as a numbered list, and we'll ask the LLM to return a comma-separated list of the indices of the sequences which it thinks will activate this feature. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class AutoInterpConfig: | ||
""" | ||
Controls all parameters for how autointerp will work. | ||
Arguments: | ||
model_name: The name of the model to use | ||
device: The device to use | ||
n_latents: The number of latents to use | ||
override_latents: The latents to use (overrides n_latents if supplied) | ||
seed: The seed to use for all randomness | ||
buffer: The size of the buffer to use for scoring | ||
no_overlap: Whether to allow overlapping sequences for scoring | ||
act_threshold_frac: The fraction of the maximum activation to use as the activation threshold | ||
total_tokens: The total number of tokens we'll gather data for. | ||
batch_size: The batch size to use for the scoring phase | ||
scoring: Whether to perform the scoring phase, or just return explanation | ||
max_tokens_in_explanation: The maximum number of tokens to allow in an explanation | ||
use_demos_in_explanation: Whether to use demonstrations in the explanation prompt | ||
n_top_ex_for_generation: The number of top activating sequences to use for the generation phase | ||
n_iw_sampled_ex_for_generation: The number of importance-sampled sequences to use for the generation phase (this | ||
is a replacement for quantile sampling) | ||
n_top_ex_for_scoring: The number of top sequences to use for scoring | ||
n_random_ex_for_scoring: The number of random sequences to use for scoring | ||
n_iw_sampled_ex_for_scoring: The number of importance-sampled sequences to use for scoring | ||
""" | ||
|
||
# Important stuff | ||
model_name: str | ||
n_latents: int | None = None | ||
override_latents: list[int] | None = None | ||
seed: int = 0 | ||
|
||
# Main stuff | ||
buffer: int = 10 | ||
no_overlap: bool = True | ||
act_threshold_frac: float = 0.01 | ||
total_tokens: int = 10_000_000 | ||
batch_size: int = 512 # split up total tokens into batches of this size | ||
scoring: bool = True | ||
max_tokens_in_explanation: int = 30 | ||
use_demos_in_explanation: bool = True | ||
|
||
# Sequences included in generation phase | ||
n_top_ex_for_generation: int = 10 | ||
n_iw_sampled_ex_for_generation: int = 5 | ||
|
||
# Sequences included in scoring phase | ||
n_top_ex_for_scoring: int = 4 | ||
n_random_ex_for_scoring: int = 10 | ||
n_iw_sampled_ex_for_scoring: int = 0 | ||
|
||
def __post_init__(self): | ||
if self.n_latents is None: | ||
assert self.override_latents is not None | ||
self.latents = self.override_latents | ||
self.n_latents = len(self.latents) | ||
else: | ||
assert self.override_latents is None | ||
self.latents = None | ||
|
||
@property | ||
def n_top_ex(self): | ||
"""When fetching data, we get the top examples for generation & scoring simultaneously.""" | ||
return self.n_top_ex_for_generation + self.n_top_ex_for_scoring | ||
|
||
@property | ||
def max_tokens_in_prediction(self) -> int: | ||
"""Predictions take the form of comma-separated numbers, which should all be single tokens.""" | ||
return 2 * self.n_ex_for_scoring + 5 | ||
|
||
@property | ||
def n_ex_for_generation(self) -> int: | ||
return self.n_top_ex_for_generation + self.n_iw_sampled_ex_for_generation | ||
|
||
@property | ||
def n_ex_for_scoring(self) -> int: | ||
"""For scoring phase, we use a randomly shuffled mix of top-k activations and random sequences.""" | ||
return self.n_top_ex_for_scoring + self.n_random_ex_for_scoring + self.n_iw_sampled_ex_for_scoring | ||
|
||
@property | ||
def n_iw_sampled_ex(self) -> int: | ||
return self.n_iw_sampled_ex_for_generation + self.n_iw_sampled_ex_for_scoring | ||
|
||
@property | ||
def n_correct_for_scoring(self) -> int: | ||
return self.n_top_ex_for_scoring + self.n_iw_sampled_ex_for_scoring |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import argparse | ||
from pathlib import Path | ||
|
||
import torch | ||
from evals.autointerp.config import AutoInterpConfig | ||
from evals.autointerp.main import run_eval | ||
|
||
# Set up command-line argument parsing | ||
parser = argparse.ArgumentParser(description="Run AutoInterp evaluation.") | ||
parser.add_argument( | ||
"--api_key", type=str, required=True, help="API key for the evaluation." | ||
) | ||
args = parser.parse_args() | ||
|
||
api_key = args.api_key # Use the API key supplied via command line | ||
|
||
device = torch.device( | ||
"mps" | ||
if torch.backends.mps.is_available() | ||
else "cuda" | ||
if torch.cuda.is_available() | ||
else "cpu" | ||
) | ||
|
||
selected_saes_dict = { | ||
"gpt2-small-res-jb": ["blocks.7.hook_resid_pre"], | ||
} | ||
torch.set_grad_enabled(False) | ||
|
||
# ! Demo 1: just 4 specially chosen latents | ||
cfg = AutoInterpConfig(model_name="gpt2-small", override_latents=[9, 11, 15, 16873]) | ||
save_logs_path = Path(__file__).parent / "logs_4.txt" | ||
save_logs_path.unlink(missing_ok=True) | ||
results = run_eval( | ||
cfg, selected_saes_dict, device, api_key, save_logs_path=save_logs_path | ||
) | ||
print(results) | ||
|
||
# ! Demo 2: 100 randomly chosen latents | ||
cfg = AutoInterpConfig(model_name="gpt2-small", n_latents=100) | ||
save_logs_path = Path(__file__).parent / "logs_100.txt" | ||
save_logs_path.unlink(missing_ok=True) | ||
results = run_eval( | ||
cfg, selected_saes_dict, device, api_key, save_logs_path=save_logs_path | ||
) | ||
print(results) | ||
|
||
# python demo.py --api_key "YOUR_API_KEY" |
Oops, something went wrong.