Skip to content

Commit

Permalink
code dump
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Jul 15, 2024
1 parent 47d4150 commit b4dc2b8
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 73 deletions.
4 changes: 2 additions & 2 deletions configs/callbacks/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
early_stopping:
_target_: lightning.pytorch.callbacks.EarlyStopping
# name of the logged metric which determines when model is improving
monitor: "MSE/val"
monitor: "MAE/val"
mode: "min" # can be "max" or "min"
patience: 10 # how many epochs (or val check periods) of not improving until training stops
patience: 30 # how many epochs (or val check periods) of not improving until training stops
min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement

learning_rate_monitor:
Expand Down
Empty file removed configs/callbacks/none.yaml
Empty file.
26 changes: 0 additions & 26 deletions configs/callbacks/wandb.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ print_config: True
# disable python warnings if they annoy you
ignore_warnings: True

seed: 2727831
seed: 27267831
8 changes: 4 additions & 4 deletions configs/datamodule/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ zarr_path:
#- /mnt/disks/nwp_rechunk/sat/2022_nonhrv.zarr
- /mnt/disks/nwp_rechunk/sat/2023_nonhrv.zarr

history_mins: 30
forecast_mins: 35
sample_freq_mins: 5
history_mins: 165
forecast_mins: 180
sample_freq_mins: 15

train_period: ["2019-01-01", "2022-01-01"]
val_period: ["2023-01-01", "2023-12-31"]
num_workers: 16
prefetch_factor: 2
batch_size: 2
batch_size: 1
6 changes: 3 additions & 3 deletions configs/model/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ model:
_partial_: True

num_channels: 11
history_mins: 30
forecast_mins: 35
sample_freq_mins: 5
history_mins: 165
forecast_mins: 180
sample_freq_mins: 15

6 changes: 3 additions & 3 deletions configs/trainer/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ num_sanity_val_steps: 8
fast_dev_run: false
#profiler: 'simple'

accumulate_grad_batches: 4
accumulate_grad_batches: 8
#val_check_interval: 800
limit_train_batches: 1000
limit_val_batches: 200
limit_train_batches: 2000
limit_val_batches: 400
log_every_n_steps: 10
21 changes: 16 additions & 5 deletions sat_pred/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
history_mins: int,
forecast_mins: int,
sample_freq_mins: int,
preshuffle: bool = False
):
"""A torch Dataset for loading past and future satellite data
Expand All @@ -92,6 +93,7 @@ def __init__(
history_mins: How many minutes of history will be used as input features
forecast_mins: How many minutes of future will be used as target features
sample_freq_mins: The sample frequency to use for the satellite data
preshuffle: Whether to shuffle the data - useful for validation
"""

# Load the sat zarr file or list of files and slice the data to the given period
Expand All @@ -105,6 +107,9 @@ def __init__(
# there would be a missing timestamp in the sat data required for the sample
self.valid_t0_times = find_valid_t0_times(self.ds, history_mins, forecast_mins, sample_freq_mins)

if preshuffle:
self.valid_t0_times = pd.to_datetime(np.random.permutation(self.valid_t0_times))

self.history_mins = history_mins
self.forecast_mins = forecast_mins
self.sample_freq_mins = sample_freq_mins
Expand All @@ -115,7 +120,13 @@ def __len__(self):


def __getitem__(self, idx):
t0 = self.valid_t0_times[idx]
if isinstance(idx, (str)):
t0 = pd.Timestamp(idx)
assert t0 in self.valid_t0_times
elif isinstance(idx, int):
t0 = self.valid_t0_times[idx]
else:
raise ValueError(f"Unrecognised type {type(idx)}")

ds_sel = self.ds.sel(
time=slice(
Expand All @@ -141,7 +152,6 @@ def __getitem__(self, idx):
X = np.nan_to_num(X, nan=-1)
y = np.nan_to_num(y, nan=-1)


return X.astype(np.float32), y.astype(np.float32)


Expand Down Expand Up @@ -198,14 +208,15 @@ def __init__(
persistent_workers=False,
)

def _make_dataset(self, start_date, end_date):
def _make_dataset(self, start_date, end_date, preshuffle=False):
dataset = SatelliteDataset(
self.zarr_path,
start_date,
end_date,
self.history_mins,
self.forecast_mins,
self.sample_freq_mins,
preshuffle=preshuffle,
)
return dataset

Expand All @@ -217,8 +228,8 @@ def train_dataloader(self):
def val_dataloader(self):
"""Construct val dataloader"""

dataset = self._make_dataset(*self.val_period)
return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)
dataset = self._make_dataset(*self.val_period, preshuffle=True)
return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)

def test_dataloader(self):
"""Construct test dataloader"""
Expand Down
127 changes: 98 additions & 29 deletions sat_pred/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import yaml
import numpy as np
import matplotlib.pyplot as plt

import wandb
from torch.utils.data import default_collate


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -81,13 +82,46 @@ def __call__(self, model):
"""Return optimizer"""
return torch.optim.AdamW(model.parameters(), lr=self.lr, **self.kwargs)


class AdamWReduceLROnPlateau:
"""AdamW optimizer and reduce on plateau scheduler"""

def __init__(
self, lr=0.0005, patience=10, factor=0.2, threshold=2e-4, step_freq=None, **opt_kwargs
):
"""AdamW optimizer and reduce on plateau scheduler"""
self.lr = lr
self.patience = patience
self.factor = factor
self.threshold = threshold
self.step_freq = step_freq
self.opt_kwargs = opt_kwargs

def __call__(self, model):

opt = torch.optim.AdamW(
model.parameters(), lr=self.lr, **self.opt_kwargs
)
sch = {
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
opt,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,
),
"monitor": "MAE/val",
}

