Skip to content

Commit

Permalink
tweak for nans
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Dec 5, 2024
1 parent 6a06026 commit f0c03ce
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 17 deletions.
3 changes: 3 additions & 0 deletions sat_pred/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def name(self) -> str:
def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Return loss"""

target = target.copy()
target[target==-1] = float('nan')

loss = 0

for scale in self.scales:
Expand Down
23 changes: 6 additions & 17 deletions sat_pred/training_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,7 @@ def __init__(
self.video_plot_t0_times = video_plot_t0_times
self.video_crop_plots = video_crop_plots
self.multi_gpu = multi_gpu

@staticmethod
def _minus_one_to_nan(y: torch.Tensor) -> None:
"""Replace -1 values in tensor with NaNs in-place"""
y[y==-1] = torch.nan


def _calculate_common_losses(
self,
y: torch.Tensor,
Expand All @@ -147,9 +142,11 @@ def _calculate_common_losses(

losses = {}

mse_loss = torch.nanmean(F.mse_loss(y_hat, y, reduction="none"))
mae_loss = torch.nanmean(F.l1_loss(y_hat, y, reduction="none"))
ssim_loss = torch.nanmean(1-self.ssim_func(y_hat, y)) # need to maximise SSIM
mask = y==-1

mse_loss = F.mse_loss(y_hat, y, reduction="none")[~mask].mean()
mae_loss = F.l1_loss(y_hat, y, reduction="none")[~mask].mean()
ssim_loss = (1-self.ssim_func(y_hat, y))[~mask].mean() # need to maximise SSIM

losses = {
"MSE": mse_loss,
Expand Down Expand Up @@ -205,10 +202,6 @@ def training_step(self, batch, batch_idx: int) -> None | torch.Tensor:
y_hat = self.model(X)
del X

# Replace the -1 (filled) values in y with NaNs
# This operation is in-place
self._minus_one_to_nan(y)

losses = self._calculate_common_losses(y, y_hat)
losses = {f"{k}/train": v for k, v in losses.items()}

Expand Down Expand Up @@ -237,10 +230,6 @@ def validation_step(self, batch: dict, batch_idx: int):
X, y = batch
y_hat = self.model(X)
del X

# Replace the -1 (filled) values in y with NaNs
# This operation is in-place
self._minus_one_to_nan(y)

losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))
Expand Down

0 comments on commit f0c03ce

Please sign in to comment.