Skip to content

Commit f15d1d3

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Cleanup/refactor of decoders and related tests (#8617)
Reviewed By: ahmadsharif1 Differential Revision: D62032043 fbshipit-source-id: c09f44f1855852095a326d4270c817a0ecc1bd3d
1 parent 106fc31 commit f15d1d3

18 files changed

+141
-133
lines changed

test/test_image.py

+61-39
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,19 @@
4242
IS_WINDOWS = sys.platform in ("win32", "cygwin")
4343
IS_MACOS = sys.platform == "darwin"
4444
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
45+
WEBP_TEST_IMAGES_DIR = os.environ.get("WEBP_TEST_IMAGES_DIR", "")
46+
47+
# Hacky way of figuring out whether we compiled with libavif/libheif (those are
48+
# currenlty disabled by default)
49+
try:
50+
_decode_avif(torch.arange(10, dtype=torch.uint8))
51+
except Exception as e:
52+
DECODE_AVIF_ENABLED = "torchvision not compiled with libavif support" not in str(e)
53+
54+
try:
55+
_decode_heic(torch.arange(10, dtype=torch.uint8))
56+
except Exception as e:
57+
DECODE_HEIC_ENABLED = "torchvision not compiled with libheif support" not in str(e)
4558

4659

4760
def _get_safe_image_name(name):
@@ -149,17 +162,6 @@ def test_invalid_exif(tmpdir, size):
149162
torch.testing.assert_close(expected, output)
150163

151164

152-
def test_decode_jpeg_errors():
153-
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
154-
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
155-
156-
with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
157-
decode_jpeg(torch.empty((100,), dtype=torch.float16))
158-
159-
with pytest.raises(RuntimeError, match="Not a JPEG file"):
160-
decode_jpeg(torch.empty((100), dtype=torch.uint8))
161-
162-
163165
def test_decode_bad_huffman_images():
164166
# sanity check: make sure we can decode the bad Huffman encoding
165167
bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
@@ -235,10 +237,6 @@ def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):
235237

236238

237239
def test_decode_png_errors():
238-
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
239-
decode_png(torch.empty((), dtype=torch.uint8))
240-
with pytest.raises(RuntimeError, match="Content is not png"):
241-
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
242240
with pytest.raises(RuntimeError, match="Out of bound read in decode_png"):
243241
decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png")))
244242
with pytest.raises(RuntimeError, match="Content is too small for png"):
@@ -864,20 +862,28 @@ def test_decode_gif(tmpdir, name, scripted):
864862
torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0)
865863

866864

867-
@pytest.mark.parametrize("decode_fun", (decode_gif, decode_webp))
868-
def test_decode_gif_webp_errors(decode_fun):
865+
decode_fun_and_match = [
866+
(decode_png, "Content is not png"),
867+
(decode_jpeg, "Not a JPEG file"),
868+
(decode_gif, re.escape("DGifOpenFileName() failed - 103")),
869+
(decode_webp, "WebPGetFeatures failed."),
870+
]
871+
if DECODE_AVIF_ENABLED:
872+
decode_fun_and_match.append((_decode_avif, "BMFF parsing failed"))
873+
if DECODE_HEIC_ENABLED:
874+
decode_fun_and_match.append((_decode_heic, "Invalid input: No 'ftyp' box"))
875+
876+
877+
@pytest.mark.parametrize("decode_fun, match", decode_fun_and_match)
878+
def test_decode_bad_encoded_data(decode_fun, match):
869879
encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8)
870880
with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"):
871881
decode_fun(encoded_data[None])
872882
with pytest.raises(RuntimeError, match="Input tensor must have uint8 data type"):
873883
decode_fun(encoded_data.float())
874884
with pytest.raises(RuntimeError, match="Input tensor must be contiguous"):
875885
decode_fun(encoded_data[::2])
876-
if decode_fun is decode_gif:
877-
expected_match = re.escape("DGifOpenFileName() failed - 103")
878-
elif decode_fun is decode_webp:
879-
expected_match = "WebPGetFeatures failed."
880-
with pytest.raises(RuntimeError, match=expected_match):
886+
with pytest.raises(RuntimeError, match=match):
881887
decode_fun(encoded_data)
882888

883889

@@ -890,21 +896,27 @@ def test_decode_webp(decode_fun, scripted):
890896
img = decode_fun(encoded_bytes)
891897
assert img.shape == (3, 100, 100)
892898
assert img[None].is_contiguous(memory_format=torch.channels_last)
899+
img += 123 # make sure image buffer wasn't freed by underlying decoding lib
893900

894901

