-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 319300f
Showing
5 changed files
with
10,953 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
.venv/ | ||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
sae_lens | ||
transformer-lens | ||
torch==2.3.0 | ||
einops==0.8.0 | ||
matplotlib==3.8.4 | ||
numpy==1.26.4 | ||
pandas==2.1.2 | ||
plotly==5.22.0 | ||
tqdm==4.66.4 | ||
pytest==8.3.2 | ||
nbformat==5.10.4 | ||
ipykernel==6.29.5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"<torch.autograd.grad_mode.set_grad_enabled at 0x13f1b51d0>" | ||
] | ||
}, | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# Standard imports\n", | ||
"import os\n", | ||
"import torch\n", | ||
"from tqdm import tqdm\n", | ||
"import plotly.express as px\n", | ||
"\n", | ||
"# Imports for displaying vis in Colab / notebook\n", | ||
"import webbrowser\n", | ||
"import http.server\n", | ||
"import socketserver\n", | ||
"import threading\n", | ||
"PORT = 8000\n", | ||
"\n", | ||
"torch.set_grad_enabled(False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Device: mps\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"if torch.backends.mps.is_available():\n", | ||
" device = \"mps\"\n", | ||
"else:\n", | ||
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", | ||
"\n", | ||
"\n", | ||
"print(f\"Device: {device}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/Users/adamkarvonen/SAE_bench_template/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | ||
" from .autonotebook import tqdm as notebook_tqdm\n", | ||
"/Users/adamkarvonen/SAE_bench_template/.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", | ||
" warnings.warn(\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Loaded pretrained model pythia-70m-deduped into HookedTransformer\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from datasets import load_dataset \n", | ||
"from transformer_lens import HookedTransformer\n", | ||
"from sae_lens import SAE\n", | ||
"\n", | ||
"model = HookedTransformer.from_pretrained(\"pythia-70m-deduped\", device = device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"SAEConfig(architecture='standard', d_in=512, d_sae=16384, activation_fn_str='topk', apply_b_dec_to_input=True, finetuning_scaling_factor=False, context_size=128, model_name='pythia-70m-deduped', hook_name='blocks.4.hook_resid_post', hook_layer=4, hook_head_index=None, prepend_bos=True, dataset_path='monology/pile-uncopyrighted', dataset_trust_remote_code=False, normalize_activations='none', dtype='float32', device='mps', sae_lens_training_version=None, activation_fn_kwargs={'k': 80}, neuronpedia_id=None, model_from_pretrained_kwargs={})\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)\n", | ||
"# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict\n", | ||
"# We also return the feature sparsities which are stored in HF for convenience. \n", | ||
"sae, cfg_dict, sparsity = SAE.from_pretrained(\n", | ||
" release = \"sae_bench_pythia70m_sweep_topk_ctx128_0730\",\n", | ||
" sae_id = \"blocks.4.hook_resid_post__trainer_10\",\n", | ||
" device = device\n", | ||
")\n", | ||
"sae = sae.to(device=device)\n", | ||
"\n", | ||
"print(sae.cfg)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.8" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.