forked from pytorch/vision
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecode_jpeg_cuda.cpp
208 lines (172 loc) · 5.7 KB
/
decode_jpeg_cuda.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#include "encode_decode_jpegs_cuda.h"
#include <ATen/ATen.h>
#if NVJPEG_FOUND
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <nvjpeg.h>
#endif
#include <string>
namespace vision {
namespace image {
#if !NVJPEG_FOUND
torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device) {
TORCH_CHECK(
false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support");
}
#else
namespace {
static nvjpegHandle_t nvjpeg_handle = nullptr;
}
torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device) {
C10_LOG_API_USAGE_ONCE(
"torchvision.csrc.io.image.cuda.decode_jpeg_cuda.decode_jpeg_cuda");
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
TORCH_CHECK(
!data.is_cuda(),
"The input tensor must be on CPU when decoding with nvjpeg")
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
TORCH_CHECK(device.is_cuda(), "Expected a cuda device")
int major_version;
int minor_version;
nvjpegStatus_t get_major_property_status =
nvjpegGetProperty(MAJOR_VERSION, &major_version);
nvjpegStatus_t get_minor_property_status =
nvjpegGetProperty(MINOR_VERSION, &minor_version);
TORCH_CHECK(
get_major_property_status == NVJPEG_STATUS_SUCCESS,
"nvjpegGetProperty failed: ",
get_major_property_status);
TORCH_CHECK(
get_minor_property_status == NVJPEG_STATUS_SUCCESS,
"nvjpegGetProperty failed: ",
get_minor_property_status);
if ((major_version < 11) || ((major_version == 11) && (minor_version < 6))) {
TORCH_WARN_ONCE(
"There is a memory leak issue in the nvjpeg library for CUDA versions < 11.6. "
"Make sure to rely on CUDA 11.6 or above before using decode_jpeg(..., device='cuda').");
}
at::cuda::CUDAGuard device_guard(device);
// Create global nvJPEG handle
static std::once_flag nvjpeg_handle_creation_flag;
std::call_once(nvjpeg_handle_creation_flag, []() {
if (nvjpeg_handle == nullptr) {
nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle);
if (create_status != NVJPEG_STATUS_SUCCESS) {
// Reset handle so that one can still call the function again in the
// same process if there was a failure
free(nvjpeg_handle);
nvjpeg_handle = nullptr;
}
TORCH_CHECK(
create_status == NVJPEG_STATUS_SUCCESS,
"nvjpegCreateSimple failed: ",
create_status);
}
});
// Create the jpeg state
nvjpegJpegState_t jpeg_state;
nvjpegStatus_t state_status =
nvjpegJpegStateCreate(nvjpeg_handle, &jpeg_state);
TORCH_CHECK(
state_status == NVJPEG_STATUS_SUCCESS,
"nvjpegJpegStateCreate failed: ",
state_status);
auto datap = data.data_ptr<uint8_t>();
// Get the image information
int num_channels;
nvjpegChromaSubsampling_t subsampling;
int widths[NVJPEG_MAX_COMPONENT];
int heights[NVJPEG_MAX_COMPONENT];
nvjpegStatus_t info_status = nvjpegGetImageInfo(
nvjpeg_handle,
datap,
data.numel(),
&num_channels,
&subsampling,
widths,
heights);
if (info_status != NVJPEG_STATUS_SUCCESS) {
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status);
}
if (subsampling == NVJPEG_CSS_UNKNOWN) {
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling");
}
int width = widths[0];
int height = heights[0];
nvjpegOutputFormat_t ouput_format;
int num_channels_output;
switch (mode) {
case IMAGE_READ_MODE_UNCHANGED:
num_channels_output = num_channels;
// For some reason, setting output_format to NVJPEG_OUTPUT_UNCHANGED will
// not properly decode RGB images (it's fine for grayscale), so we set
// output_format manually here
if (num_channels == 1) {
ouput_format = NVJPEG_OUTPUT_Y;
} else if (num_channels == 3) {
ouput_format = NVJPEG_OUTPUT_RGB;
} else {
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(
false,
"When mode is UNCHANGED, only 1 or 3 input channels are allowed.");
}
break;
case IMAGE_READ_MODE_GRAY:
ouput_format = NVJPEG_OUTPUT_Y;
num_channels_output = 1;
break;
case IMAGE_READ_MODE_RGB:
ouput_format = NVJPEG_OUTPUT_RGB;
num_channels_output = 3;
break;
default:
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(
false, "The provided mode is not supported for JPEG decoding on GPU");
}
auto out_tensor = torch::empty(
{int64_t(num_channels_output), int64_t(height), int64_t(width)},
torch::dtype(torch::kU8).device(device));
// nvjpegImage_t is a struct with
// - an array of pointers to each channel
// - the pitch for each channel
// which must be filled in manually
nvjpegImage_t out_image;
for (int c = 0; c < num_channels_output; c++) {
out_image.channel[c] = out_tensor[c].data_ptr<uint8_t>();
out_image.pitch[c] = width;
}
for (int c = num_channels_output; c < NVJPEG_MAX_COMPONENT; c++) {
out_image.channel[c] = nullptr;
out_image.pitch[c] = 0;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index());
nvjpegStatus_t decode_status = nvjpegDecode(
nvjpeg_handle,
jpeg_state,
datap,
data.numel(),
ouput_format,
&out_image,
stream);
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(
decode_status == NVJPEG_STATUS_SUCCESS,
"nvjpegDecode failed: ",
decode_status);
return out_tensor;
}
#endif // NVJPEG_FOUND
} // namespace image
} // namespace vision