Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for EXIF orientation transform in read_image for JPEG #8279

Merged
merged 8 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torchvision.transforms.functional as F
from common_utils import assert_equal, needs_cuda
from PIL import __version__ as PILLOW_VERSION, Image
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps
from torchvision.io.image import (
_read_png_16,
decode_image,
Expand Down Expand Up @@ -100,6 +100,25 @@ def test_decode_jpeg(img_path, pil_mode, mode):
assert abs_mean_diff < 2


@pytest.mark.parametrize("orientation", [1, 2, 3, 4, 5, 6, 7, 8, 0])
def test_decode_jpeg_with_exif_orientation(tmpdir, orientation):
fp = os.path.join(tmpdir, f"exif_oriented_{orientation}.jpg")
t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
im = F.to_pil_image(t)
exif = im.getexif()
exif[274] = orientation # set exif orientation
im.save(fp, "JPEG", exif=exif.tobytes())

data = read_file(fp)
output = decode_image(data, apply_exif_orientation=True)

pimg = Image.open(fp)
pimg = ImageOps.exif_transpose(pimg)

expected = F.pil_to_tensor(pimg)
torch.testing.assert_close(expected, output)


def test_decode_jpeg_errors():
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
Expand Down
10 changes: 8 additions & 2 deletions torchvision/csrc/io/image/cpu/decode_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
namespace vision {
namespace image {

torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
torch::Tensor decode_image(
const torch::Tensor& data,
ImageReadMode mode,
bool apply_exif_orientation) {
// Check that tensor is a CPU tensor
TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor");
// Check that the input tensor dtype is uint8
Expand All @@ -22,8 +25,11 @@ torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"

if (memcmp(jpeg_signature, datap, 3) == 0) {
return decode_jpeg(data, mode);
return decode_jpeg(data, mode, apply_exif_orientation);
} else if (memcmp(png_signature, datap, 4) == 0) {
TORCH_CHECK(
!apply_exif_orientation,
"Unsupported option apply_exif_orientation=true for PNG")
return decode_png(data, mode);
} else {
TORCH_CHECK(
Expand Down
3 changes: 2 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_image.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ namespace image {

C10_EXPORT torch::Tensor decode_image(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
bool apply_exif_orientation = false);

} // namespace image
} // namespace vision
64 changes: 62 additions & 2 deletions torchvision/csrc/io/image/cpu/decode_jpeg.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "decode_jpeg.h"
#include "common_jpeg.h"
#include "exif.h"

namespace vision {
namespace image {
Expand All @@ -12,6 +13,7 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
#else

using namespace detail;
using namespace exif_private;

namespace {

Expand Down Expand Up @@ -65,6 +67,8 @@ static void torch_jpeg_set_source_mgr(
src->len = len;
src->pub.bytes_in_buffer = len;
src->pub.next_input_byte = src->data;

jpeg_save_markers(cinfo, APP1, 0xffff);
}

inline unsigned char clamped_cmyk_rgb_convert(
Expand Down Expand Up @@ -121,7 +125,10 @@ void convert_line_cmyk_to_gray(

} // namespace

torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
torch::Tensor decode_jpeg(
const torch::Tensor& data,
ImageReadMode mode,
bool apply_exif_orientation) {
C10_LOG_API_USAGE_ONCE(
"torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg");
// Check that the input tensor dtype is uint8
Expand Down Expand Up @@ -191,6 +198,54 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
jpeg_calc_output_dimensions(&cinfo);
}

int exif_orientation = 0;
if (apply_exif_orientation) {
// Check for Exif marker APP1
jpeg_saved_marker_ptr exif_marker = 0;
jpeg_saved_marker_ptr cmarker = cinfo.marker_list;
while (cmarker && exif_marker == 0) {
if (cmarker->marker == APP1) {
exif_marker = cmarker;
}
cmarker = cmarker->next;
}

if (exif_marker) {
// Code below is inspired from OpenCV
// https://github.com/opencv/opencv/blob/097891e311fae1d8354eb092a0fd0171e630d78c/modules/modules/imgcodecs/src/exif.cpp

// Bytes from Exif size field to the first TIFF header
constexpr size_t start_offset = 6;
if (exif_marker->data_length > start_offset) {
auto* exif_data_ptr = exif_marker->data + start_offset;
auto size = exif_marker->data_length - start_offset;
std::vector<unsigned char> exif_data_vec(
exif_data_ptr, exif_data_ptr + size);

auto endianness = get_endianness(exif_data_vec);

// Checking whether Tag Mark (0x002A) correspond to one contained in the
// Jpeg file
uint16_t tag_mark = get_uint16(exif_data_vec, endianness, 2);
if (tag_mark == REQ_EXIF_TAG_MARK) {
auto offset = get_uint32(exif_data_vec, endianness, 4);
size_t num_entry = get_uint16(exif_data_vec, endianness, offset);
offset += 2; // go to start of tag fields
constexpr size_t tiff_field_size = 12;
for (size_t entry = 0; entry < num_entry; entry++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any risk that this loop can segfault/overflow if the entries are malformed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_uint16 can return -1 in case offset + required size for uint16 go outside the actual buffer size, so we can catch it here and return earlier.
I hope the code is safe but I can't say 100% sure. I have to investigate a bit more.

// Here we just search for orientation tag and parse it
auto tag_num = get_uint16(exif_data_vec, endianness, offset);
if (tag_num == ORIENTATION_EXIF_TAG) {
exif_orientation =
get_uint16(exif_data_vec, endianness, offset + 8);
}
offset += tiff_field_size;
}
}
}
}
}

jpeg_start_decompress(&cinfo);

int height = cinfo.output_height;
Expand Down Expand Up @@ -227,7 +282,12 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {

jpeg_finish_decompress(&cinfo);
jpeg_destroy_decompress(&cinfo);
return tensor.permute({2, 0, 1});
auto output = tensor.permute({2, 0, 1});

if (apply_exif_orientation) {
return exif_orientation_transform(output, exif_orientation);
}
return output;
}
#endif // #if !JPEG_FOUND

Expand Down
3 changes: 2 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_jpeg.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ namespace image {

C10_EXPORT torch::Tensor decode_jpeg(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
bool apply_exif_orientation = false);

C10_EXPORT int64_t _jpeg_version();
C10_EXPORT bool _is_compiled_against_turbo();
Expand Down
99 changes: 99 additions & 0 deletions torchvision/csrc/io/image/cpu/exif.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#pragma once
#include <torch/types.h>

namespace vision {
namespace image {
namespace exif_private {

constexpr uint16_t APP1 = 0xe1;
constexpr uint16_t ENDIANNESS_INTEL = 0x49;
constexpr uint16_t ENDIANNESS_MOTO = 0x4d;
constexpr uint16_t REQ_EXIF_TAG_MARK = 0x2a;
constexpr uint16_t ORIENTATION_EXIF_TAG = 0x0112;
constexpr uint16_t INCORRECT_TAG = -1;

// Functions in this module are taken from OpenCV
// https://github.com/opencv/opencv/blob/097891e311fae1d8354eb092a0fd0171e630d78c/modules/modules/imgcodecs/src/exif.cpp
inline uint16_t get_endianness(const std::vector<unsigned char>& exif_data) {
if ((exif_data.size() < 1) ||
(exif_data.size() > 1 && exif_data[0] != exif_data[1])) {
return 0;
}
if (exif_data[0] == 'I') {
return ENDIANNESS_INTEL;
}
if (exif_data[0] == 'M') {
return ENDIANNESS_MOTO;
}
return 0;
}

inline uint16_t get_uint16(
const std::vector<unsigned char>& exif_data,
uint16_t endianness,
const size_t offset) {
if (offset + 1 >= exif_data.size()) {
return INCORRECT_TAG;
}

if (endianness == ENDIANNESS_INTEL) {
return exif_data[offset] + (exif_data[offset + 1] << 8);
}
return (exif_data[offset] << 8) + exif_data[offset + 1];
}

inline uint32_t get_uint32(
const std::vector<unsigned char>& exif_data,
uint16_t endianness,
const size_t offset) {
if (offset + 3 >= exif_data.size()) {
return INCORRECT_TAG;
}

if (endianness == ENDIANNESS_INTEL) {
return exif_data[offset] + (exif_data[offset + 1] << 8) +
(exif_data[offset + 2] << 16) + (exif_data[offset + 3] << 24);
}
return (exif_data[offset] << 24) + (exif_data[offset + 1] << 16) +
(exif_data[offset + 2] << 8) + exif_data[offset + 3];
}

constexpr uint16_t IMAGE_ORIENTATION_TL = 1; // normal orientation
constexpr uint16_t IMAGE_ORIENTATION_TR = 2; // needs horizontal flip
constexpr uint16_t IMAGE_ORIENTATION_BR = 3; // needs 180 rotation
constexpr uint16_t IMAGE_ORIENTATION_BL = 4; // needs vertical flip
constexpr uint16_t IMAGE_ORIENTATION_LT =
5; // mirrored horizontal & rotate 270 CW
constexpr uint16_t IMAGE_ORIENTATION_RT = 6; // rotate 90 CW
constexpr uint16_t IMAGE_ORIENTATION_RB =
7; // mirrored horizontal & rotate 90 CW
constexpr uint16_t IMAGE_ORIENTATION_LB = 8; // needs 270 CW rotation

inline torch::Tensor exif_orientation_transform(
const torch::Tensor& image,
int orientation) {
if (orientation == IMAGE_ORIENTATION_TL) {
return image;
} else if (orientation == IMAGE_ORIENTATION_TR) {
return image.flip(-1);
} else if (orientation == IMAGE_ORIENTATION_BR) {
// needs 180 rotation equivalent to
// flip both horizontally and vertically
return image.flip({-2, -1});
} else if (orientation == IMAGE_ORIENTATION_BL) {
return image.flip(-2);
} else if (orientation == IMAGE_ORIENTATION_LT) {
return image.transpose(-1, -2);
} else if (orientation == IMAGE_ORIENTATION_RT) {
return image.transpose(-1, -2).flip(-1);
} else if (orientation == IMAGE_ORIENTATION_RB) {
return image.transpose(-1, -2).flip({-2, -1});
} else if (orientation == IMAGE_ORIENTATION_LB) {
return image.transpose(-1, -2).flip(-2);
}
return image;
}

} // namespace exif_private
} // namespace image
} // namespace vision
4 changes: 2 additions & 2 deletions torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ static auto registry =
torch::RegisterOperators()
.op("image::decode_png", &decode_png)
.op("image::encode_png", &encode_png)
.op("image::decode_jpeg", &decode_jpeg)
.op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", &decode_jpeg)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helps to make the op BC compatible, otherwise apply_exif_orientation becomes required and thus break the old code

.op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
.op("image::decode_image", &decode_image)
.op("image::decode_image(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", &decode_image)
.op("image::decode_jpeg_cuda", &decode_jpeg_cuda)
.op("image::_jpeg_version", &_jpeg_version)
.op("image::_is_compiled_against_turbo", &_is_compiled_against_turbo);
Expand Down
25 changes: 19 additions & 6 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,10 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):


def decode_jpeg(
input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu"
input: torch.Tensor,
mode: ImageReadMode = ImageReadMode.UNCHANGED,
device: str = "cpu",
apply_exif_orientation: bool = False,
) -> torch.Tensor:
"""
Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor.
Expand All @@ -157,6 +160,8 @@ def decode_jpeg(
.. warning::
There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.
apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
Default: False. Only implemented for JPEG format on CPU.

Returns:
output (Tensor[image_channels, image_height, image_width])
Expand All @@ -167,7 +172,7 @@ def decode_jpeg(
if device.type == "cuda":
output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
else:
output = torch.ops.image.decode_jpeg(input, mode.value)
output = torch.ops.image.decode_jpeg(input, mode.value, apply_exif_orientation)
return output


Expand Down Expand Up @@ -212,7 +217,9 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
write_file(filename, output)


def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
def decode_image(
input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False
) -> torch.Tensor:
"""
Detects whether an image is a JPEG or PNG and performs the appropriate
operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Expand All @@ -227,17 +234,21 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
Default: ``ImageReadMode.UNCHANGED``.
See ``ImageReadMode`` class for more information on various
available modes.
apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
Default: False. Only implemented for JPEG format

Returns:
output (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_image)
output = torch.ops.image.decode_image(input, mode.value)
output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
return output


def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
def read_image(
path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False
) -> torch.Tensor:
"""
Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
Expand All @@ -249,14 +260,16 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
Default: ``ImageReadMode.UNCHANGED``.
See ``ImageReadMode`` class for more information on various
available modes.
apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
Default: False. Only implemented for JPEG format

Returns:
output (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_image)
data = read_file(path)
return decode_image(data, mode)
return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation)


def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Expand Down
Loading