Skip to content

Commit 3dfe7e6

Browse files
authored
Merge branch 'main' into please_dont_modify_this_branch_unless_you_are_just_merging_with_main__
2 parents d0b4f8f + 124dfa4 commit 3dfe7e6

39 files changed

+812
-212
lines changed

packaging/pre_build_script.sh

+7-6
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@ if [[ "$(uname)" == Darwin ]]; then
99
brew uninstall --ignore-dependencies --force $pkg || true
1010
done
1111

12-
conda install -yq wget
12+
conda install -y wget
1313
fi
1414

1515
if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then
16-
conda install libpng libwebp -yq
16+
conda install libpng libwebp -y
1717
# Installing webp also installs a non-turbo jpeg, so we uninstall jpeg stuff
1818
# before re-installing them
1919
conda uninstall libjpeg-turbo libjpeg -y
20-
conda install -yq ffmpeg=4.2 libjpeg-turbo -c pytorch
20+
conda install -y ffmpeg=4.2 -c pytorch
21+
conda install -y libjpeg-turbo -c pytorch
2122

2223
# Copy binaries to be included in the wheel distribution
2324
if [[ "$OSTYPE" == "msys" ]]; then
@@ -28,11 +29,11 @@ if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then
2829
else
2930

3031
if [[ "$ARCH" == "aarch64" ]]; then
31-
conda install libpng -yq
32-
conda install -yq ffmpeg=4.2 libjpeg-turbo -c pytorch-nightly
32+
conda install libpng -y
33+
conda install -y ffmpeg=4.2 libjpeg-turbo -c pytorch-nightly
3334
fi
3435

35-
conda install libwebp -yq
36+
conda install libwebp -y
3637
conda install libjpeg-turbo -c pytorch
3738
yum install -y freetype gnutls
3839
pip install auditwheel

setup.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import distutils.spawn
33
import glob
44
import os
5+
import shlex
56
import shutil
67
import subprocess
78
import sys
@@ -95,8 +96,14 @@ def get_dist(pkgname):
9596
return None
9697

9798
pytorch_dep = os.getenv("TORCH_PACKAGE_NAME", "torch")
98-
if os.getenv("PYTORCH_VERSION"):
99-
pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")
99+
if version_pin := os.getenv("PYTORCH_VERSION"):
100+
pytorch_dep += "==" + version_pin
101+
elif (version_pin_ge := os.getenv("PYTORCH_VERSION_GE")) and (version_pin_lt := os.getenv("PYTORCH_VERSION_LT")):
102+
# This branch and the associated env vars exist to help third-party
103+
# builds like in https://github.com/pytorch/vision/pull/8936. This is
104+
# supported on a best-effort basis, we don't guarantee that this won't
105+
# eventually break (and we don't test it.)
106+
pytorch_dep += f">={version_pin_ge},<{version_pin_lt}"
100107

