Skip to content

Commit

Permalink
Remove redundant with no_grad()
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen committed Nov 10, 2024
1 parent 777e9d4 commit 528959f
Showing 1 changed file with 14 additions and 17 deletions.
31 changes: 14 additions & 17 deletions sae_bench_utils/activation_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,28 +120,25 @@ def get_feature_activation_sparsity(
total_tokens = 0

for i in tqdm(range(0, tokens.shape[0], batch_size)):
with torch.no_grad():
tokens_BL = tokens[i : i + batch_size]
_, cache = model.run_with_cache(
tokens_BL, stop_at_layer=layer + 1, names_filter=hook_name
)
resid_BLD: Float[torch.Tensor, "batch pos d_model"] = cache[hook_name]
tokens_BL = tokens[i : i + batch_size]
_, cache = model.run_with_cache(tokens_BL, stop_at_layer=layer + 1, names_filter=hook_name)
resid_BLD: Float[torch.Tensor, "batch pos d_model"] = cache[hook_name]

sae_act_BLF: Float[torch.Tensor, "batch pos d_sae"] = sae.encode(resid_BLD)
# make act to zero or one
sae_act_BLF = (sae_act_BLF > 0).to(dtype=torch.float32)
sae_act_BLF: Float[torch.Tensor, "batch pos d_sae"] = sae.encode(resid_BLD)
# make act to zero or one
sae_act_BLF = (sae_act_BLF > 0).to(dtype=torch.float32)

if mask_bos_pad_eos_tokens:
attn_mask_BL = get_bos_pad_eos_mask(tokens_BL, model.tokenizer)
else:
attn_mask_BL = torch.ones_like(tokens_BL, dtype=torch.bool)
if mask_bos_pad_eos_tokens:
attn_mask_BL = get_bos_pad_eos_mask(tokens_BL, model.tokenizer)
else:
attn_mask_BL = torch.ones_like(tokens_BL, dtype=torch.bool)

sae_act_BLF = sae_act_BLF * attn_mask_BL[:, :, None]
total_tokens += attn_mask_BL.sum().item()
sae_act_BLF = sae_act_BLF * attn_mask_BL[:, :, None]
total_tokens += attn_mask_BL.sum().item()

sae_act_F = einops.reduce(sae_act_BLF, "B L F -> F", "sum")
sae_act_F = einops.reduce(sae_act_BLF, "B L F -> F", "sum")

sae_acts.append(sae_act_F)
sae_acts.append(sae_act_F)

total_sae_acts_F = torch.stack(sae_acts).sum(dim=0)
return total_sae_acts_F / total_tokens
Expand Down

0 comments on commit 528959f

Please sign in to comment.