Skip to content

Commit f7d9e75

Browse files
authored
Support encoded RLE format in for COCO segmentations (#8387)
1 parent 26af015 commit f7d9e75

File tree

2 files changed

+44
-11
lines changed

2 files changed

+44
-11
lines changed

test/test_datasets.py

+36-6
Original file line numberDiff line numberDiff line change
@@ -782,32 +782,46 @@ def inject_fake_data(self, tmpdir, config):
782782

783783
annotation_folder = tmpdir / self._ANNOTATIONS_FOLDER
784784
os.makedirs(annotation_folder)
785+
786+
segmentation_kind = config.pop("segmentation_kind", "list")
785787
info = self._create_annotation_file(
786-
annotation_folder, self._ANNOTATIONS_FILE, file_names, num_annotations_per_image
788+
annotation_folder,
789+
self._ANNOTATIONS_FILE,
790+
file_names,
791+
num_annotations_per_image,
792+
segmentation_kind=segmentation_kind,
787793
)
788794

789795
info["num_examples"] = num_images
790796
return info
791797

792-
def _create_annotation_file(self, root, name, file_names, num_annotations_per_image):
798+
def _create_annotation_file(self, root, name, file_names, num_annotations_per_image, segmentation_kind="list"):
793799
image_ids = [int(file_name.stem) for file_name in file_names]
794800
images = [dict(file_name=str(file_name), id=id) for file_name, id in zip(file_names, image_ids)]
795801

796-
annotations, info = self._create_annotations(image_ids, num_annotations_per_image)
802+
annotations, info = self._create_annotations(image_ids, num_annotations_per_image, segmentation_kind)
797803
self._create_json(root, name, dict(images=images, annotations=annotations))
798804

799805
return info
800806

801-
def _create_annotations(self, image_ids, num_annotations_per_image):
807+
def _create_annotations(self, image_ids, num_annotations_per_image, segmentation_kind="list"):
802808
annotations = []
803809
annotion_id = 0
810+
804811
for image_id in itertools.islice(itertools.cycle(image_ids), len(image_ids) * num_annotations_per_image):
812+
segmentation = {
813+
"list": [torch.rand(8).tolist()],
814+
"rle": {"size": [10, 10], "counts": [1]},
815+
"rle_encoded": {"size": [2400, 2400], "counts": "PQRQ2[1\\Y2f0gNVNRhMg2"},
816+
"bad": 123,
817+
}[segmentation_kind]
818+
805819
annotations.append(
806820
dict(
807821
image_id=image_id,
808822
id=annotion_id,
809823
bbox=torch.rand(4).tolist(),
810-
segmentation=[torch.rand(8).tolist()],
824+
segmentation=segmentation,
811825
category_id=int(torch.randint(91, ())),
812826
area=float(torch.rand(1)),
813827
iscrowd=int(torch.randint(2, size=(1,))),
@@ -832,11 +846,27 @@ def test_slice_error(self):
832846
with pytest.raises(ValueError, match="Index must be of type integer"):
833847
dataset[:2]
834848

849+
def test_segmentation_kind(self):
850+
if isinstance(self, CocoCaptionsTestCase):
851+
return
852+
853+
for segmentation_kind in ("list", "rle", "rle_encoded"):
854+
config = {"segmentation_kind": segmentation_kind}
855+
with self.create_dataset(config) as (dataset, _):
856+
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys="all")
857+
list(dataset)
858+
859+
config = {"segmentation_kind": "bad"}
860+
with self.create_dataset(config) as (dataset, _):
861+
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys="all")
862+
with pytest.raises(ValueError, match="COCO segmentation expected to be a dict or a list"):
863+
list(dataset)
864+
835865

836866
class CocoCaptionsTestCase(CocoDetectionTestCase):
837867
DATASET_CLASS = datasets.CocoCaptions
838868

839-
def _create_annotations(self, image_ids, num_annotations_per_image):
869+
def _create_annotations(self, image_ids, num_annotations_per_image, segmentation_kind="list"):
840870
captions = [str(idx) for idx in range(num_annotations_per_image)]
841871
annotations = combinations_grid(image_id=image_ids, caption=captions)
842872
for id, annotation in enumerate(annotations):

torchvision/tv_tensors/_dataset_wrapper.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -359,11 +359,14 @@ def coco_dectection_wrapper_factory(dataset, target_keys):
359359
def segmentation_to_mask(segmentation, *, canvas_size):
360360
from pycocotools import mask
361361

362-
segmentation = (
363-
mask.frPyObjects(segmentation, *canvas_size)
364-
if isinstance(segmentation, dict)
365-
else mask.merge(mask.frPyObjects(segmentation, *canvas_size))
366-
)
362+
if isinstance(segmentation, dict):
363+
# if counts is a string, it is already an encoded RLE mask
364+
if not isinstance(segmentation["counts"], str):
365+
segmentation = mask.frPyObjects(segmentation, *canvas_size)
366+
elif isinstance(segmentation, list):
367+
segmentation = mask.merge(mask.frPyObjects(segmentation, *canvas_size))
368+
else:
369+
raise ValueError(f"COCO segmentation expected to be a dict or a list, got {type(segmentation)}")
367370
return torch.from_numpy(mask.decode(segmentation))
368371

369372
def wrapper(idx, sample):

0 commit comments

Comments
 (0)