diff --git a/sat_pred/loss.py b/sat_pred/loss.py index 3cdd2c6..cd5a81d 100644 --- a/sat_pred/loss.py +++ b/sat_pred/loss.py @@ -32,6 +32,9 @@ def name(self) -> str: def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Return loss""" + target = target.copy() + target[target==-1] = float('nan') + loss = 0 for scale in self.scales: diff --git a/sat_pred/training_module.py b/sat_pred/training_module.py index 2895c0d..32620a0 100644 --- a/sat_pred/training_module.py +++ b/sat_pred/training_module.py @@ -127,12 +127,7 @@ def __init__( self.video_plot_t0_times = video_plot_t0_times self.video_crop_plots = video_crop_plots self.multi_gpu = multi_gpu - - @staticmethod - def _minus_one_to_nan(y: torch.Tensor) -> None: - """Replace -1 values in tensor with NaNs in-place""" - y[y==-1] = torch.nan - + def _calculate_common_losses( self, y: torch.Tensor, @@ -147,9 +142,11 @@ def _calculate_common_losses( losses = {} - mse_loss = torch.nanmean(F.mse_loss(y_hat, y, reduction="none")) - mae_loss = torch.nanmean(F.l1_loss(y_hat, y, reduction="none")) - ssim_loss = torch.nanmean(1-self.ssim_func(y_hat, y)) # need to maximise SSIM + mask = y==-1 + + mse_loss = F.mse_loss(y_hat, y, reduction="none")[~mask].mean() + mae_loss = F.l1_loss(y_hat, y, reduction="none")[~mask].mean() + ssim_loss = (1-self.ssim_func(y_hat, y))[~mask].mean() # need to maximise SSIM losses = { "MSE": mse_loss, @@ -205,10 +202,6 @@ def training_step(self, batch, batch_idx: int) -> None | torch.Tensor: y_hat = self.model(X) del X - # Replace the -1 (filled) values in y with NaNs - # This operation is in-place - self._minus_one_to_nan(y) - losses = self._calculate_common_losses(y, y_hat) losses = {f"{k}/train": v for k, v in losses.items()} @@ -237,10 +230,6 @@ def validation_step(self, batch: dict, batch_idx: int): X, y = batch y_hat = self.model(X) del X - - # Replace the -1 (filled) values in y with NaNs - # This operation is in-place - self._minus_one_to_nan(y) losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat))