Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen committed Sep 26, 2024
0 parents commit 319300f
Show file tree
Hide file tree
Showing 5 changed files with 10,953 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.venv/
.DS_Store
12 changes: 12 additions & 0 deletions sparse_probing/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
sae_lens
transformer-lens
torch==2.3.0
einops==0.8.0
matplotlib==3.8.4
numpy==1.26.4
pandas==2.1.2
plotly==5.22.0
tqdm==4.66.4
pytest==8.3.2
nbformat==5.10.4
ipykernel==6.29.5
156 changes: 156 additions & 0 deletions sparse_probing/src/probe_training.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch.autograd.grad_mode.set_grad_enabled at 0x13f1b51d0>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Standard imports\n",
"import os\n",
"import torch\n",
"from tqdm import tqdm\n",
"import plotly.express as px\n",
"\n",
"# Imports for displaying vis in Colab / notebook\n",
"import webbrowser\n",
"import http.server\n",
"import socketserver\n",
"import threading\n",
"PORT = 8000\n",
"\n",
"torch.set_grad_enabled(False)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Device: mps\n"
]
}
],
"source": [
"if torch.backends.mps.is_available():\n",
" device = \"mps\"\n",
"else:\n",
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"\n",
"print(f\"Device: {device}\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/adamkarvonen/SAE_bench_template/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"/Users/adamkarvonen/SAE_bench_template/.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded pretrained model pythia-70m-deduped into HookedTransformer\n"
]
}
],
"source": [
"from datasets import load_dataset \n",
"from transformer_lens import HookedTransformer\n",
"from sae_lens import SAE\n",
"\n",
"model = HookedTransformer.from_pretrained(\"pythia-70m-deduped\", device = device)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SAEConfig(architecture='standard', d_in=512, d_sae=16384, activation_fn_str='topk', apply_b_dec_to_input=True, finetuning_scaling_factor=False, context_size=128, model_name='pythia-70m-deduped', hook_name='blocks.4.hook_resid_post', hook_layer=4, hook_head_index=None, prepend_bos=True, dataset_path='monology/pile-uncopyrighted', dataset_trust_remote_code=False, normalize_activations='none', dtype='float32', device='mps', sae_lens_training_version=None, activation_fn_kwargs={'k': 80}, neuronpedia_id=None, model_from_pretrained_kwargs={})\n"
]
}
],
"source": [
"# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)\n",
"# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict\n",
"# We also return the feature sparsities which are stored in HF for convenience. \n",
"sae, cfg_dict, sparsity = SAE.from_pretrained(\n",
" release = \"sae_bench_pythia70m_sweep_topk_ctx128_0730\",\n",
" sae_id = \"blocks.4.hook_resid_post__trainer_10\",\n",
" device = device\n",
")\n",
"sae = sae.to(device=device)\n",
"\n",
"print(sae.cfg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 319300f

Please sign in to comment.