Skip to content

Commit

Permalink
add sae encode function
Browse files Browse the repository at this point in the history
  • Loading branch information
callummcdougall committed Oct 19, 2024
1 parent 4b23575 commit b247f91
Show file tree
Hide file tree
Showing 6 changed files with 599 additions and 442 deletions.
1 change: 1 addition & 0 deletions evals/autointerp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ There are 4 Python files in this folder:
- `config.py` - this contains the config class for AutoInterp.
- `main.py` - this contains the main `AutoInterp` class, as well as the functions which are the interface to the rest of the SAEBench codebase.
- `demo.py` - you can run this via `python demo.py --api_key YOUR_API_KEY` to see an example output & how the function works. It creates & saves a log file (I've left the output of those files in the repo, so you can see what they look like).
- `sae_encode.py` - this contains a temporary replacement for the `encode` method in SAELens, until [my PR](https://github.com/jbloomAus/SAELens/pull/334) is merged.

## Summary of how it works

Expand Down
21 changes: 5 additions & 16 deletions evals/autointerp/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,18 @@
from pathlib import Path

import torch

from evals.autointerp.config import AutoInterpConfig
from evals.autointerp.main import run_eval

# Set up command-line argument parsing
parser = argparse.ArgumentParser(description="Run AutoInterp evaluation.")
parser.add_argument(
"--api_key", type=str, required=True, help="API key for the evaluation."
)
parser.add_argument("--api_key", type=str, required=True, help="API key for the evaluation.")
args = parser.parse_args()

api_key = args.api_key # Use the API key supplied via command line

device = torch.device(
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

selected_saes_dict = {
"gpt2-small-res-jb": ["blocks.7.hook_resid_pre"],
Expand All @@ -31,18 +24,14 @@
cfg = AutoInterpConfig(model_name="gpt2-small", override_latents=[9, 11, 15, 16873])
save_logs_path = Path(__file__).parent / "logs_4.txt"
save_logs_path.unlink(missing_ok=True)
results = run_eval(
cfg, selected_saes_dict, device, api_key, save_logs_path=save_logs_path
)
results = run_eval(cfg, selected_saes_dict, str(device), api_key, save_logs_path=save_logs_path)
print(results)

# ! Demo 2: 100 randomly chosen latents
cfg = AutoInterpConfig(model_name="gpt2-small", n_latents=100)
save_logs_path = Path(__file__).parent / "logs_100.txt"
save_logs_path.unlink(missing_ok=True)
results = run_eval(
cfg, selected_saes_dict, device, api_key, save_logs_path=save_logs_path
)
results = run_eval(cfg, selected_saes_dict, str(device), api_key, save_logs_path=save_logs_path)
print(results)

# python demo.py --api_key "YOUR_API_KEY"
Loading

0 comments on commit b247f91

Please sign in to comment.