Skip to content

Commit

Permalink
correctly save and load state dict
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen committed Nov 18, 2024
1 parent 2317de9 commit 35d64c8
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions baselines/pca_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ def forward(self, acts):
def save_state_dict(self, file_path: str):
"""Save the encoder and decoder to a file."""
torch.save(
{"W_enc": self.W_enc.data, "W_dec": self.W_dec.data, "mean": self.mean}, file_path
{"W_enc": self.W_enc.data, "W_dec": self.W_dec.data, "mean": self.mean.data}, file_path
)

def load_from_file(self, file_path: str):
"""Load the encoder and decoder from a file."""
state_dict = torch.load(file_path, map_location=self.device)
self.W_enc.data = state_dict["W_enc"]
self.W_dec.data = state_dict["W_dec"]
self.mean = state_dict["mean"]
self.mean.data = state_dict["mean"]

# required as we have device and dtype class attributes
def to(self, *args, **kwargs):
Expand Down Expand Up @@ -269,6 +269,8 @@ def fit_PCA_gpu(
# pca = fit_PCA(pca, model, tokens_BL, llm_batch_size, pca_batch_size)
pca = fit_PCA_gpu(pca, model, tokens_BL, llm_batch_size, pca_batch_size)

pca.load_from_file(f"pca_{model_name}_blocks.{layer}.hook_resid_post.pt")

pca.to(device=device)

test_input = torch.randn(1, 128, d_model, device=device, dtype=torch.float32)
Expand Down

0 comments on commit 35d64c8

Please sign in to comment.