Skip to content

Commit 0e8f840

Browse files
authored
Refactor load_image to return torch.Tensor instead of PIL.Image (pytorch#2366)
1 parent fe17fad commit 0e8f840

File tree

6 files changed

+92
-35
lines changed

6 files changed

+92
-35
lines changed

tests/torchtune/data/test_data_utils.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88

99
import pytest
10+
import torch
1011
from PIL import Image
1112

1213
from tests.common import ASSETS
@@ -107,8 +108,8 @@ def test_load_image(monkeypatch, tmp_path):
107108

108109
# Test loading from local file
109110
image = load_image(tmp_image)
110-
assert isinstance(image, Image.Image)
111-
assert image.size == (580, 403)
111+
assert isinstance(image, torch.Tensor)
112+
assert image.size() == (3, 403, 580)
112113

113114
# Test loading from remote file
114115
# Mock the urlopen function to return a BytesIO object
@@ -117,11 +118,11 @@ def mock_urlopen(url):
117118

118119
monkeypatch.setattr("urllib.request.urlopen", mock_urlopen)
119120
image = load_image("http://example.com/test_image.jpg")
120-
assert isinstance(image, Image.Image)
121-
assert image.size == (580, 403)
121+
assert isinstance(image, torch.Tensor)
122+
assert image.size() == (3, 403, 580)
122123

123124
# Test that a ValueError is raised when the image path is invalid
124-
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
125+
with pytest.raises(ValueError, match="Failed to load local image as torch.Tensor"):
125126
load_image("invalid_path")
126127

127128
# Test a temporary file with invalid image data
@@ -130,16 +131,16 @@ def mock_urlopen(url):
130131
f.write("Invalid image data")
131132

132133
# Test that a ValueError is raised when the image data is invalid
133-
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
134+
with pytest.raises(ValueError, match="Failed to load local image as torch.Tensor"):
134135
load_image(str(image_path))
135136

136137
# Test that a ValueError is raised when there is an HTTP error
137138
# Mock the urlopen function to raise an exception
138139
def mock_urlopen(url):
139-
raise Exception("Failed to load image")
140+
raise Exception("Failed to load remote image as torch.Tensor")
140141

141142
monkeypatch.setattr("urllib.request.urlopen", mock_urlopen)
142-
with pytest.raises(ValueError, match="Failed to load image"):
143+
with pytest.raises(ValueError, match="Failed to load remote image as torch.Tensor"):
143144
load_image("http://example.com/test_image.jpg")
144145

145146
# Test that a ValueError is raised when there is an IO error
@@ -148,7 +149,7 @@ def mock_urlopen(url):
148149
with open(image_path, "w") as f:
149150
f.write("Test data")
150151
os.chmod(image_path, 0o000) # Remove read permissions
151-
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
152+
with pytest.raises(ValueError, match="Failed to load local image as torch.Tensor"):
152153
load_image(str(image_path))
153154
os.chmod(image_path, 0o644) # Restore read permissions
154155

@@ -157,5 +158,5 @@ def mock_urlopen(url):
157158
image_path = tmp_path / "test_image.jpg"
158159
with open(image_path, "wb") as f:
159160
f.write(b"Invalid image data")
160-
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
161+
with pytest.raises(ValueError, match="Failed to load local image as torch.Tensor"):
161162
load_image(str(image_path))

tests/torchtune/datasets/multimodal/test_vqa_dataset.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import pytest
8-
from PIL.PngImagePlugin import PngImageFile
8+
9+
import torch
910
from tests.common import ASSETS
1011
from tests.test_utils import DummyTokenizer
11-
1212
from torchtune.datasets.multimodal import vqa_dataset
1313

1414

@@ -46,7 +46,7 @@ def test_get_item(self, tokenizer):
4646
)
4747
assert prompt == expected_tokens[i]
4848
assert label == expected_labels[i]
49-
assert isinstance(image[0], PngImageFile)
49+
assert isinstance(image[0], torch.Tensor)
5050

