From 2b1e2b6ee45c226ebf9e57a0f38ea468731f710f Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 1 Oct 2024 14:08:26 -0500 Subject: [PATCH] Use sklearn by default, except for training on all SAE latents --- sparse_probing/src/main.py | 7 +- sparse_probing/src/probe_training.py | 98 ++++++++++++++++++---------- 2 files changed, 69 insertions(+), 36 deletions(-) diff --git a/sparse_probing/src/main.py b/sparse_probing/src/main.py index 7919c77..97e903b 100644 --- a/sparse_probing/src/main.py +++ b/sparse_probing/src/main.py @@ -120,10 +120,11 @@ def run_eval( all_test_acts_BLD, sae, config.sae_batch_size, llm_dtype ) - sae_probes, sae_test_accuracies = probe_training.train_probe_on_activations( + _, sae_test_accuracies = probe_training.train_probe_on_activations( all_sae_train_acts_BF, all_sae_test_acts_BF, select_top_k=None, + use_sklearn=False, ) results_dict["custom_eval_results"][sae_name] = {} @@ -166,7 +167,7 @@ def run_eval( random.seed(config.random_seed) torch.manual_seed(config.random_seed) - # populate selected_saes_dict + # populate selected_saes_dict using config values for release in config.sae_releases: if "gemma-scope" in release: config.selected_saes_dict[release] = ( @@ -185,7 +186,7 @@ def run_eval( # run the evaluation on all selected SAEs results_dict = run_eval(config, device) - # create output filename + # create output filename and save results checkpoints_str = "" if config.include_checkpoints: checkpoints_str = "_with_checkpoints" diff --git a/sparse_probing/src/probe_training.py b/sparse_probing/src/probe_training.py index 3aabf26..a69c39b 100644 --- a/sparse_probing/src/probe_training.py +++ b/sparse_probing/src/probe_training.py @@ -5,6 +5,7 @@ from beartype import beartype from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score +import copy import dataset_info @@ -78,7 +79,7 @@ def get_top_k_mean_diff_mask( @jaxtyped(typechecker=beartype) -def apply_topk_mask_gpu( +def apply_topk_mask_zero_dims( acts_BD: Float[torch.Tensor, "batch_size d_model"], mask_D: Bool[torch.Tensor, "d_model"], ) -> Float[torch.Tensor, "batch_size k"]: @@ -89,7 +90,7 @@ def apply_topk_mask_gpu( @jaxtyped(typechecker=beartype) -def apply_topk_mask_sklearn( +def apply_topk_mask_reduce_dim( acts_BD: Float[torch.Tensor, "batch_size d_model"], mask_D: Bool[torch.Tensor, "d_model"], ) -> Float[torch.Tensor, "batch_size k"]: @@ -111,6 +112,9 @@ def train_sklearn_probe( verbose: bool = False, l1_ratio: Optional[float] = None, ) -> tuple[LogisticRegression, float]: + train_inputs = train_inputs.to(dtype=torch.float32) + test_inputs = test_inputs.to(dtype=torch.float32) + # Convert torch tensors to numpy arrays train_inputs_np = train_inputs.cpu().numpy() train_labels_np = train_labels.cpu().numpy() @@ -153,6 +157,7 @@ def test_sklearn_probe( labels: Int[torch.Tensor, "dataset_size"], probe: LogisticRegression, ) -> float: + inputs = inputs.to(dtype=torch.float32) inputs_np = inputs.cpu().numpy() labels_np = labels.cpu().numpy() predictions = probe.predict(inputs_np) @@ -206,20 +211,31 @@ def train_probe_gpu( dim: int, batch_size: int, epochs: int, - device: str, - model_dtype: torch.dtype, lr: float, verbose: bool = False, l1_penalty: Optional[float] = None, + early_stopping_patience: int = 10, ) -> tuple[Probe, float]: + """We have a GPU training function for training on all SAE features, which was very slow (1 minute+) on CPU.""" + device = train_inputs.device + model_dtype = train_inputs.dtype + + print(f"Training probe with dim: {dim}, device: {device}, dtype: {model_dtype}") + probe = Probe(dim, model_dtype).to(device) optimizer = torch.optim.AdamW(probe.parameters(), lr=lr) criterion = nn.BCEWithLogitsLoss() + best_test_accuracy = 0.0 + best_probe = None + patience_counter = 0 for epoch in range(epochs): + indices = torch.randperm(len(train_inputs)) + for i in range(0, len(train_inputs), batch_size): - acts_BD = train_inputs[i : i + batch_size] - labels_B = train_labels[i : i + batch_size] + batch_indices = indices[i : i + batch_size] + acts_BD = train_inputs[batch_indices] + labels_B = train_labels[batch_indices] logits_B = probe(acts_BD) loss = criterion( logits_B, labels_B.clone().detach().to(device=device, dtype=model_dtype) @@ -236,12 +252,23 @@ def train_probe_gpu( train_accuracy = test_probe_gpu(train_inputs, train_labels, batch_size, probe) test_accuracy = test_probe_gpu(test_inputs, test_labels, batch_size, probe) - if epoch == epochs - 1 and verbose: + if test_accuracy > best_test_accuracy: + best_test_accuracy = test_accuracy + best_probe = copy.deepcopy(probe) + patience_counter = 0 + else: + patience_counter += 1 + + if verbose: print( - f"\nEpoch {epoch + 1}/{epochs} Loss: {loss.item()}, train accuracy: {train_accuracy}, test accuracy: {test_accuracy}\n" + f"Epoch {epoch + 1}/{epochs} Loss: {loss.item()}, train accuracy: {train_accuracy}, test accuracy: {test_accuracy}" ) - return probe, test_accuracy + if patience_counter >= early_stopping_patience: + print(f"GPU probe training early stopping triggered after {epoch + 1} epochs") + break + + return best_probe, best_test_accuracy @jaxtyped(typechecker=beartype) @@ -249,7 +276,11 @@ def train_probe_on_activations( train_activations: dict[str, Float[torch.Tensor, "train_dataset_size d_model"]], test_activations: dict[str, Float[torch.Tensor, "test_dataset_size d_model"]], select_top_k: Optional[int] = None, -) -> tuple[dict[str, LogisticRegression], dict[str, float]]: + use_sklearn: bool = True, +) -> tuple[dict[str, Optional[LogisticRegression]], dict[str, float]]: + """Train a probe on the given activations and return the probe and test accuracies for each profession. + use_sklearn is a flag to use sklearn's LogisticRegression model instead of a custom PyTorch model. + We use sklearn by default. probe training on GPU is only for training a probe on all SAE features.""" torch.set_grad_enabled(True) probes, test_accuracies = {}, {} @@ -261,34 +292,35 @@ def train_probe_on_activations( if select_top_k is not None: activation_mask_D = get_top_k_mean_diff_mask(train_acts, train_labels, select_top_k) - train_acts = apply_topk_mask_sklearn(train_acts, activation_mask_D) - test_acts = apply_topk_mask_sklearn(test_acts, activation_mask_D) + train_acts = apply_topk_mask_reduce_dim(train_acts, activation_mask_D) + test_acts = apply_topk_mask_reduce_dim(test_acts, activation_mask_D) activation_dim = train_acts.shape[1] print(f"Num non-zero elements: {activation_dim}") - probe, test_accuracy = train_sklearn_probe( - train_acts, - train_labels, - test_acts, - test_labels, - verbose=False, - ) - - # probe, test_accuracy = train_probe_gpu( - # train_acts, - # train_labels, - # test_acts, - # test_labels, - # dim=activation_dim, - # batch_size=probe_batch_size, - # epochs=epochs, - # device=device, - # model_dtype=model_dtype, - # lr=lr, - # verbose=False, - # ) + if use_sklearn: + probe, test_accuracy = train_sklearn_probe( + train_acts, + train_labels, + test_acts, + test_labels, + verbose=False, + ) + else: + probe, test_accuracy = train_probe_gpu( + train_acts, + train_labels, + test_acts, + test_labels, + dim=activation_dim, + batch_size=250, + epochs=100, + lr=1e-2, + verbose=False, + early_stopping_patience=10, + ) + probe = None print(f"Test accuracy for {profession}: {test_accuracy}")