895-
# This test is skipped because it requires webp images that we're not including
896-
# within the repo. The test images were downloaded from the different pages of
897-
# https://developers.google.com/speed/webp/gallery
898-
# Note that converting an RGBA image to RGB leads to bad results because the
899-
# transparent pixels aren't necessarily set to "black" or "white", they can be
900-
# random stuff. This is consistent with PIL results.
901-
@pytest.mark.skip(reason="Need to download test images first")
902+
# This test is skipped by default because it requires webp images that we're not
903+
# including within the repo. The test images were downloaded manually from the
904+
# different pages of https://developers.google.com/speed/webp/gallery
905+
@pytest.mark.skipif(not WEBP_TEST_IMAGES_DIR, reason="WEBP_TEST_IMAGES_DIR is not set")
902906
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
903907
@pytest.mark.parametrize("scripted", (False, True))
904908
@pytest.mark.parametrize(
905-
"mode, pil_mode", ((ImageReadMode.RGB, "RGB"), (ImageReadMode.RGB_ALPHA, "RGBA"), (ImageReadMode.UNCHANGED, None))
909+
"mode, pil_mode",
910+
(
911+
# Note that converting an RGBA image to RGB leads to bad results because the
912+
# transparent pixels aren't necessarily set to "black" or "white", they can be
913+
# random stuff. This is consistent with PIL results.
914+
(ImageReadMode.RGB, "RGB"),
915+
(ImageReadMode.RGB_ALPHA, "RGBA"),
916+
(ImageReadMode.UNCHANGED, None),
917+
),
906918
)
907-
@pytest.mark.parametrize("filename", Path("/home/nicolashug/webp_samples").glob("*.webp"))
919+
@pytest.mark.parametrize("filename", Path(WEBP_TEST_IMAGES_DIR).glob("*.webp"), ids=lambda p: p.name)
908920
def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename):
909921
encoded_bytes = read_file(filename)
910922
if scripted:
@@ -915,9 +927,10 @@ def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename)
915927
pil_img = Image.open(filename).convert(pil_mode)
916928
from_pil = F.pil_to_tensor(pil_img)
917929
assert_equal(img, from_pil)
930+
img += 123 # make sure image buffer wasn't freed by underlying decoding lib
918931

919932

920-
@pytest.mark.xfail(reason="AVIF support not enabled yet.")
933+
@pytest.mark.skipif(not DECODE_AVIF_ENABLED, reason="AVIF support not enabled.")
921934
@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
922935
@pytest.mark.parametrize("scripted", (False, True))
923936
def test_decode_avif(decode_fun, scripted):
@@ -927,12 +940,20 @@ def test_decode_avif(decode_fun, scripted):
927940
img = decode_fun(encoded_bytes)
928941
assert img.shape == (3, 100, 100)
929942
assert img[None].is_contiguous(memory_format=torch.channels_last)
943+
img += 123 # make sure image buffer wasn't freed by underlying decoding lib
930944

931945

932-
@pytest.mark.xfail(reason="AVIF and HEIC support not enabled yet.")
933946
# Note: decode_image fails because some of these files have a (valid) signature
934947
# we don't recognize. We should probably use libmagic....
935-
@pytest.mark.parametrize("decode_fun", (_decode_avif, _decode_heic))
948+
decode_funs = []
949+
if DECODE_AVIF_ENABLED:
950+
decode_funs.append(_decode_avif)
951+
if DECODE_HEIC_ENABLED:
952+
decode_funs.append(_decode_heic)
953+
954+
955+
@pytest.mark.skipif(not decode_funs, reason="Built without avif and heic support.")
956+
@pytest.mark.parametrize("decode_fun", decode_funs)
936957
@pytest.mark.parametrize("scripted", (False, True))
937958
@pytest.mark.parametrize(
938959
"mode, pil_mode",
@@ -945,7 +966,7 @@ def test_decode_avif(decode_fun, scripted):
945966
@pytest.mark.parametrize(
946967
"filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"), ids=lambda p: p.name
947968
)
948-
def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename):
969+
def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, filename):
949970
if "reversed_dimg_order" in str(filename):
950971
# Pillow properly decodes this one, but we don't (order of parts of the
951972
# image is wrong). This is due to a bug that was recently fixed in
@@ -996,21 +1017,21 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename)
9961017
g = make_grid([img, from_pil])
9971018
F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png"))
9981019

