Skip to content

Commit 0a09700

Browse files
committed
Store predictions
1 parent 9bd5290 commit 0a09700

File tree

7 files changed

+140
-23
lines changed

7 files changed

+140
-23
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Mediterranean analysis
5353
```
5454
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/train -n 6 -p ana_data -s 2022-01-01 -e 2024-04-30
5555
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/val -n 6 -p ana_data -s 2024-05-01 -e 2024-06-30
56+
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/test -n 17 -p ana_data -s 2024-07-22 -e 2024-08-12 --forecast
5657
```
5758

5859
ERA5
@@ -63,7 +64,7 @@ python prepare_states.py -d data/mediterranean/raw/era5 -o data/mediterranean/sa
6364

6465
Forecast data
6566
```
66-
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/test -n 16 -p ana_data -s 2024-06-30 -e 2024-08-10 --forecast
67+
python prepare_states.py -d data/mediterranean/raw/forecast -o data/mediterranean/samples/test -p for_data -s 2024-07-24 -e 2024-08-01 --forecast
6768
python prepare_states.py -d data/mediterranean/raw/ens -o data/mediterranean/samples/test -p ens_forcing -s 2024-07-01 -e 2024-08-11 --forecast
6869
python prepare_states.py -d data/mediterranean/raw/aifs -o data/mediterranean/samples/test -p aifs_forcing -s 2024-06-01 -e 2024-08-11 --forecast
6970
```

download_data.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ def download_forecast(
261261

262262
filename = f"{path_prefix}/{start_date.strftime('%Y%m%d')}.npy"
263263

264+
initial_date = start_date - timedelta(days=2)
265+
264266
all_data = []
265267
for dataset_id, variables in datasets.items():
266268
# Load ocean physics dataset for all dates at once
@@ -270,7 +272,7 @@ def download_forecast(
270272
dataset_part="default",
271273
service="arco-geo-series",
272274
variables=variables,
273-
start_datetime=start_date.strftime("%Y-%m-%dT00:00:00"),
275+
start_datetime=initial_date.strftime("%Y-%m-%dT00:00:00"),
274276
end_datetime=end_date.strftime("%Y-%m-%dT00:00:00"),
275277
minimum_depth=constants.DEPTHS[0],
276278
maximum_depth=constants.DEPTHS[-1],

neural_lam/constants.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77
# Log prediction error for these lead times
88
VAL_STEP_LOG_ERRORS = np.array([1, 2, 3, 4])
9-
TEST_STEP_LOG_ERRORS = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])
9+
TEST_STEP_LOG_ERRORS = np.array(
10+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
11+
)
1012

1113
# Sample lengths
1214
SAMPLE_LEN = {
1315
"train": 6,
1416
"val": 6,
15-
"test": 16,
17+
"test": 17,
1618
}
1719

1820
# Log these metrics to wandb as scalar values for

neural_lam/models/ar_model.py

+44
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ def __init__(self, args):
2727
self.save_hyperparameters()
2828
self.optimizer = args.optimizer
2929
self.lr = args.lr
30+
self.batch_size = args.batch_size
3031
self.epochs = args.epochs
3132
self.scheduler = args.scheduler
3233
self.initial_lr = args.initial_lr
3334
self.warmup_epochs = args.warmup_epochs
35+
self.store_pred = args.store_pred
3436

3537
# Load static features for grid/data
3638
static_data_dict = utils.load_static_data(args.dataset)
@@ -99,6 +101,9 @@ def __init__(self, args):
99101
# For storing spatial loss maps during evaluation
100102
self.spatial_loss_maps = []
101103

