Skip to content

Commit

Permalink
Merge branch 'main' into callum/autointerp
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen authored Nov 10, 2024
2 parents 9b2e909 + 36fb3ba commit 466a37d
Show file tree
Hide file tree
Showing 151 changed files with 97,432 additions and 1,596 deletions.
19 changes: 18 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,21 @@ poetry.lock
!/sparse_probing/results/example_pythia-70m-deduped_layer_4_eval_results.json
!/sparse_probing/results/example_gemma-2-2b_layer_19_eval_results.json
!/sparse_probing/results/example_gemma-2-2b_layer_19_with_checkpoints_eval_results.json
evals/absorption/results/
Pipfile*
poetry.lock

auth/*

evals/ravel/data/
evals/ravel/models/

evals/absorption/results/

# unlearning: the forget dataset cannot be uploaded
*/unlearning/data/bio-forget-corpus.jsonl


**/images/
**/results/
**/artifacts/
**/test_results/
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,21 @@ Ideally, we would like to see something like `evals.sparse_probing.main.py`, whi

All evals and submodules will share the same dependencies, which are set in pyproject.toml.

For a tutorial of using SAE Lens SAEs, including calculating L0 and Loss Recovered and getting a set of tokens from The Pile, refer to this notebook: https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb
For a tutorial of using SAE Lens SAEs, including calculating L0 and Loss Recovered and getting a set of tokens from The Pile, refer to this notebook: https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb

## Custom SAE Usage

For the sparse probing and SHIFT / TPP evals, we support evaluating any SAE object that has the following implemented, with inputs / outputs matching the SAELens SAE format:

```
sae.encode()
sae.decode()
sae.forward()
sae.W_dec # nn.Parameter(d_sae, d_in), required for SHIFT, TPP, and Feature Absorption
sae.device
sae.dtype
```

Just pass the appropriate inputs to `run_eval_single_sae()`, referring to individual eval READMEs as needed. If you match our output format you can reuse our graphing notebook.

