Skip to content

Commit 7cf2ba5

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Add HEIC decoder (#8597)
Reviewed By: ahmadsharif1 Differential Revision: D62032052 fbshipit-source-id: 519d09e72140007a15e8fb1e87a4336ea2d461c5
1 parent 7176a8b commit 7cf2ba5

File tree

10 files changed

+274
-15
lines changed

10 files changed

+274
-15
lines changed

setup.py

+17
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
2020
USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
2121
USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1"
22+
USE_HEIC = os.getenv("TORCHVISION_USE_HEIC", "0") == "1" # TODO enable by default!
2223
USE_AVIF = os.getenv("TORCHVISION_USE_AVIF", "0") == "1" # TODO enable by default!
2324
USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
2425
NVCC_FLAGS = os.getenv("NVCC_FLAGS", None)
@@ -50,6 +51,7 @@
5051
print(f"{USE_PNG = }")
5152
print(f"{USE_JPEG = }")
5253
print(f"{USE_WEBP = }")
54+
print(f"{USE_HEIC = }")
5355
print(f"{USE_AVIF = }")
5456
print(f"{USE_NVJPEG = }")
5557
print(f"{NVCC_FLAGS = }")
@@ -334,6 +336,21 @@ def make_image_extension():
334336
else:
335337
warnings.warn("Building torchvision without WEBP support")
336338

339+
if USE_HEIC:
340+
heic_found, heic_include_dir, heic_library_dir = find_library(header="libheif/heif.h")
341+
if heic_found:
342+
print("Building torchvision with HEIC support")
343+
print(f"{heic_include_dir = }")
344+
print(f"{heic_library_dir = }")
345+
if heic_include_dir is not None and heic_library_dir is not None:
346+
# if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add.
347+
include_dirs.append(heic_include_dir)
348+
library_dirs.append(heic_library_dir)
349+
libraries.append("heif")
350+
define_macros += [("HEIC_FOUND", 1)]
351+
else:
352+
warnings.warn("Building torchvision without HEIC support")
353+
337354
if USE_AVIF:
338355
avif_found, avif_include_dir, avif_library_dir = find_library(header="avif/avif.h")
339356
if avif_found:
Binary file not shown.

test/test_image.py

+48-14
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
1616
from torchvision.io.image import (
1717
_decode_avif,
18+
_decode_heic,
1819
decode_gif,
1920
decode_image,
2021
decode_jpeg,
@@ -928,11 +929,10 @@ def test_decode_avif(decode_fun, scripted):
928929
assert img[None].is_contiguous(memory_format=torch.channels_last)
929930

930931

931-
@pytest.mark.xfail(reason="AVIF support not enabled yet.")
932+
@pytest.mark.xfail(reason="AVIF and HEIC support not enabled yet.")
932933
# Note: decode_image fails because some of these files have a (valid) signature
933934
# we don't recognize. We should probably use libmagic....
934-
# @pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
935-
@pytest.mark.parametrize("decode_fun", (_decode_avif,))
935+
@pytest.mark.parametrize("decode_fun", (_decode_avif, _decode_heic))
936936
@pytest.mark.parametrize("scripted", (False, True))
937937
@pytest.mark.parametrize(
938938
"mode, pil_mode",
@@ -942,7 +942,9 @@ def test_decode_avif(decode_fun, scripted):
942942
(ImageReadMode.UNCHANGED, None),
943943
),
944944
)
945-
@pytest.mark.parametrize("filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"))
945+
@pytest.mark.parametrize(
946+
"filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"), ids=lambda p: p.name
947+
)
946948
def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename):
947949
if "reversed_dimg_order" in str(filename):
948950
# Pillow properly decodes this one, but we don't (order of parts of the
@@ -960,7 +962,14 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename)
960962
except RuntimeError as e:
961963
if any(
962964
s in str(e)
963-
for s in ("BMFF parsing failed", "avifDecoderParse failed: ", "file contains more than one image")
965+
for s in (
966+
"BMFF parsing failed",
967+
"avifDecoderParse failed: ",
968+
"file contains more than one image",
969+
"no 'ispe' property",
970+
"'iref' has double references",
971+
"Invalid image grid",
972+
)
964973
):
965974
pytest.skip(reason="Expected failure, that's OK")
966975
else:
@@ -970,22 +979,47 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename)
970979
assert img.shape[0] == 3
971980
if mode == ImageReadMode.RGB_ALPHA:
972981
assert img.shape[0] == 4
982+
973983
if img.dtype == torch.uint16:
974984
img = F.to_dtype(img, dtype=torch.uint8, scale=True)
985+
try:
986+
from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode))
987+
except RuntimeError as e:
988+
if "Invalid image grid" in str(e):
989+
pytest.skip(reason="PIL failure")
990+
else:
991+
raise e
975992

