Skip to content

Commit

Permalink
Make it easier to use get_llm_activations() with other evals
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen committed Nov 10, 2024
1 parent 36fb3ba commit 1ed9a29
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 33 deletions.
4 changes: 2 additions & 2 deletions evals/shift_and_tpp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,10 @@ def get_dataset_activations(
)

all_train_acts_BLD = activation_collection.get_all_llm_activations(
train_data, model, llm_batch_size, layer, hook_point
train_data, model, llm_batch_size, layer, hook_point, mask_bos_pad_eos_tokens=True
)
all_test_acts_BLD = activation_collection.get_all_llm_activations(
test_data, model, llm_batch_size, layer, hook_point
test_data, model, llm_batch_size, layer, hook_point, mask_bos_pad_eos_tokens=True
)

return all_train_acts_BLD, all_test_acts_BLD
Expand Down
4 changes: 2 additions & 2 deletions evals/sparse_probing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def get_dataset_activations(
)

all_train_acts_BLD = activation_collection.get_all_llm_activations(
train_data, model, llm_batch_size, layer, hook_point
train_data, model, llm_batch_size, layer, hook_point, mask_bos_pad_eos_tokens=True
)
all_test_acts_BLD = activation_collection.get_all_llm_activations(
test_data, model, llm_batch_size, layer, hook_point
test_data, model, llm_batch_size, layer, hook_point, mask_bos_pad_eos_tokens=True
)

return all_train_acts_BLD, all_test_acts_BLD
Expand Down
94 changes: 65 additions & 29 deletions sae_bench_utils/activation_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import torch.nn as nn
from tqdm import tqdm
from typing import Callable, Optional
from jaxtyping import Int, Float, jaxtyped, BFloat16
from jaxtyping import Bool, Int, Float, jaxtyped
from beartype import beartype
import einops
from transformer_lens import HookedTransformer
from transformers import AutoTokenizer
from sae_lens import SAE

# Relevant at ctx len 128
Expand All @@ -21,48 +22,83 @@
}


# @jaxtyped(typechecker=beartype) # TODO: jaxtyped struggles with the tokenizer
@torch.no_grad
def get_bos_pad_eos_mask(
tokens: Int[torch.Tensor, "dataset_size seq_len"], tokenizer: AutoTokenizer
) -> Bool[torch.Tensor, "dataset_size seq_len"]:
mask = (
(tokens == tokenizer.pad_token_id)
| (tokens == tokenizer.eos_token_id)
| (tokens == tokenizer.bos_token_id)
).to(dtype=torch.bool)
return ~mask


@jaxtyped(typechecker=beartype)
@torch.no_grad
def get_all_llm_activations(
tokenized_inputs_dict: dict[str, dict[str, Int[torch.Tensor, "dataset_size seq_len"]]],
def get_llm_activations(
tokens: Int[torch.Tensor, "dataset_size seq_len"],
model: HookedTransformer,
batch_size: int,
layer: int,
hook_name: str,
remove_bos_token: bool = True,
) -> dict[str, Float[torch.Tensor, "dataset_size seq_len d_model"]]:
"""VERY IMPORTANT NOTE: We zero out masked token activations in this function. Later, we ignore zeroed activations."""
all_classes_acts_BLD = {}
mask_bos_pad_eos_tokens: bool = False,
) -> Float[torch.Tensor, "dataset_size seq_len d_model"]:
"""Collects activations for an LLM model from a given layer for a given set of tokens.
VERY IMPORTANT NOTE: If mask_bos_pad_eos_tokens is True, we zero out activations for BOS, PAD, and EOS tokens.
Later, we ignore zeroed activations."""

for class_name in tokenized_inputs_dict:
all_acts_BLD = []
tokenized_inputs = tokenized_inputs_dict[class_name]
all_acts_BLD = []

for i in tqdm(
range(0, len(tokens), batch_size),
desc="Collecting activations",
):
tokens_BL = tokens[i : i + batch_size]

acts_BLD = None

def activation_hook(resid_BLD: torch.Tensor, hook):
nonlocal acts_BLD
acts_BLD = resid_BLD

model.run_with_hooks(
tokens_BL, stop_at_layer=layer + 1, fwd_hooks=[(hook_name, activation_hook)]
)

for i in tqdm(
range(0, len(tokenized_inputs["input_ids"]), batch_size),
desc=f"Collecting activations for class {class_name}",
):
tokens_BL = tokenized_inputs["input_ids"][i : i + batch_size]
attention_mask_BL = tokenized_inputs["attention_mask"][i : i + batch_size]
if mask_bos_pad_eos_tokens:
attn_mask_BL = get_bos_pad_eos_mask(tokens_BL, model.tokenizer)
acts_BLD = acts_BLD * attn_mask_BL[:, :, None]

acts_BLD = None
all_acts_BLD.append(acts_BLD)

def activation_hook(resid_BLD: torch.Tensor, hook):
nonlocal acts_BLD
acts_BLD = resid_BLD
return torch.cat(all_acts_BLD, dim=0)

model.run_with_hooks(
tokens_BL, stop_at_layer=layer + 1, fwd_hooks=[(hook_name, activation_hook)]
)

acts_BLD = acts_BLD * attention_mask_BL[:, :, None]
if remove_bos_token:
acts_BLD = acts_BLD[:, 1:, :]
all_acts_BLD.append(acts_BLD)
@jaxtyped(typechecker=beartype)
@torch.no_grad
def get_all_llm_activations(
tokenized_inputs_dict: dict[str, dict[str, Int[torch.Tensor, "dataset_size seq_len"]]],
model: HookedTransformer,
batch_size: int,
layer: int,
hook_name: str,
mask_bos_pad_eos_tokens: bool = False,
) -> dict[str, Float[torch.Tensor, "dataset_size seq_len d_model"]]:
"""If we have a dictionary of tokenized inputs for different classes, this function collects activations for all classes.
We assume that the tokenized inputs have both the input_ids and attention_mask keys.
VERY IMPORTANT NOTE: We zero out masked token activations in this function. Later, we ignore zeroed activations."""
all_classes_acts_BLD = {}

for class_name in tokenized_inputs_dict:
tokens = tokenized_inputs_dict[class_name]["input_ids"]

all_acts_BLD = torch.cat(all_acts_BLD, dim=0)
acts_BLD = get_llm_activations(
tokens, model, batch_size, layer, hook_name, mask_bos_pad_eos_tokens
)

all_classes_acts_BLD[class_name] = all_acts_BLD
all_classes_acts_BLD[class_name] = acts_BLD

return all_classes_acts_BLD

Expand Down

0 comments on commit 1ed9a29

Please sign in to comment.