To run our baselines in pythia-70m and gemma-2-2b, refer to `if __name__ == "__main__":` in `shift_and_tpp/main.py`.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"mean_correct": 1.0, "total_correct": 73, "is_correct": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "output_probs": [[0.00011396408081054688, 9.441375732421875e-05, 0.984375, 0.0001659393310546875], [0.0005645751953125, 0.84375, 0.001739501953125, 0.00135040283203125], [0.00057220458984375, 0.9140625, 0.000270843505859375, 0.00023937225341796875], [0.0022125244140625, 0.953125, 0.00183868408203125, 0.00118255615234375], [0.000240325927734375, 0.000186920166015625, 0.000255584716796875, 0.91796875], [0.0003719329833984375, 0.000652313232421875, 0.9765625, 0.00121307373046875], [0.005645751953125, 0.0081787109375, 0.65234375, 0.00171661376953125], [0.0026092529296875, 0.000545501708984375, 0.0004825592041015625, 0.9296875], [5.364418029785156e-05, 0.00014591217041015625, 0.98046875, 0.00147247314453125], [0.000499725341796875, 0.0004425048828125, 0.90625, 0.00136566162109375], [0.765625, 0.0042724609375, 0.0027618408203125, 0.01495361328125], [0.0001735687255859375, 9.870529174804688e-05, 0.0004711151123046875, 0.96484375], [0.000225067138671875, 0.0001983642578125, 0.91796875, 0.0004482269287109375], [0.953125, 0.000766754150390625, 0.00098419189453125, 0.000362396240234375], [8.630752563476562e-05, 9.822845458984375e-05, 0.00016117095947265625, 0.8984375], [0.000644683837890625, 0.0001735687255859375, 0.00011205673217773438, 0.90625], [0.00164794921875, 0.002716064453125, 0.96875, 0.00421142578125], [0.9765625, 0.000370025634765625, 0.0001983642578125, 0.0001201629638671875], [0.98828125, 0.00148773193359375, 0.000659942626953125, 0.0009002685546875], [0.001007080078125, 0.000949859619140625, 0.0025787353515625, 0.9765625], [8.821487426757812e-05, 0.0002536773681640625, 0.97265625, 0.0006103515625], [0.005035400390625, 0.0004138946533203125, 0.9609375, 0.000873565673828125], [0.96875, 0.0001354217529296875, 9.918212890625e-05, 3.886222839355469e-05], [0.0084228515625, 0.58984375, 0.02587890625, 0.0027313232421875], [0.9453125, 0.000629425048828125, 0.000408172607421875, 0.000316619873046875], [0.0145263671875, 0.00604248046875, 0.8984375, 0.0068359375], [0.0013885498046875, 0.0013885498046875, 0.984375, 0.0021514892578125], [0.0059814453125, 0.00408935546875, 0.00408935546875, 0.94140625], [0.00019359588623046875, 0.000339508056640625, 0.0003185272216796875, 0.94921875], [0.0035247802734375, 0.91796875, 0.0011444091796875, 0.000507354736328125], [0.00555419921875, 0.0015869140625, 0.0103759765625, 0.82421875], [0.9921875, 0.00031280517578125, 0.00014781951904296875, 8.392333984375e-05], [0.00109100341796875, 0.000377655029296875, 0.00115966796875, 0.93359375], [0.9765625, 1.9669532775878906e-05, 1.233816146850586e-05, 1.537799835205078e-05], [0.0001888275146484375, 0.0002593994140625, 0.98828125, 0.000659942626953125], [0.94921875, 0.000762939453125, 0.0005950927734375, 0.0003833770751953125], [0.92578125, 0.001312255859375, 0.00061798095703125, 0.0002574920654296875], [5.936622619628906e-05, 3.3855438232421875e-05, 0.95703125, 0.00022125244140625], [0.96875, 0.0002536773681640625, 8.726119995117188e-05, 6.389617919921875e-05], [0.000453948974609375, 0.0002765655517578125, 0.001312255859375, 0.9296875], [0.000637054443359375, 0.000720977783203125, 0.00067901611328125, 0.95703125], [0.01458740234375, 0.006072998046875, 0.796875, 0.0064697265625], [0.0021820068359375, 0.0016937255859375, 0.003173828125, 0.87890625], [0.00130462646484375, 0.8671875, 0.00189971923828125, 0.001678466796875], [0.0005035400390625, 0.91015625, 0.00014400482177734375, 0.0001354217529296875], [0.000782012939453125, 0.9140625, 0.00041961669921875, 0.00019741058349609375], [0.0007781982421875, 0.0002536773681640625, 0.96875, 0.000286102294921875], [0.007049560546875, 0.921875, 0.0062255859375, 0.00244140625], [0.00022029876708984375, 9.775161743164062e-05, 0.000560760498046875, 0.8984375], [0.0004177093505859375, 0.000392913818359375, 0.0018768310546875, 0.97265625], [0.001678466796875, 0.92578125, 0.00189971923828125, 0.001678466796875], [0.0001773834228515625, 0.00012969970703125, 0.0007476806640625, 0.92578125], [0.00136566162109375, 0.96484375, 0.00347900390625, 0.000286102294921875], [0.0013885498046875, 0.87109375, 0.00457763671875, 0.001678466796875], [0.9609375, 0.003936767578125, 0.002532958984375, 0.0019683837890625], [0.9609375, 0.00022220611572265625, 0.00013446807861328125, 0.00023651123046875], [0.000820159912109375, 0.0005645751953125, 0.95703125, 0.0003643035888671875], [0.000141143798828125, 0.0002193450927734375, 0.94921875, 0.000263214111328125], [0.000492095947265625, 0.94921875, 0.000461578369140625, 0.000408172607421875], [0.97265625, 0.000370025634765625, 0.000347137451171875, 0.0002384185791015625], [0.0023956298828125, 0.00186920166015625, 0.96484375, 0.0028839111328125], [0.00011873245239257812, 0.00013446807861328125, 0.00087738037109375, 0.96484375], [0.00104522705078125, 0.89453125, 0.000598907470703125, 0.000720977783203125], [0.0034637451171875, 0.9609375, 0.003692626953125, 0.00106048583984375], [0.000690460205078125, 0.0002880096435546875, 0.9765625, 0.00057220458984375], [6.532669067382812e-05, 0.00015735626220703125, 0.9921875, 0.00013828277587890625], [0.87890625, 0.0023193359375, 0.0010986328125, 0.0006256103515625], [0.0283203125, 0.007171630859375, 0.0086669921875, 0.828125], [0.79296875, 0.0272216796875, 0.03955078125, 0.03955078125], [0.9140625, 0.00128936767578125, 0.000690460205078125, 0.0005035400390625], [0.97265625, 8.7738037109375e-05, 9.965896606445312e-05, 5.6743621826171875e-05], [0.0013427734375, 0.0011138916015625, 0.953125, 0.00162506103515625], [0.032958984375, 0.546875, 0.004730224609375, 0.003692626953125]]}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"mean_correct": 1.0, "total_correct": 9, "is_correct": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "output_probs": [[0.0012359619140625, 0.0264892578125, 0.7734375, 0.0016937255859375], [0.91796875, 0.0016632080078125, 0.00177001953125, 0.00274658203125], [0.0006561279296875, 0.000396728515625, 0.000789642333984375, 0.98046875], [9.918212890625e-05, 0.0001735687255859375, 0.96875, 0.0004444122314453125], [0.0018463134765625, 0.95703125, 0.007293701171875, 0.002227783203125], [0.00144195556640625, 0.004180908203125, 0.84765625, 0.00185394287109375], [0.87109375, 0.0026092529296875, 0.0037841796875, 0.0021514892578125], [0.921875, 0.000698089599609375, 0.00095367431640625, 0.00051116943359375], [0.91796875, 0.003753662109375, 0.0042724609375, 0.003753662109375]]}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"mean_correct": 0.9999999403953552, "total_correct": 107, "is_correct": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "output_probs": [[0.010986328125, 0.76953125, 0.006256103515625, 0.0021514892578125], [0.0003070831298828125, 0.000270843505859375, 0.000537872314453125, 0.8046875], [0.000583648681640625, 0.0004825592041015625, 0.9296875, 0.0004825592041015625], [0.003082275390625, 0.00128936767578125, 0.0012054443359375, 0.85546875], [0.98828125, 0.000213623046875, 0.0001888275146484375, 0.00011444091796875], [0.00028228759765625, 0.000438690185546875, 0.00135040283203125, 0.8984375], [0.037109375, 0.0093994140625, 0.0093994140625, 0.84375], [0.0006103515625, 0.0003490447998046875, 0.000946044921875, 0.9765625], [0.000701904296875, 0.87109375, 0.0009613037109375, 0.000545501708984375], [0.000484466552734375, 0.0001678466796875, 0.93359375, 0.000621795654296875], [0.000274658203125, 0.00017642974853515625, 0.92578125, 0.0003299713134765625], [0.00154876708984375, 0.0028839111328125, 0.0023956298828125, 0.80078125], [0.0001983642578125, 0.9140625, 0.0001983642578125, 0.00018596649169921875], [0.98828125, 5.1021575927734375e-05, 6.151199340820312e-05, 1.0669231414794922e-05], [0.0062255859375, 0.81640625, 0.0021514892578125, 0.0013885498046875], [0.984375, 3.933906555175781e-05, 2.5391578674316406e-05, 3.075599670410156e-05], [0.9921875, 9.5367431640625e-05, 6.532669067382812e-05, 3.2901763916015625e-05], [0.00091552734375, 0.000591278076171875, 0.9453125, 0.0009765625], [0.0002880096435546875, 0.000392913818359375, 0.97265625, 0.000270843505859375], [0.0028839111328125, 0.0017547607421875, 0.0032806396484375, 0.8515625], [0.9375, 0.000148773193359375, 9.584426879882812e-05, 7.009506225585938e-05], [0.015869140625, 0.71875, 0.0062255859375, 0.004547119140625], [0.000347137451171875, 0.0003261566162109375, 0.000446319580078125, 0.859375], [0.84765625, 0.0009918212890625, 0.007354736328125, 0.000640869140625], [0.96484375, 8.678436279296875e-05, 6.341934204101562e-05, 2.3365020751953125e-05], [0.002166748046875, 0.00139617919921875, 0.002777099609375, 0.87109375], [0.95703125, 7.200241088867188e-05, 3.600120544433594e-05, 4.935264587402344e-05], [0.00171661376953125, 0.0024871826171875, 0.0072021484375, 0.9453125], [0.00016117095947265625, 0.000125885009765625, 0.0002346038818359375, 0.95703125], [0.005401611328125, 0.00154876708984375, 0.91015625, 0.00083160400390625], [8.487701416015625e-05, 0.00016880035400390625, 0.94140625, 0.00019168853759765625], [9.965896606445312e-05, 0.00010585784912109375, 0.00021076202392578125, 0.8046875], [0.0026702880859375, 0.83984375, 0.002349853515625, 0.0011138916015625], [0.00011920928955078125, 0.8515625, 0.0004711151123046875, 0.00025177001953125], [0.8984375, 0.00049591064453125, 0.0004673004150390625, 0.000438690185546875], [0.00018787384033203125, 0.0001468658447265625, 0.000423431396484375, 0.984375], [0.00225830078125, 0.0017547607421875, 0.96875, 0.001983642578125], [0.002471923828125, 0.0013275146484375, 0.01043701171875, 0.94140625], [0.8125, 0.004547119140625, 0.0024261474609375, 0.0021514892578125], [0.006103515625, 0.01373291015625, 0.002105712890625, 0.62109375], [0.00421142578125, 0.70703125, 0.00186920166015625, 0.00093841552734375], [0.0020294189453125, 0.92578125, 0.00130462646484375, 0.00048065185546875], [0.00016689300537109375, 0.0001220703125, 0.875, 0.0002593994140625], [0.001495361328125, 0.93359375, 0.0004558563232421875, 0.000377655029296875], [3.5762786865234375e-05, 2.0265579223632812e-05, 0.9453125, 0.00021839141845703125], [0.0001850128173828125, 0.00012683868408203125, 0.0004444122314453125, 0.96875], [0.033203125, 0.0213623046875, 0.005096435546875, 0.75390625], [0.00421142578125, 0.8515625, 0.001861572265625, 0.00099945068359375], [0.00023555755615234375, 9.822845458984375e-05, 0.9609375, 6.771087646484375e-05], [0.00016498565673828125, 3.695487976074219e-05, 0.98046875, 0.000507354736328125], [0.005096435546875, 0.00146484375, 0.97265625, 0.000782012939453125], [0.0018157958984375, 0.77734375, 0.0018157958984375, 0.00070953369140625], [0.0010223388671875, 0.0002593994140625, 0.0007476806640625, 0.9296875], [0.2734375, 0.047607421875, 0.03076171875, 0.2138671875], [0.00081634521484375, 0.000339508056640625, 0.001953125, 0.7890625], [0.97265625, 3.886222839355469e-05, 6.818771362304688e-05, 3.6716461181640625e-05], [0.00112152099609375, 0.0009307861328125, 0.006439208984375, 0.95703125], [0.00142669677734375, 0.7890625, 0.0004634857177734375, 0.0002651214599609375], [0.0028533935546875, 0.8984375, 0.00173187255859375, 0.00112152099609375], [0.00028228759765625, 0.001190185546875, 0.0028533935546875, 0.83984375], [0.01068115234375, 0.0014495849609375, 0.0037078857421875, 0.96484375], [0.0006866455078125, 0.00225830078125, 0.0025634765625, 0.85546875], [0.01953125, 0.0021820068359375, 0.006317138671875, 0.828125], [0.0020599365234375, 0.001708984375, 0.01043701171875, 0.94140625], [0.96875, 0.0032806396484375, 0.00225830078125, 0.0012054443359375], [0.000823974609375, 0.9609375, 0.00060272216796875, 0.0002841949462890625], [0.00115203857421875, 0.000579833984375, 0.0027618408203125, 0.984375], [0.000213623046875, 0.0001010894775390625, 0.0001010894775390625, 0.87109375], [0.00010395050048828125, 0.0002346038818359375, 0.953125, 0.00020694732666015625], [0.0013580322265625, 0.0010528564453125, 0.004730224609375, 0.95703125], [0.0012054443359375, 0.0004444122314453125, 0.91015625, 0.0004444122314453125], [0.00102996826171875, 0.00021648406982421875, 0.94140625, 0.00040435791015625], [0.000888824462890625, 0.9765625, 0.00016498565673828125, 9.393692016601562e-05], [0.00121307373046875, 0.001373291015625, 0.00543212890625, 0.97265625], [0.00101470947265625, 0.921875, 0.00115203857421875, 0.0006561279296875], [0.0011138916015625, 0.7890625, 0.0011138916015625, 0.000812530517578125], [0.00628662109375, 0.0010223388671875, 0.9296875, 0.0024566650390625], [0.96484375, 0.0001850128173828125, 0.0002689361572265625, 0.00010538101196289062], [0.000568389892578125, 0.00026702880859375, 0.0009918212890625, 0.90625], [0.95703125, 0.000873565673828125, 0.000720977783203125, 0.000530242919921875], [0.00909423828125, 0.00970458984375, 0.004302978515625, 0.38671875], [0.00035858154296875, 0.88671875, 0.00035858154296875, 0.00014972686767578125], [0.004791259765625, 0.0025634765625, 0.10888671875, 0.8046875], [0.8984375, 0.002227783203125, 0.001190185546875, 0.0005645751953125], [0.00017261505126953125, 0.0001621246337890625, 0.0003643035888671875, 0.9609375], [0.000560760498046875, 0.000247955322265625, 0.00098419189453125, 0.94921875], [0.006988525390625, 0.006988525390625, 0.0189208984375, 0.859375], [0.9296875, 0.0023040771484375, 0.00139617919921875, 0.0007476806640625], [0.98828125, 6.151199340820312e-05, 3.719329833984375e-05, 2.4080276489257812e-05], [0.00016021728515625, 5.888938903808594e-05, 8.058547973632812e-05, 0.94921875], [6.723403930664062e-05, 6.29425048828125e-05, 0.953125, 0.0002346038818359375], [0.00015354156494140625, 0.0004444122314453125, 0.0004444122314453125, 0.91015625], [0.0047607421875, 0.002716064453125, 0.0107421875, 0.85546875], [0.953125, 0.002227783203125, 0.00152587890625, 0.00098419189453125], [0.000202178955078125, 0.000202178955078125, 0.00054931640625, 0.9921875], [0.00982666015625, 0.88671875, 0.034423828125, 0.022216796875], [0.0005340576171875, 0.90625, 0.0004711151123046875, 0.000324249267578125], [0.000652313232421875, 0.000255584716796875, 0.000614166259765625, 0.9765625], [0.00031280517578125, 0.9296875, 9.5367431640625e-05, 3.743171691894531e-05], [0.0004367828369140625, 0.00017070770263671875, 0.00063323974609375, 0.953125], [0.8125, 0.0003509521484375, 0.000396728515625, 0.000576019287109375], [0.984375, 0.0003509521484375, 0.0003299713134765625, 0.0002574920654296875], [0.00372314453125, 0.0027313232421875, 0.00787353515625, 0.9140625], [0.0013885498046875, 0.8671875, 0.000579833984375, 0.000308990478515625], [0.00032806396484375, 7.772445678710938e-05, 0.9765625, 0.00012874603271484375], [0.98046875, 0.0008392333984375, 0.000614166259765625, 0.00032806396484375], [0.0004119873046875, 0.6171875, 0.00022029876708984375, 0.00018310546875]]}
Loading

0 comments on commit 466a37d

Please sign in to comment.