Skip to content

Commit

Permalink
Define llm dtype in activation_collection.py
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen committed Oct 1, 2024
1 parent 9b49ec4 commit 0f29194
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
5 changes: 5 additions & 0 deletions sparse_probing/src/activation_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
"gemma-2-2b": 32,
}

LLM_NAME_TO_DTYPE = {
"pythia-70m-deduped": torch.float32,
"gemma-2-2b": torch.bfloat16,
}


@jaxtyped(typechecker=beartype)
@torch.no_grad
Expand Down
19 changes: 9 additions & 10 deletions sparse_probing/src/eval_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
# I wanted to avoid the pip install -e . in the shared directory, but maybe that's the best way to do it
import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from formatting_utils import filter_sae_names


@dataclass
class EvalConfig:
random_seed: int = 42
model_dtype: torch.dtype = None # Set in __post_init__

dataset_name: str = "bias_in_bios"
chosen_classes: list[str] = field(default_factory=lambda: ["0", "1", "2", "6", "9"])
Expand All @@ -35,7 +35,7 @@ class EvalConfig:
# include_checkpoints: bool = False

## Uncomment to run Gemma SAEs

sae_release: str = "sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824"
model_name: str = "gemma-2-2b"
layers: list[int] = field(default_factory=lambda: [19])
Expand All @@ -47,13 +47,12 @@ class EvalConfig:
saes: list[str] = field(default_factory=list)

def __post_init__(self):
if self.model_name == "pythia-70m-deduped":
self.model_dtype = torch.float32
elif self.model_name == "gemma-2-2b":
self.model_dtype = torch.bfloat16
else:
raise ValueError(f"Unknown model name: {self.model_name}")

self.saes = filter_sae_names(self.sae_release, self.layers, self.trainer_ids, self.include_checkpoints, drop_sae_bench_prefix=True)
self.saes = filter_sae_names(
self.sae_release,
self.layers,
self.trainer_ids,
self.include_checkpoints,
drop_sae_bench_prefix=True,
)

print("SAEs: ", self.saes)
11 changes: 6 additions & 5 deletions sparse_probing/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ def run_eval(
results_dict = {}
results_dict["custom_eval_results"] = {}

llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name]
llm_dtype = activation_collection.LLM_NAME_TO_DTYPE[config.model_name]

model = HookedTransformer.from_pretrained_no_processing(
config.model_name, device=device, dtype=config.model_dtype
config.model_name, device=device, dtype=llm_dtype
)
llm_batch_size = activation_collection.LLM_NAME_TO_BATCH_SIZE[config.model_name]

train_df, test_df = dataset_creation.load_huggingface_dataset(config.dataset_name)
train_data, test_data = dataset_creation.get_multi_label_train_test_data(
Expand Down Expand Up @@ -108,10 +110,10 @@ def run_eval(
sae = sae.to(device=device)

all_sae_train_acts_BF = activation_collection.get_sae_meaned_activations(
all_train_acts_BLD, sae, config.sae_batch_size, config.model_dtype
all_train_acts_BLD, sae, config.sae_batch_size, llm_dtype
)
all_sae_test_acts_BF = activation_collection.get_sae_meaned_activations(
all_test_acts_BLD, sae, config.sae_batch_size, config.model_dtype
all_test_acts_BLD, sae, config.sae_batch_size, llm_dtype
)

sae_probes, sae_test_accuracies = probe_training.train_probe_on_activations(
Expand All @@ -136,7 +138,6 @@ def run_eval(
average_test_accuracy(sae_top_k_test_accuracies)
)

config.model_dtype = str(config.model_dtype) # so it's json serializable
results_dict["custom_eval_config"] = asdict(config)
return results_dict

Expand Down

0 comments on commit 0f29194

Please sign in to comment.