976-
from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode))
977-
if False:
993+
if True:
978994
from torchvision.utils import make_grid
979995

980996
g = make_grid([img, from_pil])
981997
F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png"))
982-
if mode != ImageReadMode.RGB:
983-
# We don't compare against PIL for RGB because results look pretty
984-
# different on RGBA images (other images are fine). The result on
985-
# torchvision basically just plainly ignores the alpha channel, resuting
986-
# in transparent pixels looking dark. PIL seems to be using a sort of
987-
# k-nn thing, looking at the output. Take a look at the resuting images.
988-
torch.testing.assert_close(img, from_pil, rtol=0, atol=3)
998+
999+
is__decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic"
1000+
if mode == ImageReadMode.RGB and not is__decode_heic:
1001+
# We don't compare torchvision's AVIF against PIL for RGB because
1002+
# results look pretty different on RGBA images (other images are fine).
1003+
# The result on torchvision basically just plainly ignores the alpha
1004+
# channel, resuting in transparent pixels looking dark. PIL seems to be
1005+
# using a sort of k-nn thing (Take a look at the resuting images)
1006+
return
1007+
if filename.name == "sofa_grid1x5_420.avif" and is__decode_heic:
1008+
return
1009+
1010+
torch.testing.assert_close(img, from_pil, rtol=0, atol=3)
1011+
1012+
1013+
@pytest.mark.xfail(reason="HEIC support not enabled yet.")
1014+
@pytest.mark.parametrize("decode_fun", (_decode_heic, decode_image))
1015+
@pytest.mark.parametrize("scripted", (False, True))
1016+
def test_decode_heic(decode_fun, scripted):
1017+
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".heic")))
1018+
if scripted:
1019+
decode_fun = torch.jit.script(decode_fun)
1020+
img = decode_fun(encoded_bytes)
1021+
assert img.shape == (3, 100, 100)
1022+
assert img[None].is_contiguous(memory_format=torch.channels_last)
9891023

9901024