999-
is__decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic"
1000-
if mode == ImageReadMode.RGB and not is__decode_heic:
1020+
is_decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic"
1021+
if mode == ImageReadMode.RGB and not is_decode_heic:
10011022
# We don't compare torchvision's AVIF against PIL for RGB because
10021023
# results look pretty different on RGBA images (other images are fine).
10031024
# The result on torchvision basically just plainly ignores the alpha
10041025
# channel, resuting in transparent pixels looking dark. PIL seems to be
10051026
# using a sort of k-nn thing (Take a look at the resuting images)
10061027
return
1007-
if filename.name == "sofa_grid1x5_420.avif" and is__decode_heic:
1028+
if filename.name == "sofa_grid1x5_420.avif" and is_decode_heic:
10081029
return
10091030

10101031
torch.testing.assert_close(img, from_pil, rtol=0, atol=3)
10111032

10121033

1013-
@pytest.mark.xfail(reason="HEIC support not enabled yet.")
1034+
@pytest.mark.skipif(not DECODE_HEIC_ENABLED, reason="HEIC support not enabled yet.")
10141035
@pytest.mark.parametrize("decode_fun", (_decode_heic, decode_image))
10151036
@pytest.mark.parametrize("scripted", (False, True))
10161037
def test_decode_heic(decode_fun, scripted):
@@ -1020,6 +1041,7 @@ def test_decode_heic(decode_fun, scripted):
10201041
img = decode_fun(encoded_bytes)
10211042
assert img.shape == (3, 100, 100)
10221043
assert img[None].is_contiguous(memory_format=torch.channels_last)
1044+
img += 123 # make sure image buffer wasn't freed by underlying decoding lib
10231045

10241046

10251047
if __name__ == "__main__":

torchvision/csrc/io/image/common.cpp

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
2+
#include "common.h"
3+
#include <torch/torch.h>
4+
5+
namespace vision {
6+
namespace image {
7+
8+
void validate_encoded_data(const torch::Tensor& encoded_data) {
9+
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
10+
TORCH_CHECK(
11+
encoded_data.dtype() == torch::kU8,
12+
"Input tensor must have uint8 data type, got ",
13+
encoded_data.dtype());
14+
TORCH_CHECK(
15+
encoded_data.dim() == 1 && encoded_data.numel() > 0,
16+
"Input tensor must be 1-dimensional and non-empty, got ",
17+
encoded_data.dim(),
18+
" dims and ",
19+
encoded_data.numel(),
20+
" numels.");
21+
}
22+
23+
bool should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video(
24+
ImageReadMode mode,
25+
bool has_alpha) {
26+
// Return true if the calling decoding function should return a 3D RGB tensor,
27+
// and false if it should return a 4D RGBA tensor.
28+
// This function ignores the requested "grayscale" modes and treats it as
29+
// "unchanged", so it should only used on decoders who don't support grayscale
30+
// outputs.
31+
32+
if (mode == IMAGE_READ_MODE_RGB) {
33+
return true;
34+
}
35+
if (mode == IMAGE_READ_MODE_RGB_ALPHA) {
36+
return false;
37+
}
38+
// From here we assume mode is "unchanged", even for grayscale ones.
39+
return !has_alpha;
40+
}
41+
42+
} // namespace image
43+
} // namespace vision

torchvision/csrc/io/image/image_read_mode.h torchvision/csrc/io/image/common.h

+7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <stdint.h>
4+
#include <torch/torch.h>
45

56
namespace vision {
67
namespace image {
@@ -13,5 +14,11 @@ const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2;
1314
const ImageReadMode IMAGE_READ_MODE_RGB = 3;
1415
const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4;
1516

17+
void validate_encoded_data(const torch::Tensor& encoded_data);
18+
19+
bool should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video(
20+
ImageReadMode mode,
21+
bool has_alpha);
22+
1623
} // namespace image
1724
} // namespace vision

torchvision/csrc/io/image/cpu/decode_avif.cpp

+5-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "decode_avif.h"
2+
#include "../common.h"
23

