Skip to content

Commit c861500

Browse files
committed
Specify where preds are stored
1 parent bcd80c1 commit c861500

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

neural_lam/models/ar_model.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def __init__(self, args):
3333
self.initial_lr = args.initial_lr
3434
self.warmup_epochs = args.warmup_epochs
3535
self.store_pred = args.store_pred
36+
self.dataset = args.dataset
37+
self.run_id = args.run_id
3638

3739
# Load static features for grid/data
3840
static_data_dict = utils.load_static_data(args.dataset)
@@ -441,7 +443,13 @@ def store_predictions(self, batch_idx, prediction):
441443
# Rescale to original data scale
442444
prediction_rescaled = prediction * self.data_std + self.data_mean
443445

444-
pred_dir = os.path.join(wandb.run.dir, "predictions")
446+
if self.run_id is None:
447+
pred_dir = os.path.join(wandb.run.dir, "predictions")
448+
else:
449+
pred_dir = os.path.join(
450+
"data", self.dataset, "predictions", self.run_id
451+
)
452+
445453
os.makedirs(pred_dir, exist_ok=True)
446454

447455
# Save pred as .npy files

0 commit comments

Comments
 (0)