Skip to content

Commit f709766

Browse files
authored
Fix focal loss tests (#8920)
1 parent b5c7443 commit f709766

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchvision/ops/focal_loss.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def sigmoid_focal_loss(
3333
"""
3434
# Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
3535

36-
if not (0 <= alpha <= 1) or alpha != -1:
36+
if not (0 <= alpha <= 1) and alpha != -1:
3737
raise ValueError(f"Invalid alpha value: {alpha}. alpha must be in the range [0,1] or -1 for ignore.")
3838

3939
if not torch.jit.is_scripting() and not torch.jit.is_tracing():

0 commit comments

Comments
 (0)