9911025
if __name__ == "__main__":
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#include "decode_heic.h"
2+
3+
#if HEIC_FOUND
4+
#include "libheif/heif_cxx.h"
5+
#endif // HEIC_FOUND
6+
7+
namespace vision {
8+
namespace image {
9+
10+
#if !HEIC_FOUND
11+
torch::Tensor decode_heic(
12+
const torch::Tensor& encoded_data,
13+
ImageReadMode mode) {
14+
TORCH_CHECK(
15+
false, "decode_heic: torchvision not compiled with libheif support");
16+
}
17+
#else
18+
19+
torch::Tensor decode_heic(
20+
const torch::Tensor& encoded_data,
21+
ImageReadMode mode) {
22+
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
23+
TORCH_CHECK(
24+
encoded_data.dtype() == torch::kU8,
25+
"Input tensor must have uint8 data type, got ",
26+
encoded_data.dtype());
27+
TORCH_CHECK(
28+
encoded_data.dim() == 1,
29+
"Input tensor must be 1-dimensional, got ",
30+
encoded_data.dim(),
31+
" dims.");
32+
33+
if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB &&
34+
mode != IMAGE_READ_MODE_RGB_ALPHA) {
35+
// Other modes aren't supported, but we don't error or even warn because we
36+
// have generic entry points like decode_image which may support all modes,
37+
// it just depends on the underlying decoder.
38+
mode = IMAGE_READ_MODE_UNCHANGED;
39+
}
40+
41+
// If return_rgb is false it means we return rgba - nothing else.
42+
auto return_rgb = true;
43+
44+
int height = 0;
45+
int width = 0;
46+
int num_channels = 0;
47+
int stride = 0;
48+
uint8_t* decoded_data = nullptr;
49+
heif::Image img;
50+
int bit_depth = 0;
51+
52+
try {
53+
heif::Context ctx;
54+
ctx.read_from_memory_without_copy(
55+
encoded_data.data_ptr<uint8_t>(), encoded_data.numel());
56+
57+
// TODO properly error on (or support) image sequences. Right now, I think
58+
// this function will always return the first image in a sequence, which is
59+
// inconsistent with decode_gif (which returns a batch) and with decode_avif
60+
// (which errors loudly).
61+
// Why? I'm struggling to make sense of
62+
// ctx.get_number_of_top_level_images(). It disagrees with libavif's
63+
// imageCount. For example on some of the libavif test images:
64+
//
65+
// - colors-animated-12bpc-keyframes-0-2-3.avif
66+
// avif num images = 5
67+
// heif num images = 1 // Why is this 1 when clearly this is supposed to
68+
// be a sequence?
69+
// - sofa_grid1x5_420.avif
70+
// avif num images = 1
71+
// heif num images = 6 // If we were to error here we won't be able to
72+
// decode this image which is otherwise properly
73+
// decoded by libavif.
74+
// I can't find a libheif function that does what we need here, or at least
75+
// that agrees with libavif.
76+
77+
// TORCH_CHECK(
78+
// ctx.get_number_of_top_level_images() == 1,
79+
// "heic file contains more than one image");
80+
81+
heif::ImageHandle handle = ctx.get_primary_image_handle();
82+
bit_depth = handle.get_luma_bits_per_pixel();
83+
84+
return_rgb =
85+
(mode == IMAGE_READ_MODE_RGB ||
86+
(mode == IMAGE_READ_MODE_UNCHANGED && !handle.has_alpha_channel()));
87+
88+
height = handle.get_height();
89+
width = handle.get_width();
90+
91+
num_channels = return_rgb ? 3 : 4;
92+
heif_chroma chroma;
93+
if (bit_depth == 8) {
94+
chroma = return_rgb ? heif_chroma_interleaved_RGB
95+
: heif_chroma_interleaved_RGBA;
96+
} else {
97+
// TODO: This, along with our 10bits -> 16bits range mapping down below,
98+
// may not work on BE platforms
99+
chroma = return_rgb ? heif_chroma_interleaved_RRGGBB_LE
100+
: heif_chroma_interleaved_RRGGBBAA_LE;
101+
}
102+
103+
img = handle.decode_image(heif_colorspace_RGB, chroma);
104+
105+
decoded_data = img.get_plane(heif_channel_interleaved, &stride);
106+
} catch (const heif::Error& err) {
107+
// We need this try/catch block and call TORCH_CHECK, because libheif may
108+
// otherwise throw heif::Error that would just be reported as "An unknown
109+
// exception occurred" when we move back to Python.
110+
TORCH_CHECK(false, "decode_heif failed: ", err.get_message());
111+
}
112+
TORCH_CHECK(decoded_data != nullptr, "Something went wrong during decoding.");
113+
114+
auto dtype = (bit_depth == 8) ? torch::kUInt8 : at::kUInt16;
115+
auto out = torch::empty({height, width, num_channels}, dtype);
116+
uint8_t* out_ptr = (uint8_t*)out.data_ptr();
117+
118+
// decoded_data is *almost* the raw decoded data, but not quite: for some
119+
// images, there may be some padding at the end of each row, i.e. when stride
120+
// != row_size_in_bytes. So we can't copy decoded_data into the tensor's
121+
// memory directly, we have to copy row by row. Oh, and if you think you can
122+
// take a shortcut when stride == row_size_in_bytes and just do:
123+
// out = torch::from_blob(decoded_data, ...)
124+
// you can't, because decoded_data is owned by the heif::Image object and it
125+
// gets freed when it gets out of scope!
126+
auto row_size_in_bytes = width * num_channels * ((bit_depth == 8) ? 1 : 2);
127+
for (auto h = 0; h < height; h++) {
128+
memcpy(
129+
out_ptr + h * row_size_in_bytes,
130+
decoded_data + h * stride,
131+
row_size_in_bytes);
132+
}
133+
if (bit_depth > 8) {
134+
// Say bit depth is 10. decodec_data and out_ptr contain 10bits values
135+
// over 2 bytes, stored into uint16_t. In torchvision a uint16 value is
136+
// expected to be in [0, 2**16), so we have to map the 10bits value to that
137+
// range. Note that other libraries like libavif do that mapping
138+
// automatically.
139+
// TODO: It's possible to avoid the memcpy call above in this case, and do
140+
// the copy at the same time as the conversation. Whether it's worth it
141+
// should be benchmarked.
142+
auto out_ptr_16 = (uint16_t*)out_ptr;
143+
for (auto p = 0; p < height * width * num_channels; p++) {
144+
out_ptr_16[p] <<= (16 - bit_depth);
145+
}
146+
}
147+
return out.permute({2, 0, 1});
148+
}
149+
#endif // HEIC_FOUND
150+
151+
} // namespace image
152+
} // namespace vision
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
3+
#include <torch/types.h>
4+
#include "../image_read_mode.h"
5+
6+
namespace vision {
7+
namespace image {
8+
9+
C10_EXPORT torch::Tensor decode_heic(
10+
const torch::Tensor& data,
11+
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
12+
13+
} // namespace image
14+
} // namespace vision

torchvision/csrc/io/image/cpu/decode_image.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "decode_avif.h"
44
#include "decode_gif.h"
5+
#include "decode_heic.h"
56
#include "decode_jpeg.h"
67
#include "decode_png.h"
78
#include "decode_webp.h"
@@ -61,6 +62,17 @@ torch::Tensor decode_image(
6162
return decode_avif(data, mode);
6263
}
6364

65+
// Similarly for heic we assume the signature is "ftypeheic" but some files
66+
// may come as "ftypmif1" where the "heic" part is defined later in the file.
67+
// We can't be re-inventing libmagic here. We might need to start relying on
68+
// it though...
69+
const uint8_t heic_signature[8] = {
70+
0x66, 0x74, 0x79, 0x70, 0x68, 0x65, 0x69, 0x63}; // == "ftypheic"
71+
TORCH_CHECK(data.numel() >= 12, err_msg);
72+
if ((memcmp(heic_signature, datap + 4, 8) == 0)) {
73+
return decode_heic(data, mode);
74+
}
75+
6476
const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF"
6577
const uint8_t webp_signature_end[7] = {
6678
0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8"

torchvision/csrc/io/image/image.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ static auto registry =
2323
&decode_jpeg)
2424
.op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor",
2525
&decode_webp)
26+
.op("image::decode_heic(Tensor encoded_data, int mode) -> Tensor",
27+
&decode_heic)
2628
.op("image::decode_avif(Tensor encoded_data, int mode) -> Tensor",
2729
&decode_avif)
2830
.op("image::encode_jpeg", &encode_jpeg)

