@@ -8,13 +8,17 @@ namespace vision {
8
8
namespace image {
9
9
10
10
#if !WEBP_FOUND
11
- torch::Tensor decode_webp (const torch::Tensor& data) {
11
+ torch::Tensor decode_webp (
12
+ const torch::Tensor& encoded_data,
13
+ ImageReadMode mode) {
12
14
TORCH_CHECK (
13
15
false , " decode_webp: torchvision not compiled with libwebp support" );
14
16
}
15
17
#else
16
18
17
- torch::Tensor decode_webp (const torch::Tensor& encoded_data) {
19
+ torch::Tensor decode_webp (
20
+ const torch::Tensor& encoded_data,
21
+ ImageReadMode mode) {
18
22
TORCH_CHECK (encoded_data.is_contiguous (), " Input tensor must be contiguous." );
19
23
TORCH_CHECK (
20
24
encoded_data.dtype () == torch::kU8 ,
@@ -26,13 +30,41 @@ torch::Tensor decode_webp(const torch::Tensor& encoded_data) {
26
30
encoded_data.dim (),
27
31
" dims." );
28
32
33
+ auto encoded_data_p = encoded_data.data_ptr <uint8_t >();
34
+ auto encoded_data_size = encoded_data.numel ();
35
+
36
+ WebPBitstreamFeatures features;
37
+ auto res = WebPGetFeatures (encoded_data_p, encoded_data_size, &features);
38
+ TORCH_CHECK (
39
+ res == VP8_STATUS_OK, " WebPGetFeatures failed with error code " , res);
40
+ TORCH_CHECK (
41
+ !features.has_animation , " Animated webp files are not supported." );
42
+
43
+ auto decoding_func = WebPDecodeRGB;
44
+ int num_channels = 0 ;
45
+ if (mode == IMAGE_READ_MODE_RGB) {
46
+ decoding_func = WebPDecodeRGB;
47
+ num_channels = 3 ;
48
+ } else if (mode == IMAGE_READ_MODE_RGB_ALPHA) {
49
+ decoding_func = WebPDecodeRGBA;
50
+ num_channels = 4 ;
51
+ } else {
52
+ // Assume mode is "unchanged"
53
+ decoding_func = features.has_alpha ? WebPDecodeRGBA : WebPDecodeRGB;
54
+ num_channels = features.has_alpha ? 4 : 3 ;
55
+ }
56
+
29
57
int width = 0 ;
30
58
int height = 0 ;
31
- auto decoded_data = WebPDecodeRGB (
32
- encoded_data.data_ptr <uint8_t >(), encoded_data.numel (), &width, &height);
33
- TORCH_CHECK (decoded_data != nullptr , " WebPDecodeRGB failed." );
34
- auto out = torch::from_blob (decoded_data, {height, width, 3 }, torch::kUInt8 );
35
- return out.permute ({2 , 0 , 1 }); // return CHW, channels-last
59
+
60
+ auto decoded_data =
61
+ decoding_func (encoded_data_p, encoded_data_size, &width, &height);
62
+ TORCH_CHECK (decoded_data != nullptr , " WebPDecodeRGB[A] failed." );
63
+
64
+ auto out = torch::from_blob (
65
+ decoded_data, {height, width, num_channels}, torch::kUInt8 );
66
+
67
+ return out.permute ({2 , 0 , 1 });
36
68
}
37
69
#endif // WEBP_FOUND
38
70
0 commit comments