34
#if AVIF_FOUND
45
#include "avif/avif.h"
@@ -33,16 +34,7 @@ torch::Tensor decode_avif(
3334
// Refer there for more detail about what each function does, and which
3435
// structure/data is available after which call.
3536

36-
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
37-
TORCH_CHECK(
38-
encoded_data.dtype() == torch::kU8,
39-
"Input tensor must have uint8 data type, got ",
40-
encoded_data.dtype());
41-
TORCH_CHECK(
42-
encoded_data.dim() == 1,
43-
"Input tensor must be 1-dimensional, got ",
44-
encoded_data.dim(),
45-
" dims.");
37+
validate_encoded_data(encoded_data);
4638

4739
DecoderPtr decoder(avifDecoderCreate());
4840
TORCH_CHECK(decoder != nullptr, "Failed to create avif decoder.");
@@ -60,6 +52,7 @@ torch::Tensor decode_avif(
6052
result == AVIF_RESULT_OK,
6153
"avifDecoderParse failed: ",
6254
avifResultToString(result));
55+
printf("avif num images = %d\n", decoder->imageCount);
6356
TORCH_CHECK(
6457
decoder->imageCount == 1, "Avif file contains more than one image");
6558

@@ -78,18 +71,9 @@ torch::Tensor decode_avif(
7871
auto use_uint8 = (decoder->image->depth <= 8);
7972
rgb.depth = use_uint8 ? 8 : 16;
8073

81-
if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB &&
82-
mode != IMAGE_READ_MODE_RGB_ALPHA) {
83-
// Other modes aren't supported, but we don't error or even warn because we
84-
// have generic entry points like decode_image which may support all modes,
85-
// it just depends on the underlying decoder.
86-
mode = IMAGE_READ_MODE_UNCHANGED;
87-
}
88-
89-
// If return_rgb is false it means we return rgba - nothing else.
9074
auto return_rgb =
91-
(mode == IMAGE_READ_MODE_RGB ||
92-
(mode == IMAGE_READ_MODE_UNCHANGED && !decoder->alphaPresent));
75+
should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video(
76+
mode, decoder->alphaPresent);
9377

9478
auto num_channels = return_rgb ? 3 : 4;
9579
rgb.format = return_rgb ? AVIF_RGB_FORMAT_RGB : AVIF_RGB_FORMAT_RGBA;

torchvision/csrc/io/image/cpu/decode_avif.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

33
#include <torch/types.h>
4-
#include "../image_read_mode.h"
4+
#include "../common.h"
55

66
namespace vision {
77
namespace image {

torchvision/csrc/io/image/cpu/decode_gif.cpp

+2-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "decode_gif.h"
22
#include <cstring>
3+
#include "../common.h"
34
#include "giflib/gif_lib.h"
45

56
namespace vision {
@@ -34,16 +35,7 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) {
3435
// Refer over there for more details on the libgif API, API ref, and a
3536
// detailed description of the GIF format.
3637

37-
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
38-
TORCH_CHECK(
39-
encoded_data.dtype() == torch::kU8,
40-
"Input tensor must have uint8 data type, got ",
41-
encoded_data.dtype());
42-
TORCH_CHECK(
43-
encoded_data.dim() == 1,
44-
"Input tensor must be 1-dimensional, got ",
45-
encoded_data.dim(),
46-
" dims.");
38+
validate_encoded_data(encoded_data);
4739

4840
int error = D_GIF_SUCCEEDED;
4941

torchvision/csrc/io/image/cpu/decode_heic.cpp

+4-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "decode_heic.h"
2+
#include "../common.h"
23

34
#if HEIC_FOUND
45
#include "libheif/heif_cxx.h"
@@ -19,26 +20,8 @@ torch::Tensor decode_heic(
1920
torch::Tensor decode_heic(
2021
const torch::Tensor& encoded_data,
2122
ImageReadMode mode) {
22-
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
23-
TORCH_CHECK(
24-
encoded_data.dtype() == torch::kU8,
25-
"Input tensor must have uint8 data type, got ",
26-
encoded_data.dtype());
27-
TORCH_CHECK(
28-
encoded_data.dim() == 1,
29-
"Input tensor must be 1-dimensional, got ",
30-
encoded_data.dim(),
31-
" dims.");
32-
33-
if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB &&
34-
mode != IMAGE_READ_MODE_RGB_ALPHA) {
35-
// Other modes aren't supported, but we don't error or even warn because we
36-
// have generic entry points like decode_image which may support all modes,
37-
// it just depends on the underlying decoder.
38-
mode = IMAGE_READ_MODE_UNCHANGED;
39-
}
23+
validate_encoded_data(encoded_data);
4024

41-
// If return_rgb is false it means we return rgba - nothing else.
4225
auto return_rgb = true;
4326

4427
int height = 0;
@@ -82,8 +65,8 @@ torch::Tensor decode_heic(
8265
bit_depth = handle.get_luma_bits_per_pixel();
8366

8467
return_rgb =
85-
(mode == IMAGE_READ_MODE_RGB ||
86-
(mode == IMAGE_READ_MODE_UNCHANGED && !handle.has_alpha_channel()));
68+
should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video(
69+
mode, handle.has_alpha_channel());
8770

8871
height = handle.get_height();
8972
width = handle.get_width();

0 commit comments

Comments
 (0)