Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transparency support to webp decoder #8610

Merged
merged 4 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def test_decode_gif_webp_errors(decode_fun):
if decode_fun is decode_gif:
expected_match = re.escape("DGifOpenFileName() failed - 103")
elif decode_fun is decode_webp:
expected_match = "WebPDecodeRGB failed."
expected_match = "WebPGetFeatures failed."
with pytest.raises(RuntimeError, match=expected_match):
decode_fun(encoded_data)

Expand All @@ -891,6 +891,31 @@ def test_decode_webp(decode_fun, scripted):
assert img[None].is_contiguous(memory_format=torch.channels_last)


# This test is skipped because it requires webp images that we're not including
# within the repo. The test images were downloaded from the different pages of
# https://developers.google.com/speed/webp/gallery
# Note that converting an RGBA image to RGB leads to bad results because the
# transparent pixels aren't necessarily set to "black" or "white", they can be
# random stuff. This is consistent with PIL results.
@pytest.mark.skip(reason="Need to download test images first")
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize(
"mode, pil_mode", ((ImageReadMode.RGB, "RGB"), (ImageReadMode.RGB_ALPHA, "RGBA"), (ImageReadMode.UNCHANGED, None))
)
@pytest.mark.parametrize("filename", Path("/home/nicolashug/webp_samples").glob("*.webp"))
def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename):
encoded_bytes = read_file(filename)
if scripted:
decode_fun = torch.jit.script(decode_fun)
img = decode_fun(encoded_bytes, mode=mode)
assert img[None].is_contiguous(memory_format=torch.channels_last)

pil_img = Image.open(filename).convert(pil_mode)
from_pil = F.pil_to_tensor(pil_img)
assert_equal(img, from_pil)


@pytest.mark.xfail(reason="AVIF support not enabled yet.")
@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/io/image/cpu/decode_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ torch::Tensor decode_image(
TORCH_CHECK(data.numel() >= 15, err_msg);
if ((memcmp(webp_signature_begin, datap, 4) == 0) &&
(memcmp(webp_signature_end, datap + 8, 7) == 0)) {
return decode_webp(data);
return decode_webp(data, mode);
}

TORCH_CHECK(false, err_msg);
Expand Down
46 changes: 39 additions & 7 deletions torchvision/csrc/io/image/cpu/decode_webp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ namespace vision {
namespace image {

#if !WEBP_FOUND
torch::Tensor decode_webp(const torch::Tensor& data) {
torch::Tensor decode_webp(
const torch::Tensor& encoded_data,
ImageReadMode mode) {
TORCH_CHECK(
false, "decode_webp: torchvision not compiled with libwebp support");
}
#else

torch::Tensor decode_webp(const torch::Tensor& encoded_data) {
torch::Tensor decode_webp(
const torch::Tensor& encoded_data,
ImageReadMode mode) {
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
Expand All @@ -26,13 +30,41 @@ torch::Tensor decode_webp(const torch::Tensor& encoded_data) {
encoded_data.dim(),
" dims.");

auto encoded_data_p = encoded_data.data_ptr<uint8_t>();
auto encoded_data_size = encoded_data.numel();

WebPBitstreamFeatures features;
auto res = WebPGetFeatures(encoded_data_p, encoded_data_size, &features);
TORCH_CHECK(
res == VP8_STATUS_OK, "WebPGetFeatures failed with error code ", res);
TORCH_CHECK(
!features.has_animation, "Animated webp files are not supported.");

auto decoding_func = WebPDecodeRGB;
int num_channels = 0;
if (mode == IMAGE_READ_MODE_RGB) {
decoding_func = WebPDecodeRGB;
num_channels = 3;
} else if (mode == IMAGE_READ_MODE_RGB_ALPHA) {
decoding_func = WebPDecodeRGBA;
num_channels = 4;
} else {
// Assume mode is "unchanged"
decoding_func = features.has_alpha ? WebPDecodeRGBA : WebPDecodeRGB;
num_channels = features.has_alpha ? 4 : 3;
}

int width = 0;
int height = 0;
auto decoded_data = WebPDecodeRGB(
encoded_data.data_ptr<uint8_t>(), encoded_data.numel(), &width, &height);
TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB failed.");
auto out = torch::from_blob(decoded_data, {height, width, 3}, torch::kUInt8);
return out.permute({2, 0, 1}); // return CHW, channels-last

auto decoded_data =
decoding_func(encoded_data_p, encoded_data_size, &width, &height);
TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB[A] failed.");

auto out = torch::from_blob(
decoded_data, {height, width, num_channels}, torch::kUInt8);

return out.permute({2, 0, 1});
}
#endif // WEBP_FOUND

Expand Down
5 changes: 4 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_webp.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#pragma once

#include <torch/types.h>
#include "../image_read_mode.h"

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_webp(const torch::Tensor& data);
C10_EXPORT torch::Tensor decode_webp(
const torch::Tensor& encoded_data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);

} // namespace image
} // namespace vision
3 changes: 2 additions & 1 deletion torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ static auto registry =
.op("image::encode_png", &encode_png)
.op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
&decode_jpeg)
.op("image::decode_webp", &decode_webp)
.op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor",
&decode_webp)
.op("image::decode_avif", &decode_avif)
.op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file)
Expand Down
16 changes: 12 additions & 4 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class ImageReadMode(Enum):
``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
RGB with transparency.

.. note::

Some decoders won't support all possible values, e.g. a decoder may only
support "RGB" and "RGBA" mode.
"""

UNCHANGED = 0
Expand Down Expand Up @@ -365,23 +370,26 @@ def decode_gif(input: torch.Tensor) -> torch.Tensor:

def decode_webp(
input: torch.Tensor,
mode: ImageReadMode = ImageReadMode.UNCHANGED,
) -> torch.Tensor:
"""
Decode a WEBP image into a 3 dimensional RGB Tensor.
Decode a WEBP image into a 3 dimensional RGB[A] Tensor.

The values of the output tensor are uint8 between 0 and 255. If the input
image is RGBA, the transparency is ignored.
The values of the output tensor are uint8 between 0 and 255.

Args:
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
the raw bytes of the WEBP image.
mode (ImageReadMode): The read mode used for optionally
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.

Returns:
Decoded image (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_webp)
return torch.ops.image.decode_webp(input)
return torch.ops.image.decode_webp(input, mode.value)


def _decode_avif(
Expand Down
Loading