return [opt], [sch]


def plot_sat_images(y, y_hat, channel_inds=[8, 1], n_frames=6):
y = y.cpu().numpy()
y_hat = y_hat.cpu().numpy()

y[y<=0] = np.nan
y_hat[y_hat<=0] = np.nan
mask = y<0
y[mask] = np.nan
y_hat[mask] = np.nan

seq_len = y.shape[1]

Expand Down Expand Up @@ -118,6 +152,32 @@ def plot_sat_images(y, y_hat, channel_inds=[8, 1], n_frames=6):
plt.xlabel("Frame number")
plt.tight_layout()
return fig


def upload_video(y, y_hat, video_name, channel_nums=[8, 1], fps=1):
y = y.cpu().numpy()
y_hat = y_hat.cpu().numpy()

mask = y<0
y[mask] = 0
y_hat[mask] = 0

channel_frames = []

for channel_num in channel_nums:
y_frames = y.transpose(1,0,2,3)[:, channel_num:channel_num+1, ::-1, ::-1]
y_hat_frames = y_hat.transpose(1,0,2,3)[:, channel_num:channel_num+1, ::-1, ::-1]
channel_frames.append(
np.concatenate(
[y_hat_frames, y_frames],
axis=3
)
)

channel_frames = np.concatenate(channel_frames, axis=2)
channel_frames = channel_frames.clip(0, None)
channel_frames = np.repeat(channel_frames, 3, axis=1)*255
wandb.log({video_name: wandb.Video(channel_frames, fps=fps)})


class TrainingModule(pl.LightningModule):
Expand All @@ -130,7 +190,7 @@ def __init__(
history_mins: int,
forecast_mins: int,
sample_freq_mins: int,
optimizer = AdamW(),
optimizer = AdamWReduceLROnPlateau(),
):
"""tbc
Expand Down Expand Up @@ -158,10 +218,19 @@ def __init__(
forecast_len=self.forecast_len,
)

def _filter_missing_targets(self, y, y_hat):

mask = y==-1
y = y[~mask]
y_hat = y_hat[~mask]
return y, y_hat


def _calculate_common_losses(self, y, y_hat):
"""Calculate losses common to train, test, and val"""


y, y_hat = self._filter_missing_targets(y, y_hat)
losses = {}

# calculate mse, mae
Expand All @@ -177,26 +246,12 @@ def _calculate_common_losses(self, y, y_hat):

return losses

def _step_mae_and_mse(self, y, y_hat, dict_key_root):
"""Calculate the MSE and MAE at each forecast step"""
losses = {}

mse_each_step = torch.mean((y_hat - y) ** 2, dim=0)
mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0)

losses.update({f"MSE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mse_each_step)})
losses.update({f"MAE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mae_each_step)})

return losses

def _calculate_val_losses(self, y, y_hat):
"""Calculate additional validation losses"""

losses = {}

# Log the loss at each time horizon
#losses.update(self._step_mae_and_mse(y, y_hat, dict_key_root="horizon"))

return losses

def _calculate_test_losses(self, y, y_hat):
Expand Down Expand Up @@ -238,7 +293,7 @@ def training_step(self, batch, batch_idx):

self._training_accumulate_log(batch, batch_idx, losses, y_hat)

return losses["MSE/train"]
return losses["MAE/train"]

def validation_step(self, batch: dict, batch_idx):
"""Run validation step"""
Expand All @@ -257,22 +312,36 @@ def validation_step(self, batch: dict, batch_idx):
on_epoch=True,
)

accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches
return logged_losses

if batch_idx in [0, 1]:
fig = plot_sat_images(y[0], y_hat[0], channel_inds=[8, 1, 2], n_frames=6)

plot_name = f"val_samples/batch_idx_{batch_idx}"
def on_validation_epoch_start(self):

val_dataset = self.trainer.val_dataloaders.dataset

dates = ["2023-06-01T12:00", "2023-04-05T09:00", "2023-08-05T16:00"]

X, y = default_collate([val_dataset[date]for date in dates])
X = X.to(self.device)
y = y.to(self.device)

y_hat = self.model(X)

for i in range(len(dates)):

plot_name = f"val_sample_plots/{dates[i]}"
fig = plot_sat_images(y[i], y_hat[i], channel_inds=[8, 1, 2], n_frames=6)
self.logger.experiment.log({plot_name: wandb.Image(fig)})

plt.close(fig)

return logged_losses


video_name = f"val_sample_videos/{dates[i]}_channel_8"
upload_video(y[i], y_hat[i], video_name, channel_nums=[8])

video_name = f"val_sample_videos/{dates[i]}_channel_1"
upload_video(y[i], y_hat[i], video_name, channel_nums=[1])


def on_validation_epoch_end(self):
"""Run on epoch end"""

return


Expand Down

0 comments on commit b4dc2b8

Please sign in to comment.