101108
requirements = [
102109
"numpy",
@@ -123,7 +130,7 @@ def get_macros_and_flags():
123130
if NVCC_FLAGS is None:
124131
nvcc_flags = []
125132
else:
126-
nvcc_flags = NVCC_FLAGS.split(" ")
133+
nvcc_flags = shlex.split(NVCC_FLAGS)
127134
extra_compile_args["nvcc"] = nvcc_flags
128135

129136
if sys.platform == "win32":

test/common_utils.py

+18
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ def sample_position(values, max_value):
423423
h, w = [torch.randint(1, s, (num_boxes,)) for s in canvas_size]
424424
y = sample_position(h, canvas_size[0])
425425
x = sample_position(w, canvas_size[1])
426+
r = -360 * torch.rand((num_boxes,)) + 180
426427

427428
if format is tv_tensors.BoundingBoxFormat.XYWH:
428429
parts = (x, y, w, h)
@@ -435,6 +436,23 @@ def sample_position(values, max_value):
435436
cx = x + w / 2
436437
cy = y + h / 2
437438
parts = (cx, cy, w, h)
439+
elif format is tv_tensors.BoundingBoxFormat.XYWHR:
440+
parts = (x, y, w, h, r)
441+
elif format is tv_tensors.BoundingBoxFormat.CXCYWHR:
442+
cx = x + w / 2
443+
cy = y + h / 2
444+
parts = (cx, cy, w, h, r)
445+
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
446+
r_rad = r * torch.pi / 180.0
447+
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
448+
x1, y1 = x, y
449+
x3 = x1 + w * cos
450+
y3 = y1 - w * sin
451+
x2 = x3 + h * sin
452+
y2 = y3 + h * cos
453+
x4 = x1 + h * sin
454+
y4 = y1 + h * cos
455+
parts = (x1, y1, x3, y3, x2, y2, x4, y4)
438456
else:
439457
raise ValueError(f"Format {format} is not supported")
440458

test/datasets_utils.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ class ImageDatasetTestCase(DatasetTestCase):
611611
"""
612612

613613
FEATURE_TYPES = (PIL.Image.Image, int)
614+
SUPPORT_TV_IMAGE_DECODE: bool = False
614615

615616
@contextlib.contextmanager
616617
def create_dataset(
@@ -632,22 +633,34 @@ def create_dataset(
632633
# This problem only occurs during testing since some tests, e.g. DatasetTestCase.test_feature_types open an
633634
# image, but never use the underlying data. During normal operation it is reasonable to assume that the
634635
# user wants to work with the image he just opened rather than deleting the underlying file.
635-
with self._force_load_images():
636+
with self._force_load_images(loader=(config or {}).get("loader", None)):
636637
yield dataset, info
637638

638639
@contextlib.contextmanager
639-
def _force_load_images(self):
640-
open = PIL.Image.open
640+
def _force_load_images(self, loader: Optional[Callable[[str], Any]] = None):
641+
open = loader or PIL.Image.open
641642

642643
def new(fp, *args, **kwargs):
643644
image = open(fp, *args, **kwargs)
644-
if isinstance(fp, (str, pathlib.Path)):
645+
if isinstance(fp, (str, pathlib.Path)) and isinstance(image, PIL.Image.Image):
645646
image.load()
646647
return image
647648

648-
with unittest.mock.patch("PIL.Image.open", new=new):
649+
with unittest.mock.patch(open.__module__ + "." + open.__qualname__, new=new):
649650
yield
650651

652+
def test_tv_decode_image_support(self):
653+
if not self.SUPPORT_TV_IMAGE_DECODE:
654+
pytest.skip(f"{self.DATASET_CLASS.__name__} does not support torchvision.io.decode_image.")
655+
656+
with self.create_dataset(
657+
config=dict(
658+
loader=torchvision.io.decode_image,
659+
)
660+
) as (dataset, _):
661+
image = dataset[0][0]
662+
assert isinstance(image, torch.Tensor)
663+
651664

652665
class VideoDatasetTestCase(DatasetTestCase):
653666
"""Abstract base class for video dataset testcases.

test/test_datasets.py

+42
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch.nn.functional as F
2525
from common_utils import combinations_grid
2626
from torchvision import datasets
27+
from torchvision.io import decode_image
2728
from torchvision.transforms import v2
2829

2930

@@ -405,6 +406,8 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
405406
REQUIRED_PACKAGES = ("scipy",)
406407
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val"))
407408

409+
SUPPORT_TV_IMAGE_DECODE = True
410+
408411
def inject_fake_data(self, tmpdir, config):
409412
tmpdir = pathlib.Path(tmpdir)
410413

@@ -1173,6 +1176,8 @@ class SBUTestCase(datasets_utils.ImageDatasetTestCase):
11731176
DATASET_CLASS = datasets.SBU
11741177
FEATURE_TYPES = (PIL.Image.Image, str)
11751178

1179+
SUPPORT_TV_IMAGE_DECODE = True
1180+
11761181
def inject_fake_data(self, tmpdir, config):
11771182
num_images = 3
11781183

@@ -1411,6 +1416,8 @@ class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase):
14111416
_IMAGES_FOLDER = "images"
14121417
_ANNOTATIONS_FILE = "captions.html"
14131418

1419+
SUPPORT_TV_IMAGE_DECODE = True
1420+
14141421
def dataset_args(self, tmpdir, config):
14151422
tmpdir = pathlib.Path(tmpdir)
14161423
root = tmpdir / self._IMAGES_FOLDER
@@ -1480,6 +1487,8 @@ class Flickr30kTestCase(Flickr8kTestCase):
14801487

14811488
_ANNOTATIONS_FILE = "captions.token"
14821489

