Skip to content

Commit

Permalink
Lower peak memory usage to fit on a 3090
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen committed Oct 22, 2024
1 parent 46d9510 commit 1fecf15
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 19 deletions.
2 changes: 1 addition & 1 deletion evals/shift_and_tpp/eval_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class EvalConfig:
probe_epochs: int = 20
probe_lr: float = 1e-3

sae_batch_size: int = 250
sae_batch_size: int = 125

# This is for spurrious correlation removal
chosen_class_indices = [
Expand Down
37 changes: 19 additions & 18 deletions evals/sparse_probing/probe_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jaxtyping import Bool, Float, Int, jaxtyped
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import math

import sae_bench_utils.dataset_info as dataset_info

Expand Down Expand Up @@ -40,24 +41,30 @@ def prepare_probe_data(

if spurious_corr:
if class_name in dataset_info.PAIRED_CLASS_KEYS.keys():
negative_acts = all_activations[dataset_info.PAIRED_CLASS_KEYS[class_name]]
selected_negative_acts_BD = all_activations[dataset_info.PAIRED_CLASS_KEYS[class_name]]
elif class_name in dataset_info.PAIRED_CLASS_KEYS.values():
reversed_dict = {v: k for k, v in dataset_info.PAIRED_CLASS_KEYS.items()}
negative_acts = all_activations[reversed_dict[class_name]]
selected_negative_acts_BD = all_activations[reversed_dict[class_name]]
else:
raise ValueError(f"Class {class_name} not found in paired class keys.")
else:
# Collect all negative class activations and labels
negative_acts = []
for idx, acts in all_activations.items():
if idx != class_name:
negative_acts.append(acts)
selected_negative_acts_BD = []
negative_keys = [k for k in all_activations.keys() if k != class_name]
num_neg_classes = len(negative_keys)
samples_per_class = math.ceil(num_positive / num_neg_classes)

negative_acts = torch.cat(negative_acts)
for negative_class_name in negative_keys:
sample_indices = torch.randperm(len(all_activations[negative_class_name]))[
:samples_per_class
]
selected_negative_acts_BD.append(all_activations[negative_class_name][sample_indices])

selected_negative_acts_BD = torch.cat(selected_negative_acts_BD)

# Randomly select num_positive samples from negative class
indices = torch.randperm(len(negative_acts))[:num_positive]
selected_negative_acts_BD = negative_acts[indices]
indices = torch.randperm(len(selected_negative_acts_BD))[:num_positive]
selected_negative_acts_BD = selected_negative_acts_BD[indices]

assert selected_negative_acts_BD.shape == positive_acts_BD.shape

Expand Down Expand Up @@ -311,17 +318,11 @@ def train_probe_on_activations(
probes, test_accuracies = {}, {}

for profession in train_activations.keys():
train_acts, train_labels = prepare_probe_data(
train_activations, profession, spurious_corr
)
test_acts, test_labels = prepare_probe_data(
test_activations, profession, spurious_corr
)
train_acts, train_labels = prepare_probe_data(train_activations, profession, spurious_corr)
test_acts, test_labels = prepare_probe_data(test_activations, profession, spurious_corr)

if select_top_k is not None:
activation_mask_D = get_top_k_mean_diff_mask(
train_acts, train_labels, select_top_k
)
activation_mask_D = get_top_k_mean_diff_mask(train_acts, train_labels, select_top_k)
train_acts = apply_topk_mask_reduce_dim(train_acts, activation_mask_D)
test_acts = apply_topk_mask_reduce_dim(test_acts, activation_mask_D)

Expand Down

0 comments on commit 1fecf15

Please sign in to comment.