42
42
IS_WINDOWS = sys .platform in ("win32" , "cygwin" )
43
43
IS_MACOS = sys .platform == "darwin"
44
44
PILLOW_VERSION = tuple (int (x ) for x in PILLOW_VERSION .split ("." ))
45
+ WEBP_TEST_IMAGES_DIR = os .environ .get ("WEBP_TEST_IMAGES_DIR" , "" )
46
+
47
+ # Hacky way of figuring out whether we compiled with libavif/libheif (those are
48
+ # currenlty disabled by default)
49
+ try :
50
+ _decode_avif (torch .arange (10 , dtype = torch .uint8 ))
51
+ except Exception as e :
52
+ DECODE_AVIF_ENABLED = "torchvision not compiled with libavif support" not in str (e )
53
+
54
+ try :
55
+ _decode_heic (torch .arange (10 , dtype = torch .uint8 ))
56
+ except Exception as e :
57
+ DECODE_HEIC_ENABLED = "torchvision not compiled with libheif support" not in str (e )
45
58
46
59
47
60
def _get_safe_image_name (name ):
@@ -149,17 +162,6 @@ def test_invalid_exif(tmpdir, size):
149
162
torch .testing .assert_close (expected , output )
150
163
151
164
152
- def test_decode_jpeg_errors ():
153
- with pytest .raises (RuntimeError , match = "Expected a non empty 1-dimensional tensor" ):
154
- decode_jpeg (torch .empty ((100 , 1 ), dtype = torch .uint8 ))
155
-
156
- with pytest .raises (RuntimeError , match = "Expected a torch.uint8 tensor" ):
157
- decode_jpeg (torch .empty ((100 ,), dtype = torch .float16 ))
158
-
159
- with pytest .raises (RuntimeError , match = "Not a JPEG file" ):
160
- decode_jpeg (torch .empty ((100 ), dtype = torch .uint8 ))
161
-
162
-
163
165
def test_decode_bad_huffman_images ():
164
166
# sanity check: make sure we can decode the bad Huffman encoding
165
167
bad_huff = read_file (os .path .join (DAMAGED_JPEG , "bad_huffman.jpg" ))
@@ -235,10 +237,6 @@ def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):
235
237
236
238
237
239
def test_decode_png_errors ():
238
- with pytest .raises (RuntimeError , match = "Expected a non empty 1-dimensional tensor" ):
239
- decode_png (torch .empty ((), dtype = torch .uint8 ))
240
- with pytest .raises (RuntimeError , match = "Content is not png" ):
241
- decode_png (torch .randint (3 , 5 , (300 ,), dtype = torch .uint8 ))
242
240
with pytest .raises (RuntimeError , match = "Out of bound read in decode_png" ):
243
241
decode_png (read_file (os .path .join (DAMAGED_PNG , "sigsegv.png" )))
244
242
with pytest .raises (RuntimeError , match = "Content is too small for png" ):
@@ -864,20 +862,28 @@ def test_decode_gif(tmpdir, name, scripted):
864
862
torch .testing .assert_close (tv_frame , pil_frame , atol = 0 , rtol = 0 )
865
863
866
864
867
- @pytest .mark .parametrize ("decode_fun" , (decode_gif , decode_webp ))
868
- def test_decode_gif_webp_errors (decode_fun ):
865
+ decode_fun_and_match = [
866
+ (decode_png , "Content is not png" ),
867
+ (decode_jpeg , "Not a JPEG file" ),
868
+ (decode_gif , re .escape ("DGifOpenFileName() failed - 103" )),
869
+ (decode_webp , "WebPGetFeatures failed." ),
870
+ ]
871
+ if DECODE_AVIF_ENABLED :
872
+ decode_fun_and_match .append ((_decode_avif , "BMFF parsing failed" ))
873
+ if DECODE_HEIC_ENABLED :
874
+ decode_fun_and_match .append ((_decode_heic , "Invalid input: No 'ftyp' box" ))
875
+
876
+
877
+ @pytest .mark .parametrize ("decode_fun, match" , decode_fun_and_match )
878
+ def test_decode_bad_encoded_data (decode_fun , match ):
869
879
encoded_data = torch .randint (0 , 256 , (100 ,), dtype = torch .uint8 )
870
880
with pytest .raises (RuntimeError , match = "Input tensor must be 1-dimensional" ):
871
881
decode_fun (encoded_data [None ])
872
882
with pytest .raises (RuntimeError , match = "Input tensor must have uint8 data type" ):
873
883
decode_fun (encoded_data .float ())
874
884
with pytest .raises (RuntimeError , match = "Input tensor must be contiguous" ):
875
885
decode_fun (encoded_data [::2 ])
876
- if decode_fun is decode_gif :
877
- expected_match = re .escape ("DGifOpenFileName() failed - 103" )
878
- elif decode_fun is decode_webp :
879
- expected_match = "WebPGetFeatures failed."
880
- with pytest .raises (RuntimeError , match = expected_match ):
886
+ with pytest .raises (RuntimeError , match = match ):
881
887
decode_fun (encoded_data )
882
888
883
889
@@ -890,21 +896,27 @@ def test_decode_webp(decode_fun, scripted):
890
896
img = decode_fun (encoded_bytes )
891
897
assert img .shape == (3 , 100 , 100 )
892
898
assert img [None ].is_contiguous (memory_format = torch .channels_last )
899
+ img += 123 # make sure image buffer wasn't freed by underlying decoding lib
893
900
894
901
895
- # This test is skipped because it requires webp images that we're not including
896
- # within the repo. The test images were downloaded from the different pages of
897
- # https://developers.google.com/speed/webp/gallery
898
- # Note that converting an RGBA image to RGB leads to bad results because the
899
- # transparent pixels aren't necessarily set to "black" or "white", they can be
900
- # random stuff. This is consistent with PIL results.
901
- @pytest .mark .skip (reason = "Need to download test images first" )
902
+ # This test is skipped by default because it requires webp images that we're not
903
+ # including within the repo. The test images were downloaded manually from the
904
+ # different pages of https://developers.google.com/speed/webp/gallery
905
+ @pytest .mark .skipif (not WEBP_TEST_IMAGES_DIR , reason = "WEBP_TEST_IMAGES_DIR is not set" )
902
906
@pytest .mark .parametrize ("decode_fun" , (decode_webp , decode_image ))
903
907
@pytest .mark .parametrize ("scripted" , (False , True ))
904
908
@pytest .mark .parametrize (
905
- "mode, pil_mode" , ((ImageReadMode .RGB , "RGB" ), (ImageReadMode .RGB_ALPHA , "RGBA" ), (ImageReadMode .UNCHANGED , None ))
909
+ "mode, pil_mode" ,
910
+ (
911
+ # Note that converting an RGBA image to RGB leads to bad results because the
912
+ # transparent pixels aren't necessarily set to "black" or "white", they can be
913
+ # random stuff. This is consistent with PIL results.
914
+ (ImageReadMode .RGB , "RGB" ),
915
+ (ImageReadMode .RGB_ALPHA , "RGBA" ),
916
+ (ImageReadMode .UNCHANGED , None ),
917
+ ),
906
918
)
907
- @pytest .mark .parametrize ("filename" , Path ("/home/nicolashug/webp_samples" ).glob ("*.webp" ))
919
+ @pytest .mark .parametrize ("filename" , Path (WEBP_TEST_IMAGES_DIR ).glob ("*.webp" ), ids = lambda p : p . name )
908
920
def test_decode_webp_against_pil (decode_fun , scripted , mode , pil_mode , filename ):
909
921
encoded_bytes = read_file (filename )
910
922
if scripted :
@@ -915,9 +927,10 @@ def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename)
915
927
pil_img = Image .open (filename ).convert (pil_mode )
916
928
from_pil = F .pil_to_tensor (pil_img )
917
929
assert_equal (img , from_pil )
930
+ img += 123 # make sure image buffer wasn't freed by underlying decoding lib
918
931
919
932
920
- @pytest .mark .xfail ( reason = "AVIF support not enabled yet ." )
933
+ @pytest .mark .skipif ( not DECODE_AVIF_ENABLED , reason = "AVIF support not enabled." )
921
934
@pytest .mark .parametrize ("decode_fun" , (_decode_avif , decode_image ))
922
935
@pytest .mark .parametrize ("scripted" , (False , True ))
923
936
def test_decode_avif (decode_fun , scripted ):
@@ -927,12 +940,20 @@ def test_decode_avif(decode_fun, scripted):
927
940
img = decode_fun (encoded_bytes )
928
941
assert img .shape == (3 , 100 , 100 )
929
942
assert img [None ].is_contiguous (memory_format = torch .channels_last )
943
+ img += 123 # make sure image buffer wasn't freed by underlying decoding lib
930
944
931
945
932
- @pytest .mark .xfail (reason = "AVIF and HEIC support not enabled yet." )
933
946
# Note: decode_image fails because some of these files have a (valid) signature
934
947
# we don't recognize. We should probably use libmagic....
935
- @pytest .mark .parametrize ("decode_fun" , (_decode_avif , _decode_heic ))
948
+ decode_funs = []
949
+ if DECODE_AVIF_ENABLED :
950
+ decode_funs .append (_decode_avif )
951
+ if DECODE_HEIC_ENABLED :
952
+ decode_funs .append (_decode_heic )
953
+
954
+
955
+ @pytest .mark .skipif (not decode_funs , reason = "Built without avif and heic support." )
956
+ @pytest .mark .parametrize ("decode_fun" , decode_funs )
936
957
@pytest .mark .parametrize ("scripted" , (False , True ))
937
958
@pytest .mark .parametrize (
938
959
"mode, pil_mode" ,
@@ -945,7 +966,7 @@ def test_decode_avif(decode_fun, scripted):
945
966
@pytest .mark .parametrize (
946
967
"filename" , Path ("/home/nicolashug/dev/libavif/tests/data/" ).glob ("*.avif" ), ids = lambda p : p .name
947
968
)
948
- def test_decode_avif_against_pil (decode_fun , scripted , mode , pil_mode , filename ):
969
+ def test_decode_avif_heic_against_pil (decode_fun , scripted , mode , pil_mode , filename ):
949
970
if "reversed_dimg_order" in str (filename ):
950
971
# Pillow properly decodes this one, but we don't (order of parts of the
951
972
# image is wrong). This is due to a bug that was recently fixed in
@@ -996,21 +1017,21 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename)
996
1017
g = make_grid ([img , from_pil ])
997
1018
F .to_pil_image (g ).save ((f"/home/nicolashug/out_images/{ filename .name } .{ pil_mode } .png" ))
998
1019
999
- is__decode_heic = getattr (decode_fun , "__name__" , getattr (decode_fun , "name" , None )) == "_decode_heic"
1000
- if mode == ImageReadMode .RGB and not is__decode_heic :
1020
+ is_decode_heic = getattr (decode_fun , "__name__" , getattr (decode_fun , "name" , None )) == "_decode_heic"
1021
+ if mode == ImageReadMode .RGB and not is_decode_heic :
1001
1022
# We don't compare torchvision's AVIF against PIL for RGB because
1002
1023
# results look pretty different on RGBA images (other images are fine).
1003
1024
# The result on torchvision basically just plainly ignores the alpha
1004
1025
# channel, resuting in transparent pixels looking dark. PIL seems to be
1005
1026
# using a sort of k-nn thing (Take a look at the resuting images)
1006
1027
return
1007
- if filename .name == "sofa_grid1x5_420.avif" and is__decode_heic :
1028
+ if filename .name == "sofa_grid1x5_420.avif" and is_decode_heic :
1008
1029
return
1009
1030
1010
1031
torch .testing .assert_close (img , from_pil , rtol = 0 , atol = 3 )
1011
1032
1012
1033
1013
- @pytest .mark .xfail ( reason = "HEIC support not enabled yet." )
1034
+ @pytest .mark .skipif ( not DECODE_HEIC_ENABLED , reason = "HEIC support not enabled yet." )
1014
1035
@pytest .mark .parametrize ("decode_fun" , (_decode_heic , decode_image ))
1015
1036
@pytest .mark .parametrize ("scripted" , (False , True ))
1016
1037
def test_decode_heic (decode_fun , scripted ):
@@ -1020,6 +1041,7 @@ def test_decode_heic(decode_fun, scripted):
1020
1041
img = decode_fun (encoded_bytes )
1021
1042
assert img .shape == (3 , 100 , 100 )
1022
1043
assert img [None ].is_contiguous (memory_format = torch .channels_last )
1044
+ img += 123 # make sure image buffer wasn't freed by underlying decoding lib
1023
1045
1024
1046
1025
1047
if __name__ == "__main__" :
0 commit comments