1490+
SUPPORT_TV_IMAGE_DECODE = True
1491+
14831492
def _image_file_name(self, idx):
14841493
return f"{idx}.jpg"
14851494

@@ -1940,6 +1949,8 @@ class LFWPeopleTestCase(datasets_utils.DatasetTestCase):
19401949
_IMAGES_DIR = {"original": "lfw", "funneled": "lfw_funneled", "deepfunneled": "lfw-deepfunneled"}
19411950
_file_id = {"10fold": "", "train": "DevTrain", "test": "DevTest"}
19421951

1952+
SUPPORT_TV_IMAGE_DECODE = True
1953+
19431954
def inject_fake_data(self, tmpdir, config):
19441955
tmpdir = pathlib.Path(tmpdir) / "lfw-py"
19451956
os.makedirs(tmpdir, exist_ok=True)
@@ -1976,6 +1987,18 @@ def _create_random_id(self):
19761987
part2 = datasets_utils.create_random_string(random.randint(4, 7))
19771988
return f"{part1}_{part2}"
19781989

1990+
def test_tv_decode_image_support(self):
1991+
if not self.SUPPORT_TV_IMAGE_DECODE:
1992+
pytest.skip(f"{self.DATASET_CLASS.__name__} does not support torchvision.io.decode_image.")
1993+
1994+
with self.create_dataset(
1995+
config=dict(
1996+
loader=decode_image,
1997+
)
1998+
) as (dataset, _):
1999+
image = dataset[0][0]
2000+
assert isinstance(image, torch.Tensor)
2001+
19792002

19802003
class LFWPairsTestCase(LFWPeopleTestCase):
19812004
DATASET_CLASS = datasets.LFWPairs
@@ -2308,6 +2331,7 @@ def inject_fake_data(self, tmpdir, config):
23082331
class EuroSATTestCase(datasets_utils.ImageDatasetTestCase):
23092332
DATASET_CLASS = datasets.EuroSAT
23102333
FEATURE_TYPES = (PIL.Image.Image, int)
2334+
SUPPORT_TV_IMAGE_DECODE = True
23112335

23122336
def inject_fake_data(self, tmpdir, config):
23132337
data_folder = os.path.join(tmpdir, "eurosat", "2750")
@@ -2332,6 +2356,8 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase):
23322356

23332357
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
23342358

2359+
SUPPORT_TV_IMAGE_DECODE = True
2360+
23352361
def inject_fake_data(self, tmpdir: str, config):
23362362
root_folder = pathlib.Path(tmpdir) / "food-101"
23372363
image_folder = root_folder / "images"
@@ -2368,6 +2394,7 @@ class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase):
23682394
ADDITIONAL_CONFIGS = combinations_grid(
23692395
split=("train", "val", "trainval", "test"), annotation_level=("variant", "family", "manufacturer")
23702396
)
2397+
SUPPORT_TV_IMAGE_DECODE = True
23712398

23722399
def inject_fake_data(self, tmpdir: str, config):
23732400
split = config["split"]
@@ -2417,6 +2444,8 @@ def inject_fake_data(self, tmpdir: str, config):
24172444
class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
24182445
DATASET_CLASS = datasets.SUN397
24192446

2447+
SUPPORT_TV_IMAGE_DECODE = True
2448+
24202449
def inject_fake_data(self, tmpdir: str, config):
24212450
data_dir = pathlib.Path(tmpdir) / "SUN397"
24222451
data_dir.mkdir()
@@ -2448,6 +2477,8 @@ class DTDTestCase(datasets_utils.ImageDatasetTestCase):
24482477
DATASET_CLASS = datasets.DTD
24492478
FEATURE_TYPES = (PIL.Image.Image, int)
24502479

2480+
SUPPORT_TV_IMAGE_DECODE = True
2481+
24512482
ADDITIONAL_CONFIGS = combinations_grid(
24522483
split=("train", "test", "val"),
24532484
# There is no need to test the whole matrix here, since each fold is treated exactly the same
@@ -2608,6 +2639,7 @@ class CLEVRClassificationTestCase(datasets_utils.ImageDatasetTestCase):
26082639
FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))
26092640

26102641
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
2642+
SUPPORT_TV_IMAGE_DECODE = True
26112643

