Skip to content

Commit

Permalink
Merge branch 'main' into packaging
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Jan 9, 2025
2 parents bb10234 + 9bbfdc5 commit 9bc22a4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 32 deletions.
4 changes: 3 additions & 1 deletion sae_bench/custom_saes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ There are a few requirements for the SAE object. If your SAE object inherits `Ba
- `W_dec` should have unit norm decoder vectors. Some SAE trainers do not enforce this. `BaseSAE` has a function `check_decoder_norms()`, which we recommend calling when loading the SAE. For an example of how to fix this, refer to `normalize_decoder()` in `relu_sae.py`.
- The SAE must have a `dtype` and `device` attribute.
- The SAE must have a `.cfg` field, which contains attributes like `d_sae` and `d_in`. The core evals utilize SAE Lens internals, and require a handful of blank fields, which are already set in the `CustomSaeConfig` dataclass.
- In general, just pattern match to the `jump_relu` and `relu` implementations if adding your own.
- In general, we recommend modifying an existing SAE class, such as the `relu_sae.py` class. You will have to modify `encode()` and `decode()`, and will probably have to add a function to load your state dict.

Refer to `SAEBench/sae_bench_demo.ipynb` for an example of how to compare a custom SAE with a baseline SAE and create some graphs. There is also a cell demonstrating how to run all evals on a selection of SAEs.

If your SAEs are trained with the [dictionary_learning repo](https://github.com/saprmarks/dictionary_learning), you can evaluate your SAEs by passing in the name of the HuggingFace repo containing your SAEs. Refer to `SAEBench/custom_saes/run_all_evals_dictionary_learning_saes.py`.

If you want a python script to evaluate your custom SAEs, refer to `run_all_evals_custom_saes.py`.

If there are any pain points when using this repo with custom SAEs, please do not hesitate to reach out or raise an issue.
70 changes: 39 additions & 31 deletions sae_bench_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
"hook_layer = 4\n",
"hook_name = f\"blocks.{hook_layer}.hook_resid_post\"\n",
"\n",
"sae = relu_sae.load_dictionary_learning_relu_sae(repo_id, baseline_filename, hook_layer, model_name, device, torch_dtype)\n",
"sae = relu_sae.load_dictionary_learning_relu_sae(repo_id, baseline_filename, model_name, device, torch_dtype, layer=hook_layer)\n",
"\n",
"print(f\"sae dtype: {sae.dtype}, device: {sae.device}\")\n",
"\n",
Expand All @@ -109,7 +109,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In our sae object we need to have a CustomSAEConfig. This contains some information which is used by the evals (hook_name, hook_layer, model_name, d_sae, etc). In addition, it contains information that is used by our plotting functions, like number of training tokens and architecture. For example, we should have the sae.cfg.architecture defined if we want to plot multiple SAE architectures."
"In our sae object we need to have a CustomSAEConfig. This contains some information which is used by the evals (hook_name, hook_layer, model_name, d_sae, etc). In addition, it contains information that is used by our plotting functions, like number of training tokens and architecture. For example, we should have the sae.cfg.architecture defined if we want to plot multiple SAE architectures.\n",
"\n",
"Note: Everything in this cell, except `architecture` and `training_tokens`, is done in the `BaseSAE` class that the `ReluSAE` inherits from. Because of this, we recommend that you modify an existing SAE class."
]
},
{
Expand Down Expand Up @@ -280,9 +282,6 @@
"metadata": {},
"outputs": [],
"source": [
"eval_path = output_folders[\"sparse_probing\"]\n",
"\n",
"core_results_path = output_folders[\"core\"]\n",
"image_path = \"./images\"\n",
"\n",
"if not os.path.exists(image_path):\n",
Expand All @@ -295,17 +294,21 @@
"metadata": {},
"outputs": [],
"source": [
"custom_sae_ids = []\n",
"results_folders = [\"./eval_results\"]\n",
"\n",
"for sae_id, sae in custom_saes:\n",
" custom_sae_ids.append((sae_id, \"custom_sae\"))\n",
"eval_type = \"sparse_probing\"\n",
"\n",
"sae_lens_ids = []\n",
"eval_folders = []\n",
"core_folders = []\n",
"for results_folder in results_folders:\n",
" eval_folders.append(f\"{results_folder}/{eval_type}\")\n",
" core_folders.append(f\"{results_folder}/core\")\n",
"\n",
"for sae_id, sae_release in baseline_saes:\n",
" sae_lens_ids.append((sae_id, sae_release))\n",
"eval_filenames = graphing_utils.find_eval_results_files(eval_folders)\n",
"core_filenames = graphing_utils.find_eval_results_files(core_folders)\n",
"\n",
"graphing_sae_ids = custom_sae_ids + sae_lens_ids"
"print(f\"eval_filenames: {eval_filenames}\")\n",
"print(f\"core_filenames: {core_filenames}\")"
]
},
{
Expand All @@ -321,9 +324,14 @@
"metadata": {},
"outputs": [],
"source": [
"raw_results_dict = graphing_utils.get_results_dict(graphing_sae_ids, eval_path, core_results_path)\n",
"eval_results_dict = graphing_utils.get_eval_results(eval_filenames)\n",
"core_results_dict = graphing_utils.get_eval_results(core_filenames)\n",
"\n",
"for sae in eval_results_dict:\n",
" eval_results_dict[sae].update(core_results_dict[sae])\n",
"\n",
"print(raw_results_dict.keys())"
"\n",
"print(eval_results_dict.keys())"
]
},
{
Expand All @@ -332,20 +340,12 @@
"metadata": {},
"outputs": [],
"source": [
"custom_sae_id = f\"{custom_sae_ids[0][0]}_{custom_sae_ids[0][1]}\".replace(\".\", \"_\")\n",
"baseline_sae_id = f\"{sae_lens_ids[0][0]}_{sae_lens_ids[0][1]}\"\n",
"\n",
"\n",
"baseline_filename = f\"{sae_lens_ids[0][0]}_{sae_lens_ids[0][1]}_eval_results.json\".replace(\"/\", \"_\")\n",
"baseline_filepath = os.path.join(eval_path, baseline_filename)\n",
"baseline_filepath = eval_filenames[0]\n",
"\n",
"with open(baseline_filepath, \"r\") as f:\n",
" baseline_sae_eval_results = json.load(f)\n",
"\n",
"custom_filename = f\"{custom_sae_ids[0][0]}_{custom_sae_ids[0][1]}_eval_results.json\".replace(\n",
" \"/\", \"_\"\n",
")\n",
"custom_filepath = os.path.join(eval_path, custom_filename)\n",
"custom_filepath = eval_filenames[1]\n",
"\n",
"with open(custom_filepath, \"r\") as f:\n",
" custom_sae_eval_results = json.load(f)\n",
Expand All @@ -363,7 +363,7 @@
" custom_sae_eval_results[\"eval_result_metrics\"][\"sae\"][f\"sae_top_{k}_test_accuracy\"],\n",
")\n",
"print(\n",
" f\"LLM top {k} accuracy was:\",\n",
" f\"LLM residual stream top {k} accuracy was:\",\n",
" baseline_sae_eval_results[\"eval_result_metrics\"][\"llm\"][f\"llm_top_{k}_test_accuracy\"],\n",
")"
]
Expand All @@ -384,9 +384,9 @@
"image_base_name = os.path.join(image_path, \"sparse_probing\")\n",
"\n",
"graphing_utils.plot_results(\n",
" graphing_sae_ids,\n",
" eval_path,\n",
" core_results_path,\n",
" eval_filenames,\n",
" core_filenames,\n",
" eval_type,\n",
" image_base_name,\n",
" k,\n",
" trainer_markers=trainer_markers,\n",
Expand Down Expand Up @@ -429,10 +429,18 @@
"outputs": [],
"source": [
"for eval_type in eval_types:\n",
"\n",
" eval_folders = []\n",
"\n",
" for results_folder in results_folders:\n",
" eval_folders.append(f\"{results_folder}/{eval_type}\")\n",
"\n",
" eval_filenames = graphing_utils.find_eval_results_files(eval_folders)\n",
"\n",
" graphing_utils.plot_results(\n",
" graphing_sae_ids,\n",
" output_folders[eval_type],\n",
" core_results_path,\n",
" eval_filenames,\n",
" core_filenames,\n",
" eval_type,\n",
" image_base_name,\n",
" k=10,\n",
" trainer_markers=trainer_markers,\n",
Expand Down

0 comments on commit 9bc22a4

Please sign in to comment.