104+
# For storing predictions under sample names
105+
self.sample_names = []
106+
102107
def configure_optimizers(self):
103108
if self.optimizer == "adamw":
104109
opt = torch.optim.AdamW(
@@ -166,6 +171,12 @@ def expand_to_batch(x, batch_size):
166171
"""
167172
return x.unsqueeze(0).expand(batch_size, -1, -1)
168173

174+
def set_sample_names(self, dataset):
175+
"""
176+
Set sample names for evaluation
177+
"""
178+
self.sample_names = dataset.sample_names
179+
169180
def predict_step(self, prev_state, prev_prev_state, forcing):
170181
"""
171182
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1
@@ -393,6 +404,10 @@ def test_step(self, batch, batch_idx):
393404
self.spatial_loss_maps.append(log_spatial_losses)
394405
# (B, N_log, num_grid_nodes)
395406

407+
# Store predictions
408+
if self.store_pred:
409+
self.store_predictions(batch_idx, prediction)
410+
396411
# Plot example predictions (on rank 0 only)
397412
if (
398413
self.trainer.is_global_zero
@@ -407,6 +422,35 @@ def test_step(self, batch, batch_idx):
407422
batch, n_additional_examples, prediction=prediction
408423
)
409424

425+
def store_predictions(self, batch_idx, prediction):
426+
"""
427+
Store predictions for a batch
428+
429+
batch_idx: index of the batch in the dataloader
430+
prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction.
431+
"""
432+
433+
sample_names = [
434+
self.sample_names[idx]
435+
for idx in range(
436+
batch_idx * self.batch_size,
437+
(batch_idx + 1) * self.batch_size,
438+
)
439+
]
440+
441+
# Rescale to original data scale
442+
prediction_rescaled = prediction * self.data_std + self.data_mean
443+
444+
pred_dir = os.path.join(wandb.run.dir, "predictions")
445+
os.makedirs(pred_dir, exist_ok=True)
446+
447+
# Save pred as .npy files
448+
for i, sample_name in enumerate(sample_names):
449+
np.save(
450+
os.path.join(pred_dir, f"{sample_name}.npy"),
451+
prediction_rescaled[i].cpu().numpy(),
452+
)
453+
410454
def plot_examples(self, batch, n_examples, prediction=None):
411455
"""
412456
Plot the first n_examples forecasts from batch

neural_lam/weather_dataset.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ def __init__(
4848
else (
4949
"ana_data_*.npy"
5050
if data_subset == "analysis"
51-
else "*_data_*.npy"
51+
else (
52+
"for_data_*.npy"
53+
if data_subset == "forecast" and split == "test"
54+
else "*_data_*.npy"
55+
)
5256
)
5357
)
5458
sample_paths = glob.glob(

prepare_states.py

+74-17
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def prepare_states(
4646

4747
# Process each file, concatenate with the next t-1 files
4848
for i in range(len(files) - n_states + 1):
49-
# Name as today's date
50-
out_filename = f"{prefix}_{os.path.basename(files[i+1])}"
49+
# Name as first forecasted date
50+
out_filename = f"{prefix}_{os.path.basename(files[i + 2])}"
5151
out_file = os.path.join(out_directory, out_filename)
5252

5353
if os.path.isfile(out_file):
@@ -132,27 +132,22 @@ def prepare_states_with_boundary(
132132

133133
# Process each file, concatenate with the next t-1 files
134134
for i in range(len(files) - n_states + 1):
135-
today = os.path.basename(files[i + 1])
136-
out_filename = f"{prefix}_{today}"
135+
forecast_date = os.path.basename(files[i + 2])
136+
out_filename = f"{prefix}_{forecast_date}"
137137
out_file = os.path.join(out_directory, out_filename)
138138

139-
if os.path.isfile(out_file):
140-
continue
141-
142139
# Stack analysis states
143140
state_sequence = [np.load(files[i + j]) for j in range(n_states)]
144141
full_state = np.stack(state_sequence, axis=0)
145142
print("full state", full_state.shape) # (n_states, N_grid, d_features)
146143

147-
forecast_file = files[i + 1].replace("analysis", "forecast")
148-
forecast_data = np.load(forecast_file)
149-
forecast_len = forecast_data.shape[0]
144+
forecast_file = files[i + 2].replace("analysis", "forecast")
145+
forecast_data = np.load(forecast_file)[2:]
150146
print(
151147
"forecast before", forecast_data.shape
152148
) # (forecast_len, N_grid, d_features)
153149

154-
assert n_states >= forecast_len, "n_states less than forecast length"
155-
extra_states = n_states - 1 - forecast_data.shape[0]
150+
extra_states = 5
156151
last_forecast_state = forecast_data[-1]
157152
repeated_forecast_states = np.repeat(
158153
last_forecast_state[np.newaxis, ...], extra_states, axis=0
@@ -162,21 +157,65 @@ def prepare_states_with_boundary(
162157
)
163158
print(
164159
"forecast after", forecast_data.shape
165-
) # (n_states - 1, N_grid, d_features)
160+
) # (n_states - 2, N_grid, d_features)
166161

167162
# Concatenate preceding day analysis state with forecast data
168163
forecast_data = np.concatenate(
169-
(state_sequence[:1], forecast_data), axis=0
164+
(state_sequence[:2], forecast_data), axis=0
170165
) # (n_states, N_grid, d_features)
171166

172167
full_state = (
173168
full_state * (1 - border_mask) + forecast_data * border_mask
174169
)
175170

176-
np.save(out_file, full_state)
171+
np.save(out_file, full_state.astype(np.float32))
177172
print(f"Saved states to: {out_file}")
178173

179174

175+
def prepare_forecast(in_directory, out_directory, prefix, start_date, end_date):
176+
"""
177+
Prepare forecast data by repeating the last state.
178+
"""
179+
forecast_dir = in_directory
180+
181+
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
182+
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
183+
184+
os.makedirs(out_directory, exist_ok=True)
185+
186+
# Get files sorted by date
187+
forecast_files = sorted(
188+
glob(os.path.join(forecast_dir, "*.npy")),
189+
key=lambda x: datetime.strptime(os.path.basename(x)[:8], "%Y%m%d"),
190+
)
191+
forecast_files = [
192+
f
193+
for f in forecast_files
194+
if start_dt
195+
<= datetime.strptime(os.path.basename(f)[:8], "%Y%m%d")
196+
<= end_dt
197+
]
198+
199+
for forecast_file in forecast_files:
200+
# Load the current forecast data
201+
forecast_data = np.load(forecast_file)
202+
print(forecast_data.shape)
203+
204+
last_forecast_state = forecast_data[-1]
205+
repeated_forecast_states = np.repeat(
206+
last_forecast_state[np.newaxis, ...], repeats=5, axis=0
207+
)
208+
forecast_data = np.concatenate(
209+
[forecast_data, repeated_forecast_states], axis=0
210+
)
211+
212+
# Save concatenated data
213+
out_filename = f"{prefix}_{os.path.basename(forecast_file)}"
214+
out_file = os.path.join(out_directory, out_filename)
215+
np.save(out_file, forecast_data)
216+
print(f"Saved forecast to: {out_file}")
217+
218+
180219
def prepare_forcing(in_directory, out_directory, prefix, start_date, end_date):
181220
"""
182221
Prepare atmospheric forcing data from forecasts.
@@ -205,6 +244,14 @@ def prepare_forcing(in_directory, out_directory, prefix, start_date, end_date):
205244
forecast_date = datetime.strptime(
206245
os.path.basename(forecast_file)[:8], "%Y%m%d"
207246
)
247+
248+
# Get files for the pre-preceding day
249+
prepreceding_day_file = os.path.join(
250+
forecast_dir,
251+
(forecast_date - timedelta(days=2)).strftime("%Y%m%d") + ".npy",
252+
)
253+
prepreceding_day_data = np.load(prepreceding_day_file)[0:1]
254+
208255
# Get files for the preceding day
209256
preceding_day_file = os.path.join(
210257
forecast_dir,
@@ -217,12 +264,14 @@ def prepare_forcing(in_directory, out_directory, prefix, start_date, end_date):
217264

218265
print(preceding_day_data.shape, current_forecast_data.shape)
219266

267+
prepreceding_day_data = prepreceding_day_data[:, :, :4]
220268
preceding_day_data = preceding_day_data[:, :, :4]
221269
current_forecast_data = current_forecast_data[:, :, :4]
222270

223271
# Concatenate all data along the time axis
224272
concatenated_forcing = np.concatenate(
225-
[preceding_day_data, current_forecast_data], axis=0
273+
[prepreceding_day_data, preceding_day_data, current_forecast_data],
274+
axis=0,
226275
)
227276

228277
# Save concatenated data
@@ -303,7 +352,7 @@ def main():
303352
args.start_date,
304353
args.end_date,
305354
)
306-
else:
355+
elif args.data_dir.endswith("analysis"):
307356
prepare_states_with_boundary(
308357
args.data_dir,
309358
args.static_dir,
@@ -313,6 +362,14 @@ def main():
313362
args.start_date,
314363
args.end_date,
315364
)
365+
else:
366+
prepare_forecast(
367+
args.data_dir,
368+
args.out_dir,
369+
args.prefix,
370+
args.start_date,
371+
args.end_date,
372+
)
316373
else:
317374
prepare_states(
318375
args.data_dir,

train_model.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def main():
185185
parser.add_argument(
186186
"--data_subset",
187187
type=str,
188-
choices=["analysis", "reanalysis"],
188+
choices=["analysis", "reanalysis", "forecast"],
189189
default=None,
190190
help="Type of data to use: 'analysis' or 'reanalysis' (default: None)",
191191
)
@@ -259,6 +259,12 @@ def main():
259259
help="Number of example predictions to plot during evaluation "
260260
"(default: 1)",
261261
)
262+
parser.add_argument(
263+
"--store_pred",
264+
type=int,
265+
default=0,
266+
help="Whether or not to store predictions (default: 0 (no))",
267+
)
262268
args = parser.parse_args()
263269

264270
# Asserts for arguments
@@ -391,6 +397,7 @@ def main():
391397
shuffle=False,
392398
num_workers=args.n_workers,
393399
)
400+
model.set_sample_names(eval_loader.dataset)
394401

395402
print(f"Running evaluation on {args.eval}")
396403
trainer.test(model=model, dataloaders=eval_loader)

0 commit comments

Comments
 (0)