You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
i have met the problem loss=nan too.Here is my solution.
Because the loss function is (sqrt(g_gound)-sqrt(g_hat))^2,grad will be nan when g_hat is 0.The code below may fix the problem:
class CustomLoss(nn.Module):
...
def forward(...):
....
rb=targets[:,:,34:68]
Hi,
i have met the problem loss=nan too.Here is my solution.
Because the loss function is (sqrt(g_gound)-sqrt(g_hat))^2,grad will be nan when g_hat is 0.The code below may fix the problem:
class CustomLoss(nn.Module):
...
def forward(...):
....
rb=targets[:,:,34:68]
try to avoid nan
mask = gb_hat<0.0003
gamma_gb_hat=torch.FloatTensor(gb_hat.size()).type_as(gb_hat)
gamma_gb_hat=1290*gb_hat[mask]
mask = gb_hat>=0.0003
gamma_gb_hat[mask]=torch,pow(gb_hat[mask],gamma)
mask = (1-rb_hat)<0.0003
gamma_rb_hat=torch.FloatTensor(rb_hat.size()).type_as(rb_hat)
gamma_rb_hat=1290*(1-rb_hat[mask])
mask = (1-rb_hat)>=0.0003
gamma_rb_hat[mask]=torch,pow((1-rb_hat[mask]),gamma)
return torch.mean(torch.pow( ( torch.pow(gb,gamma) - gamma_gb_hat ),2 ) )
+ C4 * torch.mean(torch.pow( ( torch.pow(gb,gamma) - gamma_gb_hat ),4 ) )
+ torch.mean(torch.pow( ( torch.pow(1-rb,gamma) - gamma_rb_hat ),2 ) )\
The text was updated successfully, but these errors were encountered: