diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml index d0512d7..6a69868 100644 --- a/configs/callbacks/default.yaml +++ b/configs/callbacks/default.yaml @@ -11,9 +11,9 @@ early_stopping: _target_: lightning.pytorch.callbacks.EarlyStopping # name of the logged metric which determines when model is improving - monitor: "MSE/val" + monitor: "MAE/val" mode: "min" # can be "max" or "min" - patience: 10 # how many epochs (or val check periods) of not improving until training stops + patience: 30 # how many epochs (or val check periods) of not improving until training stops min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement learning_rate_monitor: diff --git a/configs/callbacks/none.yaml b/configs/callbacks/none.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/configs/callbacks/wandb.yaml b/configs/callbacks/wandb.yaml deleted file mode 100644 index c6ae21d..0000000 --- a/configs/callbacks/wandb.yaml +++ /dev/null @@ -1,26 +0,0 @@ -defaults: - - default.yaml - -watch_model: - _target_: src.callbacks.wandb_callbacks.WatchModel - log: "all" - log_freq: 100 - -upload_code_as_artifact: - _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact - code_dir: ${work_dir}/src - -upload_ckpts_as_artifact: - _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact - ckpt_dir: "checkpoints/" - upload_best_only: True - -log_f1_precision_recall_heatmap: - _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap - -log_confusion_matrix: - _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix - -log_image_predictions: - _target_: src.callbacks.wandb_callbacks.LogImagePredictions - num_samples: 8 diff --git a/configs/config.yaml b/configs/config.yaml index c3e9c6e..1193ce5 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -31,4 +31,4 @@ print_config: True # disable python warnings if they annoy you ignore_warnings: True -seed: 2727831 +seed: 27267831 diff --git a/configs/datamodule/default.yaml b/configs/datamodule/default.yaml index 297d946..41c3e5a 100644 --- a/configs/datamodule/default.yaml +++ b/configs/datamodule/default.yaml @@ -7,12 +7,12 @@ zarr_path: #- /mnt/disks/nwp_rechunk/sat/2022_nonhrv.zarr - /mnt/disks/nwp_rechunk/sat/2023_nonhrv.zarr -history_mins: 30 -forecast_mins: 35 -sample_freq_mins: 5 +history_mins: 165 +forecast_mins: 180 +sample_freq_mins: 15 train_period: ["2019-01-01", "2022-01-01"] val_period: ["2023-01-01", "2023-12-31"] num_workers: 16 prefetch_factor: 2 -batch_size: 2 \ No newline at end of file +batch_size: 1 \ No newline at end of file diff --git a/configs/model/default.yaml b/configs/model/default.yaml index aca4c19..cea1519 100644 --- a/configs/model/default.yaml +++ b/configs/model/default.yaml @@ -4,7 +4,7 @@ model: _partial_: True num_channels: 11 -history_mins: 30 -forecast_mins: 35 -sample_freq_mins: 5 +history_mins: 165 +forecast_mins: 180 +sample_freq_mins: 15 diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml index 20111b7..2ba0df5 100644 --- a/configs/trainer/default.yaml +++ b/configs/trainer/default.yaml @@ -11,8 +11,8 @@ num_sanity_val_steps: 8 fast_dev_run: false #profiler: 'simple' -accumulate_grad_batches: 4 +accumulate_grad_batches: 8 #val_check_interval: 800 -limit_train_batches: 1000 -limit_val_batches: 200 +limit_train_batches: 2000 +limit_val_batches: 400 log_every_n_steps: 10 diff --git a/sat_pred/dataloader.py b/sat_pred/dataloader.py index fa2371c..ebf910f 100644 --- a/sat_pred/dataloader.py +++ b/sat_pred/dataloader.py @@ -82,6 +82,7 @@ def __init__( history_mins: int, forecast_mins: int, sample_freq_mins: int, + preshuffle: bool = False ): """A torch Dataset for loading past and future satellite data @@ -92,6 +93,7 @@ def __init__( history_mins: How many minutes of history will be used as input features forecast_mins: How many minutes of future will be used as target features sample_freq_mins: The sample frequency to use for the satellite data + preshuffle: Whether to shuffle the data - useful for validation """ # Load the sat zarr file or list of files and slice the data to the given period @@ -105,6 +107,9 @@ def __init__( # there would be a missing timestamp in the sat data required for the sample self.valid_t0_times = find_valid_t0_times(self.ds, history_mins, forecast_mins, sample_freq_mins) + if preshuffle: + self.valid_t0_times = pd.to_datetime(np.random.permutation(self.valid_t0_times)) + self.history_mins = history_mins self.forecast_mins = forecast_mins self.sample_freq_mins = sample_freq_mins @@ -115,7 +120,13 @@ def __len__(self): def __getitem__(self, idx): - t0 = self.valid_t0_times[idx] + if isinstance(idx, (str)): + t0 = pd.Timestamp(idx) + assert t0 in self.valid_t0_times + elif isinstance(idx, int): + t0 = self.valid_t0_times[idx] + else: + raise ValueError(f"Unrecognised type {type(idx)}") ds_sel = self.ds.sel( time=slice( @@ -141,7 +152,6 @@ def __getitem__(self, idx): X = np.nan_to_num(X, nan=-1) y = np.nan_to_num(y, nan=-1) - return X.astype(np.float32), y.astype(np.float32) @@ -198,7 +208,7 @@ def __init__( persistent_workers=False, ) - def _make_dataset(self, start_date, end_date): + def _make_dataset(self, start_date, end_date, preshuffle=False): dataset = SatelliteDataset( self.zarr_path, start_date, @@ -206,6 +216,7 @@ def _make_dataset(self, start_date, end_date): self.history_mins, self.forecast_mins, self.sample_freq_mins, + preshuffle=preshuffle, ) return dataset @@ -217,8 +228,8 @@ def train_dataloader(self): def val_dataloader(self): """Construct val dataloader""" - dataset = self._make_dataset(*self.val_period) - return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs) + dataset = self._make_dataset(*self.val_period, preshuffle=True) + return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs) def test_dataloader(self): """Construct test dataloader""" diff --git a/sat_pred/model.py b/sat_pred/model.py index 9dc3586..2accd33 100644 --- a/sat_pred/model.py +++ b/sat_pred/model.py @@ -15,7 +15,8 @@ import yaml import numpy as np import matplotlib.pyplot as plt - +import wandb +from torch.utils.data import default_collate logger = logging.getLogger(__name__) @@ -81,13 +82,46 @@ def __call__(self, model): """Return optimizer""" return torch.optim.AdamW(model.parameters(), lr=self.lr, **self.kwargs) + +class AdamWReduceLROnPlateau: + """AdamW optimizer and reduce on plateau scheduler""" + + def __init__( + self, lr=0.0005, patience=10, factor=0.2, threshold=2e-4, step_freq=None, **opt_kwargs + ): + """AdamW optimizer and reduce on plateau scheduler""" + self.lr = lr + self.patience = patience + self.factor = factor + self.threshold = threshold + self.step_freq = step_freq + self.opt_kwargs = opt_kwargs + + def __call__(self, model): + + opt = torch.optim.AdamW( + model.parameters(), lr=self.lr, **self.opt_kwargs + ) + sch = { + "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( + opt, + factor=self.factor, + patience=self.patience, + threshold=self.threshold, + ), + "monitor": "MAE/val", + } + + return [opt], [sch] + def plot_sat_images(y, y_hat, channel_inds=[8, 1], n_frames=6): y = y.cpu().numpy() y_hat = y_hat.cpu().numpy() - y[y<=0] = np.nan - y_hat[y_hat<=0] = np.nan + mask = y<0 + y[mask] = np.nan + y_hat[mask] = np.nan seq_len = y.shape[1] @@ -118,6 +152,32 @@ def plot_sat_images(y, y_hat, channel_inds=[8, 1], n_frames=6): plt.xlabel("Frame number") plt.tight_layout() return fig + + +def upload_video(y, y_hat, video_name, channel_nums=[8, 1], fps=1): + y = y.cpu().numpy() + y_hat = y_hat.cpu().numpy() + + mask = y<0 + y[mask] = 0 + y_hat[mask] = 0 + + channel_frames = [] + + for channel_num in channel_nums: + y_frames = y.transpose(1,0,2,3)[:, channel_num:channel_num+1, ::-1, ::-1] + y_hat_frames = y_hat.transpose(1,0,2,3)[:, channel_num:channel_num+1, ::-1, ::-1] + channel_frames.append( + np.concatenate( + [y_hat_frames, y_frames], + axis=3 + ) + ) + + channel_frames = np.concatenate(channel_frames, axis=2) + channel_frames = channel_frames.clip(0, None) + channel_frames = np.repeat(channel_frames, 3, axis=1)*255 + wandb.log({video_name: wandb.Video(channel_frames, fps=fps)}) class TrainingModule(pl.LightningModule): @@ -130,7 +190,7 @@ def __init__( history_mins: int, forecast_mins: int, sample_freq_mins: int, - optimizer = AdamW(), + optimizer = AdamWReduceLROnPlateau(), ): """tbc @@ -158,10 +218,19 @@ def __init__( forecast_len=self.forecast_len, ) + def _filter_missing_targets(self, y, y_hat): + + mask = y==-1 + y = y[~mask] + y_hat = y_hat[~mask] + return y, y_hat + def _calculate_common_losses(self, y, y_hat): """Calculate losses common to train, test, and val""" + + y, y_hat = self._filter_missing_targets(y, y_hat) losses = {} # calculate mse, mae @@ -177,26 +246,12 @@ def _calculate_common_losses(self, y, y_hat): return losses - def _step_mae_and_mse(self, y, y_hat, dict_key_root): - """Calculate the MSE and MAE at each forecast step""" - losses = {} - - mse_each_step = torch.mean((y_hat - y) ** 2, dim=0) - mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0) - - losses.update({f"MSE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mse_each_step)}) - losses.update({f"MAE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mae_each_step)}) - - return losses def _calculate_val_losses(self, y, y_hat): """Calculate additional validation losses""" losses = {} - # Log the loss at each time horizon - #losses.update(self._step_mae_and_mse(y, y_hat, dict_key_root="horizon")) - return losses def _calculate_test_losses(self, y, y_hat): @@ -238,7 +293,7 @@ def training_step(self, batch, batch_idx): self._training_accumulate_log(batch, batch_idx, losses, y_hat) - return losses["MSE/train"] + return losses["MAE/train"] def validation_step(self, batch: dict, batch_idx): """Run validation step""" @@ -257,22 +312,36 @@ def validation_step(self, batch: dict, batch_idx): on_epoch=True, ) - accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches + return logged_losses - if batch_idx in [0, 1]: - fig = plot_sat_images(y[0], y_hat[0], channel_inds=[8, 1, 2], n_frames=6) - - plot_name = f"val_samples/batch_idx_{batch_idx}" + def on_validation_epoch_start(self): + + val_dataset = self.trainer.val_dataloaders.dataset + + dates = ["2023-06-01T12:00", "2023-04-05T09:00", "2023-08-05T16:00"] + + X, y = default_collate([val_dataset[date]for date in dates]) + X = X.to(self.device) + y = y.to(self.device) + y_hat = self.model(X) + + for i in range(len(dates)): + + plot_name = f"val_sample_plots/{dates[i]}" + fig = plot_sat_images(y[i], y_hat[i], channel_inds=[8, 1, 2], n_frames=6) self.logger.experiment.log({plot_name: wandb.Image(fig)}) - plt.close(fig) - - return logged_losses - + + video_name = f"val_sample_videos/{dates[i]}_channel_8" + upload_video(y[i], y_hat[i], video_name, channel_nums=[8]) + + video_name = f"val_sample_videos/{dates[i]}_channel_1" + upload_video(y[i], y_hat[i], video_name, channel_nums=[1]) + + def on_validation_epoch_end(self): """Run on epoch end""" - return