From 89f0f87aecaf3b1ad7f06f6bd29326c4b74098cf Mon Sep 17 00:00:00 2001 From: Richie Bendall Date: Wed, 17 Apr 2024 17:46:35 +1200 Subject: [PATCH 1/5] Add GaussianNoise --- torchvision/transforms/_functional_pil.py | 19 ++++++++++++++ torchvision/transforms/transforms.py | 1 + torchvision/transforms/v2/_color.py | 21 ++++++++++++++++ .../transforms/v2/functional/_color.py | 25 +++++++++++++++++++ 4 files changed, 66 insertions(+) diff --git a/torchvision/transforms/_functional_pil.py b/torchvision/transforms/_functional_pil.py index 277848224ac..62eaf4d72af 100644 --- a/torchvision/transforms/_functional_pil.py +++ b/torchvision/transforms/_functional_pil.py @@ -391,3 +391,22 @@ def equalize(img: Image.Image) -> Image.Image: if not _is_pil_image(img): raise TypeError(f"img should be PIL Image. Got {type(img)}") return ImageOps.equalize(img) + +@torch.jit.unused +def gaussian_noise(img: Image.Image, mean: float = 0., var: float = 1.0) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + if var < 0: + raise ValueError(f"var shouldn't be negative. Got {var}") + + z = np.random.normal( + loc=mean, + scale=var, + size=( + *get_image_size(img), + get_image_num_channels(img), + ), + ) + + return img + z diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 2a6e0ce12c0..420513609cf 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -54,6 +54,7 @@ "RandomAutocontrast", "RandomEqualize", "ElasticTransform", + "GaussianNoise", ] diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 49b4a8d8b10..e154b0f2684 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -374,3 +374,24 @@ def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor) + + +class GaussianNoise(Transform): + """Add gaussian noise to the image. Samples from `N(0, 1)` (standard normal distribution) by default. + + If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + mean (float): Mean of the sampled gaussian distribution. Default is 0. + var (float): Variance of the sampled gaussian distribution. Default is 1. + """ + + def __init__(self, mean: float = 0., var: float = 1.) -> None: + super().__init__() + self.mean = mean + self.var = var + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + self._call_kernel(F.gaussian_noise, inpt, mean=self.mean, var=self.var) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 3025f876dff..b3a4b20f7c1 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -737,3 +737,28 @@ def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) @_register_kernel_internal(permute_channels, tv_tensors.Video) def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor: return permute_channels_image(video, permutation=permutation) + +def gaussian_noise(inpt: torch.Tensor, mean: float = 0., var: float = 1.) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.GaussianNoise`""" + if torch.jit.is_scripting(): + return gaussian_noise_image(inpt, mean=mean, var=var) + + _log_api_usage_once(gaussian_noise) + + kernel = _get_kernel(gaussian_noise, type(inpt)) + return kernel(inpt, mean=mean, var=var) + +@_register_kernel_internal(gaussian_noise, torch.Tensor) +@_register_kernel_internal(gaussian_noise, tv_tensors.Image) +def gaussian_noise_image(image: torch.Tensor, mean: float = 0., var: float = 1.) -> torch.Tensor: + if var < 0: + raise ValueError(f"var shouldn't be negative. Got {var}") + + if image.numel() == 0: + return image + + z = mean + torch.randn_like(image) * var + + return image + z + +_gaussian_noise_pil = _register_kernel_internal(gaussian_noise, PIL.Image.Image)(_FP.gaussian_noise) From 6ccd9ad549dff09cd73d50089abfa45b9b4065ac Mon Sep 17 00:00:00 2001 From: Richie Bendall Date: Sat, 20 Apr 2024 22:21:04 +1200 Subject: [PATCH 2/5] Update --- docs/source/transforms.rst | 1 + torchvision/transforms/functional.py | 19 +++++++++++++++++++ torchvision/transforms/v2/__init__.py | 1 + 3 files changed, 21 insertions(+) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 54ed18394cd..ef7a17500e3 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -555,6 +555,7 @@ Color RandomAdjustSharpness RandomAutocontrast RandomEqualize + GaussianNoise Composition ^^^^^^^^^^^ diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 124d1da5f4f..01b3550620d 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1506,6 +1506,25 @@ def equalize(img: Tensor) -> Tensor: return F_t.equalize(img) +def gaussian_noise(img: Tensor, mean: float = 0., var: float = 1.) -> Tensor: + """Add gaussian noise to the image. Samples from `N(0, 1)` (standard normal distribution) by default. + + Args: + img (PIL Image or Tensor): Image on which equalize is applied. + mean (float): Mean of the sampled gaussian distribution. Default is 0. + var (float): Variance of the sampled gaussian distribution. Default is 1. + + Returns: + PIL Image or Tensor: An image that was equalized. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(gaussian_noise) + if not isinstance(img, torch.Tensor): + F_pil.gaussian_noise(img, mean, var) + + return F_t.gaussian_noise(img, mean, var) + + def elastic_transform( img: Tensor, displacement: Tensor, diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 6dccb8a5b78..9c0bdadeb5f 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -18,6 +18,7 @@ RandomPhotometricDistort, RandomPosterize, RandomSolarize, + GaussianNoise, RGB, ) from ._container import Compose, RandomApply, RandomChoice, RandomOrder From 7d1e5645b7a3264d64fc55f0179604023b6b243a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 30 May 2024 14:35:26 +0100 Subject: [PATCH 3/5] Cleanups, remove v1, enforce float32, and tests --- docs/source/transforms.rst | 3 +- test/test_transforms_v2.py | 78 ++++++++++++++++++- torchvision/transforms/_functional_pil.py | 19 ----- torchvision/transforms/functional.py | 4 +- torchvision/transforms/transforms.py | 1 - torchvision/transforms/v2/__init__.py | 2 +- torchvision/transforms/v2/_color.py | 21 ----- torchvision/transforms/v2/_misc.py | 25 ++++++ .../transforms/v2/functional/__init__.py | 3 + .../transforms/v2/functional/_color.py | 25 ------ torchvision/transforms/v2/functional/_misc.py | 39 ++++++++++ 11 files changed, 148 insertions(+), 72 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index ef7a17500e3..4bb18cf6b48 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -350,6 +350,7 @@ Color v2.RGB v2.RandomGrayscale v2.GaussianBlur + v2.GaussianNoise v2.RandomInvert v2.RandomPosterize v2.RandomSolarize @@ -368,6 +369,7 @@ Functionals v2.functional.grayscale_to_rgb v2.functional.to_grayscale v2.functional.gaussian_blur + v2.functional.gaussian_noise v2.functional.invert v2.functional.posterize v2.functional.solarize @@ -555,7 +557,6 @@ Color RandomAdjustSharpness RandomAutocontrast RandomEqualize - GaussianNoise Composition ^^^^^^^^^^^ diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index b0c1659f253..8a47a589508 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -111,8 +111,10 @@ def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs): input = input.as_subclass(torch.Tensor) with ignore_jit_no_profile_information_warning(): - actual = kernel_scripted(input, *args, **kwargs) - expected = kernel(input, *args, **kwargs) + with freeze_rng_state(): + actual = kernel_scripted(input, *args, **kwargs) + with freeze_rng_state(): + expected = kernel(input, *args, **kwargs) assert_close(actual, expected, rtol=rtol, atol=atol) @@ -3238,6 +3240,78 @@ def test_functional_image_correctness(self, dimensions, kernel_size, sigma, dtyp torch.testing.assert_close(actual, expected, rtol=0, atol=1) +class TestGaussianNoise: + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image, make_video], + ) + def test_kernel(self, make_input): + check_kernel( + F.gaussian_noise, + make_input(dtype=torch.float32), + # This cannot pass because the noise on a batch in not per-image + check_batched_vs_unbatched=False, + ) + + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image, make_video], + ) + def test_functional(self, make_input): + check_functional(F.gaussian_noise, make_input(dtype=torch.float32)) + + @pytest.mark.parametrize( + ("kernel", "input_type"), + [ + (F.gaussian_noise, torch.Tensor), + (F.gaussian_noise_image, tv_tensors.Image), + (F.gaussian_noise_video, tv_tensors.Video), + ], + ) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.gaussian_noise, kernel=kernel, input_type=input_type) + + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image, make_video], + ) + def test_transform(self, make_input): + def adapter(_, input, __): + # This transform doesn't support uint8 so we have to convert the auto-generated uint8 tensors to float32 + # Same for PIL images + for key, value in input.items(): + if isinstance(value, torch.Tensor) and not value.is_floating_point(): + input[key] = value.to(torch.float32) + if isinstance(value, PIL.Image.Image): + input[key] = F.pil_to_tensor(value).to(torch.float32) + return input + + check_transform(transforms.GaussianNoise(), make_input(dtype=torch.float32), check_sample_input=adapter) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Gaussian Noise is not implemented for PIL images."): + F.gaussian_noise(make_image_pil()) + with pytest.raises(ValueError, match="Input tensor is expected to be in float dtype"): + F.gaussian_noise(make_image(dtype=torch.uint8)) + with pytest.raises(ValueError, match="sigma shouldn't be negative"): + F.gaussian_noise(make_image(dtype=torch.float32), sigma=-1) + + def test_clip(self): + img = make_image(dtype=torch.float32) + + out = F.gaussian_noise(img, mean=100, clip=False) + assert out.min() > 50 + + out = F.gaussian_noise(img, mean=100, clip=True) + assert (out == 1).all() + + out = F.gaussian_noise(img, mean=-100, clip=False) + assert out.min() < -50 + + out = F.gaussian_noise(img, mean=-100, clip=True) + assert (out == 0).all() + + class TestAutoAugmentTransforms: # These transforms have a lot of branches in their `forward()` passes which are conditioned on random sampling. # It's typically very hard to test the effect on some parameters without heavy mocking logic. diff --git a/torchvision/transforms/_functional_pil.py b/torchvision/transforms/_functional_pil.py index 62eaf4d72af..277848224ac 100644 --- a/torchvision/transforms/_functional_pil.py +++ b/torchvision/transforms/_functional_pil.py @@ -391,22 +391,3 @@ def equalize(img: Image.Image) -> Image.Image: if not _is_pil_image(img): raise TypeError(f"img should be PIL Image. Got {type(img)}") return ImageOps.equalize(img) - -@torch.jit.unused -def gaussian_noise(img: Image.Image, mean: float = 0., var: float = 1.0) -> Image.Image: - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - if var < 0: - raise ValueError(f"var shouldn't be negative. Got {var}") - - z = np.random.normal( - loc=mean, - scale=var, - size=( - *get_image_size(img), - get_image_num_channels(img), - ), - ) - - return img + z diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 01b3550620d..01a47843565 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1506,7 +1506,7 @@ def equalize(img: Tensor) -> Tensor: return F_t.equalize(img) -def gaussian_noise(img: Tensor, mean: float = 0., var: float = 1.) -> Tensor: +def gaussian_noise(img: Tensor, mean: float = 0.0, var: float = 1.0) -> Tensor: """Add gaussian noise to the image. Samples from `N(0, 1)` (standard normal distribution) by default. Args: @@ -1521,7 +1521,7 @@ def gaussian_noise(img: Tensor, mean: float = 0., var: float = 1.) -> Tensor: _log_api_usage_once(gaussian_noise) if not isinstance(img, torch.Tensor): F_pil.gaussian_noise(img, mean, var) - + return F_t.gaussian_noise(img, mean, var) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 420513609cf..2a6e0ce12c0 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -54,7 +54,6 @@ "RandomAutocontrast", "RandomEqualize", "ElasticTransform", - "GaussianNoise", ] diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 9c0bdadeb5f..33d83f1fe3f 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -18,7 +18,6 @@ RandomPhotometricDistort, RandomPosterize, RandomSolarize, - GaussianNoise, RGB, ) from ._container import Compose, RandomApply, RandomChoice, RandomOrder @@ -46,6 +45,7 @@ from ._misc import ( ConvertImageDtype, GaussianBlur, + GaussianNoise, Identity, Lambda, LinearTransformation, diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index e154b0f2684..49b4a8d8b10 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -374,24 +374,3 @@ def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor) - - -class GaussianNoise(Transform): - """Add gaussian noise to the image. Samples from `N(0, 1)` (standard normal distribution) by default. - - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, - where ... means it can have an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - mean (float): Mean of the sampled gaussian distribution. Default is 0. - var (float): Variance of the sampled gaussian distribution. Default is 1. - """ - - def __init__(self, mean: float = 0., var: float = 1.) -> None: - super().__init__() - self.mean = mean - self.var = var - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - self._call_kernel(F.gaussian_noise, inpt, mean=self.mean, var=self.var) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index ad2c08150cc..241ff68bf9d 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -205,6 +205,31 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.gaussian_blur, inpt, self.kernel_size, **params) +class GaussianNoise(Transform): + """Add gaussian noise to the image. + + The input tensor is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + + The input tensor is also expected to be of float dtype in ``[0, 1]``. + This transform does not support PIL images. + + Args: + mean (float): Mean of the sampled normal distribution. Default is 0. + sigma (float): Standard deviation of the sampled normal distribution. Default is 0.1. + clip (bool, optional): Whether to clip the values in ``[0, 1]`` after adding noise. Default is True. + """ + + def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip=True) -> None: + super().__init__() + self.mean = mean + self.sigma = sigma + self.clip = clip + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.gaussian_noise, inpt, mean=self.mean, sigma=self.sigma, clip=self.clip) + + class ToDtype(Transform): """Converts the input to a specific dtype, optionally scaling the values for images or videos. diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 4d4bbf2e86d..d5705d55c4b 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -136,6 +136,9 @@ gaussian_blur, gaussian_blur_image, gaussian_blur_video, + gaussian_noise, + gaussian_noise_image, + gaussian_noise_video, normalize, normalize_image, normalize_video, diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 9efec7a1da4..34d1e101dbd 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -737,28 +737,3 @@ def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) @_register_kernel_internal(permute_channels, tv_tensors.Video) def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor: return permute_channels_image(video, permutation=permutation) - -def gaussian_noise(inpt: torch.Tensor, mean: float = 0., var: float = 1.) -> torch.Tensor: - """See :class:`~torchvision.transforms.v2.GaussianNoise`""" - if torch.jit.is_scripting(): - return gaussian_noise_image(inpt, mean=mean, var=var) - - _log_api_usage_once(gaussian_noise) - - kernel = _get_kernel(gaussian_noise, type(inpt)) - return kernel(inpt, mean=mean, var=var) - -@_register_kernel_internal(gaussian_noise, torch.Tensor) -@_register_kernel_internal(gaussian_noise, tv_tensors.Image) -def gaussian_noise_image(image: torch.Tensor, mean: float = 0., var: float = 1.) -> torch.Tensor: - if var < 0: - raise ValueError(f"var shouldn't be negative. Got {var}") - - if image.numel() == 0: - return image - - z = mean + torch.randn_like(image) * var - - return image + z - -_gaussian_noise_pil = _register_kernel_internal(gaussian_noise, PIL.Image.Image)(_FP.gaussian_noise) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 12d064f6638..d0b8a413aec 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -6,6 +6,7 @@ from torch.nn.functional import conv2d, pad as torch_pad from torchvision import tv_tensors +from torchvision.transforms import _functional_pil as _FP from torchvision.transforms._functional_tensor import _max_value from torchvision.transforms.functional import pil_to_tensor, to_pil_image @@ -181,6 +182,44 @@ def gaussian_blur_video( return gaussian_blur_image(video, kernel_size, sigma) +def gaussian_noise(inpt: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.GaussianNoise`""" + if torch.jit.is_scripting(): + return gaussian_noise_image(inpt, mean=mean, sigma=sigma) + + _log_api_usage_once(gaussian_noise) + + kernel = _get_kernel(gaussian_noise, type(inpt)) + return kernel(inpt, mean=mean, sigma=sigma, clip=clip) + + +@_register_kernel_internal(gaussian_noise, torch.Tensor) +@_register_kernel_internal(gaussian_noise, tv_tensors.Image) +def gaussian_noise_image(image: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor: + if not image.is_floating_point(): + raise ValueError(f"Input tensor is expected to be in float dtype, got dtype={image.dtype}") + if sigma < 0: + raise ValueError(f"sigma shouldn't be negative. Got {sigma}") + + noise = mean + torch.randn_like(image) * sigma + out = image + noise + if clip: + out = torch.clamp(out, 0, 1) + return out + + +@_register_kernel_internal(gaussian_noise, tv_tensors.Video) +def gaussian_noise_video(video: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor: + return gaussian_noise_image(video, mean=mean, sigma=sigma, clip=clip) + + +@_register_kernel_internal(gaussian_noise, PIL.Image.Image) +def _gaussian_noise_pil( + video: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True +) -> PIL.Image.Image: + raise ValueError("Gaussian Noise is not implemented for PIL images.") + + def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: """See :func:`~torchvision.transforms.v2.ToDtype` for details.""" if torch.jit.is_scripting(): From 2adf89477f2ee6d685da67eb03b8279e778b3cf5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 30 May 2024 14:36:44 +0100 Subject: [PATCH 4/5] Remove functional v1 --- torchvision/transforms/functional.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 01a47843565..124d1da5f4f 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1506,25 +1506,6 @@ def equalize(img: Tensor) -> Tensor: return F_t.equalize(img) -def gaussian_noise(img: Tensor, mean: float = 0.0, var: float = 1.0) -> Tensor: - """Add gaussian noise to the image. Samples from `N(0, 1)` (standard normal distribution) by default. - - Args: - img (PIL Image or Tensor): Image on which equalize is applied. - mean (float): Mean of the sampled gaussian distribution. Default is 0. - var (float): Variance of the sampled gaussian distribution. Default is 1. - - Returns: - PIL Image or Tensor: An image that was equalized. - """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(gaussian_noise) - if not isinstance(img, torch.Tensor): - F_pil.gaussian_noise(img, mean, var) - - return F_t.gaussian_noise(img, mean, var) - - def elastic_transform( img: Tensor, displacement: Tensor, From d926a8342772895c963563d12de01240d2f79b7f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 31 May 2024 10:36:46 +0100 Subject: [PATCH 5/5] lint + doc --- torchvision/transforms/v2/_misc.py | 4 +++- torchvision/transforms/v2/functional/_misc.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 241ff68bf9d..6d62539ccd7 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -206,10 +206,12 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class GaussianNoise(Transform): - """Add gaussian noise to the image. + """Add gaussian noise to images or videos. The input tensor is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. + Each image or frame in a batch will be transformed independently i.e. the + noise added to each image will be different. The input tensor is also expected to be of float dtype in ``[0, 1]``. This transform does not support PIL images. diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index d0b8a413aec..84b686d50f9 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -6,7 +6,6 @@ from torch.nn.functional import conv2d, pad as torch_pad from torchvision import tv_tensors -from torchvision.transforms import _functional_pil as _FP from torchvision.transforms._functional_tensor import _max_value from torchvision.transforms.functional import pil_to_tensor, to_pil_image