Skip to content

Commit

Permalink
By default we don't use a threshold for custom topk SAEs
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen committed Feb 18, 2025
1 parent 0888d07 commit 60579ed
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions sae_bench/custom_saes/topk_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def load_dictionary_learning_topk_sae(
dtype: torch.dtype,
layer: int | None = None,
local_dir: str = "downloaded_saes",
use_threshold_at_inference: bool = False,
) -> TopKSAE:
assert "ae.pt" in filename

Expand Down Expand Up @@ -122,9 +123,7 @@ def load_dictionary_learning_topk_sae(
"k": "k",
}

use_threshold = "threshold" in pt_params

if use_threshold:
if "threshold" in pt_params:
key_mapping["threshold"] = "threshold"

# Create a new dictionary with renamed keys
Expand All @@ -145,7 +144,7 @@ def load_dictionary_learning_topk_sae(
hook_layer=layer, # type: ignore
device=device,
dtype=dtype,
use_threshold=use_threshold,
use_threshold=use_threshold_at_inference,
)

sae.load_state_dict(renamed_params)
Expand Down

0 comments on commit 60579ed

Please sign in to comment.