26122644
def inject_fake_data(self, tmpdir, config):
26132645
data_folder = pathlib.Path(tmpdir) / "clevr" / "CLEVR_v1.0"
@@ -2705,6 +2737,8 @@ class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
27052737
REQUIRED_PACKAGES = ("scipy",)
27062738
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
27072739

2740+
SUPPORT_TV_IMAGE_DECODE = True
2741+
27082742
def inject_fake_data(self, tmpdir, config):
27092743
import scipy.io as io
27102744
from numpy.core.records import fromarrays
@@ -2749,6 +2783,8 @@ class Country211TestCase(datasets_utils.ImageDatasetTestCase):
27492783

27502784
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "valid", "test"))
27512785

2786+
SUPPORT_TV_IMAGE_DECODE = True
2787+
27522788
def inject_fake_data(self, tmpdir: str, config):
27532789
split_folder = pathlib.Path(tmpdir) / "country211" / config["split"]
27542790
split_folder.mkdir(parents=True, exist_ok=True)
@@ -2777,6 +2813,8 @@ class Flowers102TestCase(datasets_utils.ImageDatasetTestCase):
27772813
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
27782814
REQUIRED_PACKAGES = ("scipy",)
27792815

2816+
SUPPORT_TV_IMAGE_DECODE = True
2817+
27802818
def inject_fake_data(self, tmpdir: str, config):
27812819
base_folder = pathlib.Path(tmpdir) / "flowers-102"
27822820

@@ -2835,6 +2873,8 @@ class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase):
28352873
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
28362874
SPLIT_TO_FOLDER = {"train": "train", "val": "valid", "test": "test"}
28372875

2876+
SUPPORT_TV_IMAGE_DECODE = True
2877+
28382878
def inject_fake_data(self, tmpdir: str, config):
28392879
root_folder = pathlib.Path(tmpdir) / "rendered-sst2"
28402880
image_folder = root_folder / self.SPLIT_TO_FOLDER[config["split"]]
@@ -3495,6 +3535,8 @@ class ImagenetteTestCase(datasets_utils.ImageDatasetTestCase):
34953535
DATASET_CLASS = datasets.Imagenette
34963536
ADDITIONAL_CONFIGS = combinations_grid(split=["train", "val"], size=["full", "320px", "160px"])
34973537

3538+
SUPPORT_TV_IMAGE_DECODE = True
3539+
34983540
_WNIDS = [
34993541
"n01440764",
35003542
"n02102040",

test/test_image.py

+36
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,42 @@ def test_encode_jpeg_cuda(img_path, scripted, contiguous):
623623
assert abs_mean_diff < 3
624624

625625

626+
@needs_cuda
627+
def test_encode_jpeg_cuda_sync():
628+
"""
629+
Non-regression test for https://github.com/pytorch/vision/issues/8587.
630+
Attempts to reproduce an intermittent CUDA stream synchronization bug
631+
by randomly creating images and round-tripping them via encode_jpeg
632+
and decode_jpeg on the GPU. Fails if the mean difference in uint8 range
633+
exceeds 5.
634+
"""
635+
torch.manual_seed(42)
636+
637+
# manual testing shows this bug appearing often in iterations between 50 and 100
638+
# as a synchronization bug, this can't be reliably reproduced
639+
max_iterations = 100
640+
threshold = 5.0 # in [0..255]
641+
642+
device = torch.device("cuda")
643+
644+
for iteration in range(max_iterations):
645+
height, width = torch.randint(4000, 5000, size=(2,))
646+
647+
image = torch.linspace(0, 1, steps=height * width, device=device)
648+
image = image.view(1, height, width).expand(3, -1, -1)
649+
650+
image = (image * 255).clamp(0, 255).to(torch.uint8)
651+
jpeg_bytes = encode_jpeg(image, quality=100)
652+
653+
decoded_image = decode_jpeg(jpeg_bytes.cpu(), device=device)
654+
mean_difference = (image.float() - decoded_image.float()).abs().mean().item()
655+
656+
assert mean_difference <= threshold, (
657+
f"Encode/decode mismatch at iteration={iteration}, "
658+
f"size={height}x{width}, mean diff={mean_difference:.2f}"
659+
)
660+
661+
626662
@pytest.mark.parametrize("device", cpu_and_cuda())
627663
@pytest.mark.parametrize("scripted", (True, False))
628664
@pytest.mark.parametrize("contiguous", (True, False))

0 commit comments

Comments
 (0)