torchvision/csrc/io/image/image.h

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "cpu/decode_avif.h"
44
#include "cpu/decode_gif.h"
5+
#include "cpu/decode_heic.h"
56
#include "cpu/decode_image.h"
67
#include "cpu/decode_jpeg.h"
78
#include "cpu/decode_png.h"

torchvision/io/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
"decode_image",
6262
"decode_jpeg",
6363
"decode_png",
64+
"decode_heic",
6465
"decode_webp",
6566
"decode_gif",
6667
"encode_jpeg",

torchvision/io/image.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -417,5 +417,31 @@ def _decode_avif(
417417
Decoded image (Tensor[image_channels, image_height, image_width])
418418
"""
419419
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
420-
_log_api_usage_once(decode_webp)
420+
_log_api_usage_once(_decode_avif)
421421
return torch.ops.image.decode_avif(input, mode.value)
422+
423+
424+
def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
425+
"""
426+
Decode an HEIC image into a 3 dimensional RGB[A] Tensor.
427+
428+
The values of the output tensor are in uint8 in [0, 255] for most images. If
429+
the image has a bit-depth of more than 8, then the output tensor is uint16
430+
in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
431+
calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
432+
``scale=True`` after this function to convert the decoded image into a uint8
433+
or float tensor.
434+
435+
Args:
436+
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
437+
the raw bytes of the HEIC image.
438+
mode (ImageReadMode): The read mode used for optionally
439+
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
440+
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
441+
442+
Returns:
443+
Decoded image (Tensor[image_channels, image_height, image_width])
444+
"""
445+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
446+
_log_api_usage_once(_decode_heic)
447+
return torch.ops.image.decode_heic(input, mode.value)

0 commit comments

Comments
 (0)