@@ -63,6 +63,13 @@ def __init__(
63
63
def load_weights (self , weights_path , weight_keys = 'params' ):
64
64
self .net = load_pretrained_network (self .net , weights_path , weight_keys = weight_keys )
65
65
66
+ def is_valid_input (self , x ):
67
+ if x is not None :
68
+ assert isinstance (x , torch .Tensor ), 'Input must be a torch.Tensor'
69
+ assert x .dim () == 4 , 'Input must be 4D tensor (B, C, H, W)'
70
+ assert x .shape [1 ] in [1 , 3 ], 'Input must be RGB or gray image'
71
+ assert x .min () >= 0 and x .max () <= 1 , f'Input must be normalized to [0, 1], but got min={ x .min ():.4f} , max={ x .max ():.4f} '
72
+
66
73
def forward (self , target , ref = None , ** kwargs ):
67
74
device = self .dummy_param .device
68
75
@@ -80,6 +87,9 @@ def forward(self, target, ref=None, **kwargs):
80
87
assert ref is not None , 'Please specify reference image for Full Reference metric'
81
88
ref = imread2tensor (ref , rgb = True )
82
89
ref = ref .unsqueeze (0 )
90
+
91
+ self .is_valid_input (target )
92
+ self .is_valid_input (ref )
83
93
84
94
if self .metric_mode == 'FR' :
85
95
assert ref is not None , 'Please specify reference image for Full Reference metric'
0 commit comments