diff --git a/sae_bench_utils/activation_collection.py b/sae_bench_utils/activation_collection.py index 835aa2d..e472b7b 100644 --- a/sae_bench_utils/activation_collection.py +++ b/sae_bench_utils/activation_collection.py @@ -120,28 +120,25 @@ def get_feature_activation_sparsity( total_tokens = 0 for i in tqdm(range(0, tokens.shape[0], batch_size)): - with torch.no_grad(): - tokens_BL = tokens[i : i + batch_size] - _, cache = model.run_with_cache( - tokens_BL, stop_at_layer=layer + 1, names_filter=hook_name - ) - resid_BLD: Float[torch.Tensor, "batch pos d_model"] = cache[hook_name] + tokens_BL = tokens[i : i + batch_size] + _, cache = model.run_with_cache(tokens_BL, stop_at_layer=layer + 1, names_filter=hook_name) + resid_BLD: Float[torch.Tensor, "batch pos d_model"] = cache[hook_name] - sae_act_BLF: Float[torch.Tensor, "batch pos d_sae"] = sae.encode(resid_BLD) - # make act to zero or one - sae_act_BLF = (sae_act_BLF > 0).to(dtype=torch.float32) + sae_act_BLF: Float[torch.Tensor, "batch pos d_sae"] = sae.encode(resid_BLD) + # make act to zero or one + sae_act_BLF = (sae_act_BLF > 0).to(dtype=torch.float32) - if mask_bos_pad_eos_tokens: - attn_mask_BL = get_bos_pad_eos_mask(tokens_BL, model.tokenizer) - else: - attn_mask_BL = torch.ones_like(tokens_BL, dtype=torch.bool) + if mask_bos_pad_eos_tokens: + attn_mask_BL = get_bos_pad_eos_mask(tokens_BL, model.tokenizer) + else: + attn_mask_BL = torch.ones_like(tokens_BL, dtype=torch.bool) - sae_act_BLF = sae_act_BLF * attn_mask_BL[:, :, None] - total_tokens += attn_mask_BL.sum().item() + sae_act_BLF = sae_act_BLF * attn_mask_BL[:, :, None] + total_tokens += attn_mask_BL.sum().item() - sae_act_F = einops.reduce(sae_act_BLF, "B L F -> F", "sum") + sae_act_F = einops.reduce(sae_act_BLF, "B L F -> F", "sum") - sae_acts.append(sae_act_F) + sae_acts.append(sae_act_F) total_sae_acts_F = torch.stack(sae_acts).sum(dim=0) return total_sae_acts_F / total_tokens