Skip to content

Commit 74c4e68

Browse files
committed
Prepare two types of atmospheric forcing
1 parent 6e56ab5 commit 74c4e68

File tree

4 files changed

+135
-41
lines changed

4 files changed

+135
-41
lines changed

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Mediterranean analysis
5555
```
5656
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/train -n 6 -p ana_data -s 2021-11-01 -e 2024-03-31
5757
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/val -n 6 -p ana_data -s 2024-04-01 -e 2024-05-31
58-
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/test -n 11 -p ana_data -s 2024-06-01 -e 2024-07-31
58+
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/test -n 16 -p ana_data -s 2024-06-30 -e 2024-08-10
5959
```
6060

6161
ERA5
@@ -68,8 +68,8 @@ python prepare_states.py -d data/mediterranean/raw/era5 -o data/mediterranean/sa
6868

6969
ECMWF
7070
```
71-
python prepare_states.py -d data/mediterranean/raw/ens -o data/mediterranean/samples/test -p ens_forcing -s 2024-06-01 -e 2024-07-31 --forecast_forcing
72-
python prepare_states.py -d data/mediterranean/raw/aifs -o data/mediterranean/samples/test -p aifs_forcing -s 2024-06-01 -e 2024-07-31 --forecast_forcing
71+
python prepare_states.py -d data/mediterranean/raw/ens -o data/mediterranean/samples/test -p ens_forcing -s 2024-07-01 -e 2024-08-11 --forecast_forcing
72+
python prepare_states.py -d data/mediterranean/raw/aifs -o data/mediterranean/samples/test -p aifs_forcing -s 2024-06-01 -e 2024-08-11 --forecast_forcing
7373
```
7474

7575
### Create static features
@@ -135,7 +135,7 @@ For a full list of possible training options, check `python train_model.py --hel
135135
SeaCast was evaluated on 1 GPU using `--eval test`, and by choosing the correct data subset + loading the appropriate model:
136136
```
137137
python train_model.py \
138-
--data_subset reanalysis \
138+
--data_subset analysis \
139139
--n_workers 4 \
140140
--batch_size 1 \
141141
--step_length 1 \

neural_lam/weather_dataset.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
standardize=True,
3434
subset=False,
3535
data_subset=None,
36+
forcing_prefix="forcing",
3637
):
3738
super().__init__()
3839

