Skip to content

Commit 0aa6f71

Browse files
committed
feat: 🚩 add input validation check
1 parent df1defa commit 0aa6f71

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

pyiqa/models/inference_model.py

+10
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def __init__(
6363
def load_weights(self, weights_path, weight_keys='params'):
6464
self.net = load_pretrained_network(self.net, weights_path, weight_keys=weight_keys)
6565

66+
def is_valid_input(self, x):
67+
if x is not None:
68+
assert isinstance(x, torch.Tensor), 'Input must be a torch.Tensor'
69+
assert x.dim() == 4, 'Input must be 4D tensor (B, C, H, W)'
70+
assert x.shape[1] in [1, 3], 'Input must be RGB or gray image'
71+
assert x.min() >= 0 and x.max() <= 1, f'Input must be normalized to [0, 1], but got min={x.min():.4f}, max={x.max():.4f}'
72+
6673
def forward(self, target, ref=None, **kwargs):
6774
device = self.dummy_param.device
6875

@@ -80,6 +87,9 @@ def forward(self, target, ref=None, **kwargs):
8087
assert ref is not None, 'Please specify reference image for Full Reference metric'
8188
ref = imread2tensor(ref, rgb=True)
8289
ref = ref.unsqueeze(0)
90+
91+
self.is_valid_input(target)
92+
self.is_valid_input(ref)
8393

8494
if self.metric_mode == 'FR':
8595
assert ref is not None, 'Please specify reference image for Full Reference metric'

0 commit comments

Comments
 (0)