5151
def test_dataset_fails_with_packed(self, tokenizer):
5252
with pytest.raises(

tests/torchtune/models/clip/test_clip_image_transform.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class TestCLIPImageTransform:
3131
[
3232
{
3333
"image_size": (100, 400, 3),
34+
"image_type": "PIL.Image",
3435
"expected_shape": torch.Size([2, 3, 224, 224]),
3536
"resize_to_max_canvas": False,
3637
"expected_tile_means": [0.2230, 0.1763],
@@ -40,6 +41,7 @@ class TestCLIPImageTransform:
4041
},
4142
{
4243
"image_size": (1000, 300, 3),
44+
"image_type": "PIL.Image",
4345
"expected_shape": torch.Size([4, 3, 224, 224]),
4446
"resize_to_max_canvas": True,
4547
"expected_tile_means": [0.5007, 0.4995, 0.5003, 0.1651],
@@ -49,6 +51,7 @@ class TestCLIPImageTransform:
4951
},
5052
{
5153
"image_size": (200, 200, 3),
54+
"image_type": "PIL.Image",
5255
"expected_shape": torch.Size([4, 3, 224, 224]),
5356
"resize_to_max_canvas": True,
5457
"expected_tile_means": [0.5012, 0.5020, 0.5011, 0.4991],
@@ -59,6 +62,48 @@ class TestCLIPImageTransform:
5962
},
6063
{
6164
"image_size": (600, 200, 3),
65+
"image_type": "torch.Tensor",
66+
"expected_shape": torch.Size([3, 3, 224, 224]),
67+
"resize_to_max_canvas": False,
68+
"expected_tile_means": [0.4473, 0.4469, 0.3032],
69+
"expected_tile_max": [1.0, 1.0, 1.0],
70+
"expected_tile_min": [0.0, 0.0, 0.0],
71+
"expected_aspect_ratio": [3, 1],
72+
},
73+
{
74+
"image_size": (100, 400, 3),
75+
"image_type": "torch.Tensor",
76+
"expected_shape": torch.Size([2, 3, 224, 224]),
77+
"resize_to_max_canvas": False,
78+
"expected_tile_means": [0.2230, 0.1763],
79+
"expected_tile_max": [1.0, 1.0],
80+
"expected_tile_min": [0.0, 0.0],
81+
"expected_aspect_ratio": [1, 2],
82+
},
83+
{
84+
"image_size": (1000, 300, 3),
85+
"image_type": "torch.Tensor",
86+
"expected_shape": torch.Size([4, 3, 224, 224]),
87+
"resize_to_max_canvas": True,
88+
"expected_tile_means": [0.5007, 0.4995, 0.5003, 0.1651],
89+
"expected_tile_max": [0.9705, 0.9694, 0.9521, 0.9314],
90+
"expected_tile_min": [0.0353, 0.0435, 0.0528, 0.0],
91+
"expected_aspect_ratio": [4, 1],
92+
},
93+
{
94+
"image_size": (200, 200, 3),
95+
"image_type": "torch.Tensor",
96+
"expected_shape": torch.Size([4, 3, 224, 224]),
97+
"resize_to_max_canvas": True,
98+
"expected_tile_means": [0.5012, 0.5020, 0.5011, 0.4991],
99+
"expected_tile_max": [0.9922, 0.9926, 0.9970, 0.9908],
100+
"expected_tile_min": [0.0056, 0.0069, 0.0059, 0.0033],
101+
"expected_aspect_ratio": [2, 2],
102+
"pad_tiles": 1,
103+
},
104+
{
105+
"image_size": (600, 200, 3),
106+
"image_type": "torch.Tensor",
62107
"expected_shape": torch.Size([3, 3, 224, 224]),
63108
"resize_to_max_canvas": False,
64109
"expected_tile_means": [0.4473, 0.4469, 0.3032],
@@ -99,7 +144,10 @@ def test_clip_image_transform(self, params):
99144
.reshape(image_size)
100145
.astype(np.uint8)
101146
)
102-
image = PIL.Image.fromarray(image)
147+
if params["image_type"] == "PIL.Image":
148+
image = PIL.Image.fromarray(image)
149+
elif params["image_type"] == "torch.Tensor":
150+
image = torch.from_numpy(image).permute(2, 0, 1)
103151

104152
# Apply the transformation
105153
output = image_transform({"image": image})

torchtune/data/_utils.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union
99
from urllib import request
1010

11+
import torch
12+
import torchvision
1113
from datasets import load_dataset
1214
from datasets.distributed import split_dataset_by_node
1315
from torch.utils.data import default_collate, DistributedSampler
@@ -44,9 +46,9 @@ def truncate(
4446
return tokens_truncated
4547

4648

47-
def load_image(image_loc: Union[Path, str]) -> "PIL.Image.Image":
49+
def load_image(image_loc: Union[Path, str]) -> torch.Tensor:
4850
"""
49-
Convenience method to load an image in PIL format from a local file path or remote source.
51+
Convenience method to load an image in torch.Tensor format from a local file path or remote source.
5052
5153
Args:
5254
image_loc (Union[Path, str]): Local file path or remote source pointing to the image
@@ -59,7 +61,7 @@ def load_image(image_loc: Union[Path, str]) -> "PIL.Image.Image":
5961
Raises:
6062
ValueError:
6163
If the image cannot be loaded from remote source, **or**
62-
if the image cannot be opened as a :class:`~PIL.Image.Image`.
64+
if the image cannot be opened as a :class:`~torch.Tensor`.
6365
6466
Examples:
6567
>>> # Load from remote source
@@ -69,25 +71,25 @@ def load_image(image_loc: Union[Path, str]) -> "PIL.Image.Image":
6971
>>> image = load_image(Path("/home/user/bird.jpg"))
7072
7173
Returns:
72-
PIL.Image.Image: The loaded image.
74+
torch.Tensor: The loaded image.
7375
"""
74-
# Hackily import PIL to avoid burdensome import in the main module
75-
# TODO: Fix this
76-
from PIL import Image
77-
7876
# If pointing to remote source, try to load to local
7977
if isinstance(image_loc, str) and image_loc.startswith("http"):
8078
try:
81-
image_loc = request.urlopen(image_loc)
79+
image_loc = request.urlopen(image_loc).read()
80+
image = torchvision.io.decode_image(
81+
torch.frombuffer(image_loc, dtype=torch.uint8),
82+
mode="RGB",
83+
)
8284
except Exception as e:
83-
raise ValueError(f"Failed to load image from {image_loc}") from e
84-
85-
# Open the local image as a PIL image
86-
try:
87-
image = Image.open(image_loc)
88-
except Exception as e:
89-
raise ValueError(f"Failed to open image as PIL Image from {image_loc}") from e
85+
raise ValueError("Failed to load remote image as torch.Tensor") from e
9086

87+
# Open the local image as a Tensor image
88+
else:
89+
try:
90+
image = torchvision.io.decode_image(image_loc, mode="RGB")
91+
except Exception as e:
92+
raise ValueError("Failed to load local image as torch.Tensor") from e
9193
return image
9294

9395

torchtune/models/clip/_transform.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,12 @@ def __call__(
156156
"aspect_ratio" field.
157157
"""
158158
image = sample["image"]
159-
assert isinstance(image, Image.Image), "Input image must be a PIL image."
159+
assert isinstance(
160+
image, (Image.Image, torch.Tensor)
161+
), "Input image must be a PIL image or a torch.Tensor."
160162

161163
# Make image torch.tensor((3, H, W), dtype=dtype), 0<=values<=1
162-
if image.mode != "RGB":
164+
if isinstance(image, Image.Image) and image.mode != "RGB":
163165
image = image.convert("RGB")
164166
image = F.to_image(image)
165167
image = F.to_dtype(image, dtype=self.dtype, scale=True)

torchtune/models/clip/inference/_transform.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8-
from typing import Any, List, Mapping, Optional, Tuple
8+
from typing import Any, List, Mapping, Optional, Tuple, Union
99

1010
import torch
1111
import torchvision
@@ -252,9 +252,13 @@ def __init__(
252252
antialias=self.antialias,
253253
)
254254

255-
def __call__(self, *, image: Image.Image, **kwargs) -> Mapping[str, Any]:
255+
def __call__(
256+
self, *, image: Union[Image.Image, torch.Tensor], **kwargs
257+
) -> Mapping[str, Any]:
256258

257-
assert isinstance(image, Image.Image), "Input image must be a PIL image."
259+
assert isinstance(
260+
image, (Image.Image, torch.Tensor)
261+
), "Input image must be a PIL image or torch.Tensor."
258262

259263
# Make image torch.tensor((3, H, W), dtype='float32'), 0<=values<=1.
260264
image_tensor = F.to_dtype(

0 commit comments

Comments
 (0)