Skip to content

Commit

Permalink
Use sklearn by default, except for training on all SAE latents
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen committed Oct 1, 2024
1 parent 4cd9cff commit 2b1e2b6
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 36 deletions.
7 changes: 4 additions & 3 deletions sparse_probing/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,11 @@ def run_eval(
all_test_acts_BLD, sae, config.sae_batch_size, llm_dtype
)

sae_probes, sae_test_accuracies = probe_training.train_probe_on_activations(
_, sae_test_accuracies = probe_training.train_probe_on_activations(
all_sae_train_acts_BF,
all_sae_test_acts_BF,
select_top_k=None,
use_sklearn=False,
)

results_dict["custom_eval_results"][sae_name] = {}
Expand Down Expand Up @@ -166,7 +167,7 @@ def run_eval(
random.seed(config.random_seed)
torch.manual_seed(config.random_seed)

# populate selected_saes_dict
# populate selected_saes_dict using config values
for release in config.sae_releases:
if "gemma-scope" in release:
config.selected_saes_dict[release] = (
Expand All @@ -185,7 +186,7 @@ def run_eval(
# run the evaluation on all selected SAEs
results_dict = run_eval(config, device)

# create output filename
# create output filename and save results
checkpoints_str = ""
if config.include_checkpoints:
checkpoints_str = "_with_checkpoints"
Expand Down
98 changes: 65 additions & 33 deletions sparse_probing/src/probe_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from beartype import beartype
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import copy

import dataset_info

Expand Down Expand Up @@ -78,7 +79,7 @@ def get_top_k_mean_diff_mask(


@jaxtyped(typechecker=beartype)
def apply_topk_mask_gpu(
def apply_topk_mask_zero_dims(
acts_BD: Float[torch.Tensor, "batch_size d_model"],
mask_D: Bool[torch.Tensor, "d_model"],
) -> Float[torch.Tensor, "batch_size k"]:
Expand All @@ -89,7 +90,7 @@ def apply_topk_mask_gpu(


@jaxtyped(typechecker=beartype)
def apply_topk_mask_sklearn(
def apply_topk_mask_reduce_dim(
acts_BD: Float[torch.Tensor, "batch_size d_model"],
mask_D: Bool[torch.Tensor, "d_model"],
) -> Float[torch.Tensor, "batch_size k"]:
Expand All @@ -111,6 +112,9 @@ def train_sklearn_probe(
verbose: bool = False,
l1_ratio: Optional[float] = None,
) -> tuple[LogisticRegression, float]:
train_inputs = train_inputs.to(dtype=torch.float32)
test_inputs = test_inputs.to(dtype=torch.float32)

# Convert torch tensors to numpy arrays
train_inputs_np = train_inputs.cpu().numpy()
train_labels_np = train_labels.cpu().numpy()
Expand Down Expand Up @@ -153,6 +157,7 @@ def test_sklearn_probe(
labels: Int[torch.Tensor, "dataset_size"],
probe: LogisticRegression,
) -> float:
inputs = inputs.to(dtype=torch.float32)
inputs_np = inputs.cpu().numpy()
labels_np = labels.cpu().numpy()
predictions = probe.predict(inputs_np)
Expand Down Expand Up @@ -206,20 +211,31 @@ def train_probe_gpu(
dim: int,
batch_size: int,
epochs: int,
device: str,
model_dtype: torch.dtype,
lr: float,
verbose: bool = False,
l1_penalty: Optional[float] = None,
early_stopping_patience: int = 10,
) -> tuple[Probe, float]:
"""We have a GPU training function for training on all SAE features, which was very slow (1 minute+) on CPU."""
device = train_inputs.device
model_dtype = train_inputs.dtype

print(f"Training probe with dim: {dim}, device: {device}, dtype: {model_dtype}")

probe = Probe(dim, model_dtype).to(device)
optimizer = torch.optim.AdamW(probe.parameters(), lr=lr)
criterion = nn.BCEWithLogitsLoss()

best_test_accuracy = 0.0
best_probe = None
patience_counter = 0
for epoch in range(epochs):
indices = torch.randperm(len(train_inputs))

for i in range(0, len(train_inputs), batch_size):
acts_BD = train_inputs[i : i + batch_size]
labels_B = train_labels[i : i + batch_size]
batch_indices = indices[i : i + batch_size]
acts_BD = train_inputs[batch_indices]
labels_B = train_labels[batch_indices]
logits_B = probe(acts_BD)
loss = criterion(
logits_B, labels_B.clone().detach().to(device=device, dtype=model_dtype)
Expand All @@ -236,20 +252,35 @@ def train_probe_gpu(
train_accuracy = test_probe_gpu(train_inputs, train_labels, batch_size, probe)
test_accuracy = test_probe_gpu(test_inputs, test_labels, batch_size, probe)

if epoch == epochs - 1 and verbose:
if test_accuracy > best_test_accuracy:
best_test_accuracy = test_accuracy
best_probe = copy.deepcopy(probe)
patience_counter = 0
else:
patience_counter += 1

if verbose:
print(
f"\nEpoch {epoch + 1}/{epochs} Loss: {loss.item()}, train accuracy: {train_accuracy}, test accuracy: {test_accuracy}\n"
f"Epoch {epoch + 1}/{epochs} Loss: {loss.item()}, train accuracy: {train_accuracy}, test accuracy: {test_accuracy}"
)

return probe, test_accuracy
if patience_counter >= early_stopping_patience:
print(f"GPU probe training early stopping triggered after {epoch + 1} epochs")
break

return best_probe, best_test_accuracy


@jaxtyped(typechecker=beartype)
def train_probe_on_activations(
train_activations: dict[str, Float[torch.Tensor, "train_dataset_size d_model"]],
test_activations: dict[str, Float[torch.Tensor, "test_dataset_size d_model"]],
select_top_k: Optional[int] = None,
) -> tuple[dict[str, LogisticRegression], dict[str, float]]:
use_sklearn: bool = True,
) -> tuple[dict[str, Optional[LogisticRegression]], dict[str, float]]:
"""Train a probe on the given activations and return the probe and test accuracies for each profession.
use_sklearn is a flag to use sklearn's LogisticRegression model instead of a custom PyTorch model.
We use sklearn by default. probe training on GPU is only for training a probe on all SAE features."""
torch.set_grad_enabled(True)

probes, test_accuracies = {}, {}
Expand All @@ -261,34 +292,35 @@ def train_probe_on_activations(

if select_top_k is not None:
activation_mask_D = get_top_k_mean_diff_mask(train_acts, train_labels, select_top_k)
train_acts = apply_topk_mask_sklearn(train_acts, activation_mask_D)
test_acts = apply_topk_mask_sklearn(test_acts, activation_mask_D)
train_acts = apply_topk_mask_reduce_dim(train_acts, activation_mask_D)
test_acts = apply_topk_mask_reduce_dim(test_acts, activation_mask_D)

activation_dim = train_acts.shape[1]

print(f"Num non-zero elements: {activation_dim}")

probe, test_accuracy = train_sklearn_probe(
train_acts,
train_labels,
test_acts,
test_labels,
verbose=False,
)

# probe, test_accuracy = train_probe_gpu(
# train_acts,
# train_labels,
# test_acts,
# test_labels,
# dim=activation_dim,
# batch_size=probe_batch_size,
# epochs=epochs,
# device=device,
# model_dtype=model_dtype,
# lr=lr,
# verbose=False,
# )
if use_sklearn:
probe, test_accuracy = train_sklearn_probe(
train_acts,
train_labels,
test_acts,
test_labels,
verbose=False,
)
else:
probe, test_accuracy = train_probe_gpu(
train_acts,
train_labels,
test_acts,
test_labels,
dim=activation_dim,
batch_size=250,
epochs=100,
lr=1e-2,
verbose=False,
early_stopping_patience=10,
)
probe = None

print(f"Test accuracy for {profession}: {test_accuracy}")

Expand Down

0 comments on commit 2b1e2b6

Please sign in to comment.