@@ -92,6 +93,8 @@ def __init__(
9293
# If subsample index should be sampled
9394
self.random_subsample = False
9495

96+
self.forcing_prefix = forcing_prefix
97+
9598
def update_pred_length(self, new_length):
9699
"""
97100
Update prediction length
@@ -148,7 +151,7 @@ def __getitem__(self, idx):
148151

149152
forcing_path = os.path.join(
150153
self.sample_dir_path,
151-
f"forcing_{sample_datetime}.npy",
154+
f"{self.forcing_prefix}_{sample_datetime}.npy",
152155
)
153156
atm_forcing = torch.tensor(
154157
np.load(forcing_path), dtype=torch.float32

prepare_states.py

+116-36
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def prepare_states(
4343

4444
# Process each file, concatenate with the next t-1 files
4545
for i in range(len(files) - n_states + 1):
46-
out_filename = f"{prefix}_{os.path.basename(files[i+2])}"
46+
# Name as today's date
47+
out_filename = f"{prefix}_{os.path.basename(files[i+1])}"
4748
out_file = os.path.join(out_directory, out_filename)
4849

4950
if os.path.isfile(out_file):
@@ -61,63 +62,133 @@ def prepare_states(
6162

6263
# Save concatenated data to the output directory
6364
np.save(out_file, full_state)
64-
print(f"Saved concatenated file: {out_file}")
65+
print(f"Saved states to: {out_file}")
6566

6667

6768
def prepare_forcing(in_directory, out_directory, prefix, start_date, end_date):
6869
"""
69-
Prepare atmospheric forcing data from HRES files.
70+
Prepare atmospheric forcing data from forecasts.
7071
"""
71-
hres_dir = in_directory
72+
forecast_dir = in_directory
7273

7374
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
7475
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
7576

7677
os.makedirs(out_directory, exist_ok=True)
7778

78-
# Get HRES files sorted by date
79-
hres_files = sorted(
80-
glob(os.path.join(hres_dir, "*.npy")),
79+
# Get files sorted by date
80+
forecast_files = sorted(
81+
glob(os.path.join(forecast_dir, "*.npy")),
8182
key=lambda x: datetime.strptime(os.path.basename(x)[:8], "%Y%m%d"),
8283
)
83-
hres_files = [
84+
forecast_files = [
8485
f
85-
for f in hres_files
86+
for f in forecast_files
8687
if start_dt
8788
<= datetime.strptime(os.path.basename(f)[:8], "%Y%m%d")
8889
<= end_dt
8990
]
9091

91-
for hres_file in hres_files:
92-
hres_date = datetime.strptime(os.path.basename(hres_file)[:8], "%Y%m%d")
93-
# Get files for the two preceding days
94-
preceding_days_files = [
95-
os.path.join(
96-
hres_dir,
97-
(hres_date - timedelta(days=i)).strftime("%Y%m%d") + ".npy",
98-
)
99-
for i in range(1, 3)
100-
]
92+
for forecast_file in forecast_files:
93+
forecast_date = datetime.strptime(
94+
os.path.basename(forecast_file)[:8], "%Y%m%d"
95+
)
96+
# Get files for the preceding day
97+
preceding_day_file = os.path.join(
98+
forecast_dir,
99+
(forecast_date - timedelta(days=1)).strftime("%Y%m%d") + ".npy",
100+
)
101+
preceding_day_data = np.load(preceding_day_file)[0:1]
101102

102-
# Load the first timestep from each preceding day's HRES file
103-
init_states = []
104-
for file_path in preceding_days_files:
105-
data = np.load(file_path)
106-
init_states.append(data[0:1])
103+
# Load the current forecast data
104+
current_forecast_data = np.load(forecast_file)[:15]
107105

108-
# Load the current HRES data
109-
current_hres_data = np.load(hres_file)
106+
print(preceding_day_data.shape, current_forecast_data.shape)
110107

111108
# Concatenate all data along the time axis
112109
concatenated_forcing = np.concatenate(
113-
init_states + [current_hres_data], axis=0
110+
[preceding_day_data, current_forecast_data], axis=0
114111
)
115112

116113
# Save concatenated data
117-
out_filename = f"{prefix}_{os.path.basename(hres_file)}"
114+
out_filename = f"{prefix}_{os.path.basename(forecast_file)}"
118115
out_file = os.path.join(out_directory, out_filename)
119116
np.save(out_file, concatenated_forcing)
120-
print(f"Saved combined forcing data file: {out_file}")
117+
print(f"Saved forcing states to: {out_file}")
118+
119+
120+
def prepare_aifs_forcing(
121+
in_directory, out_directory, prefix, start_date, end_date
122+
):
123+
"""
124+
Prepare atmospheric forcing data from AIFS forecasts (add SSR from ENS).
125+
"""
126+
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
127+
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
128+
129+
os.makedirs(out_directory, exist_ok=True)
130+
131+
# Get files sorted by date
132+
forecast_files = sorted(
133+
glob(os.path.join(in_directory, "*.npy")),
134+
key=lambda x: datetime.strptime(os.path.basename(x)[:8], "%Y%m%d"),
135+
)
136+
forecast_files = [
137+
f
138+
for f in forecast_files
139+
if start_dt
140+
<= datetime.strptime(os.path.basename(f)[:8], "%Y%m%d")
141+
<= end_dt
142+
]
143+
144+
ifs_variables = ["u10", "v10", "t2m", "msl", "ssr", "tp"]
145+
146+
for forecast_file in forecast_files:
147+
forecast_date = datetime.strptime(
148+
os.path.basename(forecast_file)[:8], "%Y%m%d"
149+
)
150+
151+
# Load the preceding day's data
152+
preceding_day_file = os.path.join(
153+
in_directory,
154+
(forecast_date - timedelta(days=1)).strftime("%Y%m%d") + ".npy",
155+
)
156+
preceding_day_data = np.load(preceding_day_file)[0:1]
157+
158+
# Insert SSR from ENS data
159+
preceding_day_ens_data = np.load(
160+
preceding_day_file.replace("aifs", "ens")
161+
)[0:1]
162+
preceding_day_ssr_data = preceding_day_ens_data[
163+
..., ifs_variables.index("ssr")
164+
]
165+
preceding_day_data = np.insert(
166+
preceding_day_data,
167+
ifs_variables.index("ssr"),
168+
preceding_day_ssr_data,
169+
axis=-1,
170+
)
171+
172+
# Load the current forecast data
173+
current_forecast_data = np.load(forecast_file)[:15]
174+
175+
# Insert SSR from ENS data
176+
ens_data = np.load(forecast_file.replace("aifs", "ens"))[:15]
177+
ssr_data = ens_data[..., ifs_variables.index("ssr")]
178+
current_forecast_data = np.insert(
179+
current_forecast_data, ifs_variables.index("ssr"), ssr_data, axis=-1
180+
)
181+
182+
# Concatenate preceding day data with current forecast data
183+
aifs_data = np.concatenate(
184+
[preceding_day_data, current_forecast_data], axis=0
185+
)
186+
187+
# Save combined data
188+
out_filename = f"{prefix}_{os.path.basename(forecast_file)}"
189+
out_file = os.path.join(out_directory, out_filename)
190+
np.save(out_file, aifs_data)
191+
print(f"Saved forcing states to: {out_file}")
121192

122193

123194
def main():
@@ -176,13 +247,22 @@ def main():
176247
args = parser.parse_args()
177248

178249
if args.forecast_forcing:
179-
prepare_forcing(
180-
args.data_dir,
181-
args.out_dir,
182-
args.prefix,
183-
args.start_date,
184-
args.end_date,
185-
)
250+
if args.data_dir.endswith("aifs"):
251+
prepare_aifs_forcing(
252+
args.data_dir,
253+
args.out_dir,
254+
args.prefix,
255+
args.start_date,
256+
args.end_date,
257+
)
258+
else:
259+
prepare_forcing(
260+
args.data_dir,
261+
args.out_dir,
262+
args.prefix,
263+
args.start_date,
264+
args.end_date,
265+
)
186266
else:
187267
prepare_states(
188268
args.data_dir,

train_model.py

+11
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,13 @@ def main():
189189
default=None,
190190
help="Type of data to use: 'analysis' or 'reanalysis' (default: None)",
191191
)
192+
parser.add_argument(
193+
"--forcing_prefix",
194+
type=str,
195+
choices=["forcing", "ens_forcing", "aifs_forcing"],
196+
default="forcing",
197+
help="Type of forcing to use (default: forcing => ERA5 files)",
198+
)
192199
parser.add_argument(
193200
"--loss",
194201
type=str,
@@ -278,6 +285,7 @@ def main():
278285
subsample_step=args.step_length,
279286
subset=bool(args.subset_ds),
280287
data_subset=args.data_subset,
288+
forcing_prefix=args.forcing_prefix,
281289
),
282290
args.batch_size,
283291
shuffle=True,
@@ -292,6 +300,7 @@ def main():
292300
subsample_step=args.step_length,
293301
subset=bool(args.subset_ds),
294302
data_subset=args.data_subset,
303+
forcing_prefix=args.forcing_prefix,
295304
),
296305
args.batch_size,
297306
shuffle=False,
@@ -375,6 +384,8 @@ def main():
375384
split="test",
376385
subsample_step=args.step_length,
377386
subset=bool(args.subset_ds),
387+
data_subset=args.data_subset,
388+
forcing_prefix=args.forcing_prefix,
378389
),
379390
args.batch_size,
380391
shuffle=False,

0 commit comments

Comments
 (0)