From 4c46da6610b3bba7df2910cd0397a9926c4a9264 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Thu, 20 Feb 2025 10:54:52 -0800 Subject: [PATCH] chore: fix updated torch types --- sae_bench/custom_saes/batch_topk_sae.py | 5 ++++- sae_bench/custom_saes/topk_sae.py | 5 ++++- sae_bench/evals/unlearning/utils/feature_activation.py | 2 +- sae_bench/sae_bench_utils/activation_collection.py | 4 ++-- tests/unit/evals/absorption/test_k_sparse_probing.py | 4 ++-- 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/sae_bench/custom_saes/batch_topk_sae.py b/sae_bench/custom_saes/batch_topk_sae.py index 37192b4..a0fb493 100644 --- a/sae_bench/custom_saes/batch_topk_sae.py +++ b/sae_bench/custom_saes/batch_topk_sae.py @@ -8,6 +8,9 @@ class BatchTopKSAE(base_sae.BaseSAE): + threshold: torch.Tensor + k: torch.Tensor + def __init__( self, d_in: int, @@ -47,7 +50,7 @@ def encode(self, x: torch.Tensor): ) return encoded_acts_BF - post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) + post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) # type: ignore tops_acts_BK = post_topk.values top_indices_BK = post_topk.indices diff --git a/sae_bench/custom_saes/topk_sae.py b/sae_bench/custom_saes/topk_sae.py index 7b70707..ab600de 100644 --- a/sae_bench/custom_saes/topk_sae.py +++ b/sae_bench/custom_saes/topk_sae.py @@ -8,6 +8,9 @@ class TopKSAE(base_sae.BaseSAE): + threshold: torch.Tensor + k: torch.Tensor + def __init__( self, d_in: int, @@ -49,7 +52,7 @@ def encode(self, x: torch.Tensor): ) return encoded_acts_BF - post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) + post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) # type: ignore tops_acts_BK = post_topk.values top_indices_BK = post_topk.indices diff --git a/sae_bench/evals/unlearning/utils/feature_activation.py b/sae_bench/evals/unlearning/utils/feature_activation.py index bbcf050..bfccd2d 100644 --- a/sae_bench/evals/unlearning/utils/feature_activation.py +++ b/sae_bench/evals/unlearning/utils/feature_activation.py @@ -95,7 +95,7 @@ def gather_target_act_hook(mod, inputs, outputs): target_act = outputs[0] return outputs - handle = model.model.layers[target_layer].register_forward_hook( + handle = model.model.layers[target_layer].register_forward_hook( # type: ignore gather_target_act_hook ) _ = model.forward(inputs) # type: ignore diff --git a/sae_bench/sae_bench_utils/activation_collection.py b/sae_bench/sae_bench_utils/activation_collection.py index 0b2f303..4041e49 100644 --- a/sae_bench/sae_bench_utils/activation_collection.py +++ b/sae_bench/sae_bench_utils/activation_collection.py @@ -80,7 +80,7 @@ def activation_hook(resid_BLD: torch.Tensor, hook): if mask_bos_pad_eos_tokens: attn_mask_BL = get_bos_pad_eos_mask(tokens_BL, model.tokenizer) - acts_BLD = acts_BLD * attn_mask_BL[:, :, None] + acts_BLD = acts_BLD * attn_mask_BL[:, :, None] # type: ignore all_acts_BLD.append(acts_BLD) @@ -375,7 +375,7 @@ def encode_precomputed_activations( sae_act_BLF = sae_act_BLF[:, :, selected_latents] if mask_bos_pad_eos_tokens: - attn_mask_BL = get_bos_pad_eos_mask(tokens_BL, sae.model.tokenizer) + attn_mask_BL = get_bos_pad_eos_mask(tokens_BL, sae.model.tokenizer) # type: ignore else: attn_mask_BL = torch.ones_like(tokens_BL, dtype=torch.bool) diff --git a/tests/unit/evals/absorption/test_k_sparse_probing.py b/tests/unit/evals/absorption/test_k_sparse_probing.py index e90254e..6640c57 100644 --- a/tests/unit/evals/absorption/test_k_sparse_probing.py +++ b/tests/unit/evals/absorption/test_k_sparse_probing.py @@ -13,9 +13,9 @@ def test_train_sparse_multi_probe_results_in_many_zero_weights(): torch.set_grad_enabled(True) - x = torch.rand(1000, 100) + x = torch.rand(1000, 200) y = torch.randint(2, (1000, 3)) - probe1 = train_sparse_multi_probe(x, y, l1_decay=0.01, device=torch.device("cpu")) + probe1 = train_sparse_multi_probe(x, y, l1_decay=0.03, device=torch.device("cpu")) probe2 = train_sparse_multi_probe(x, y, l1_decay=0.1, device=torch.device("cpu")) probe1_zero_weights = (probe1.weights.abs() < 1e-5).sum()