@@ -782,32 +782,46 @@ def inject_fake_data(self, tmpdir, config):
782
782
783
783
annotation_folder = tmpdir / self ._ANNOTATIONS_FOLDER
784
784
os .makedirs (annotation_folder )
785
+
786
+ segmentation_kind = config .pop ("segmentation_kind" , "list" )
785
787
info = self ._create_annotation_file (
786
- annotation_folder , self ._ANNOTATIONS_FILE , file_names , num_annotations_per_image
788
+ annotation_folder ,
789
+ self ._ANNOTATIONS_FILE ,
790
+ file_names ,
791
+ num_annotations_per_image ,
792
+ segmentation_kind = segmentation_kind ,
787
793
)
788
794
789
795
info ["num_examples" ] = num_images
790
796
return info
791
797
792
- def _create_annotation_file (self , root , name , file_names , num_annotations_per_image ):
798
+ def _create_annotation_file (self , root , name , file_names , num_annotations_per_image , segmentation_kind = "list" ):
793
799
image_ids = [int (file_name .stem ) for file_name in file_names ]
794
800
images = [dict (file_name = str (file_name ), id = id ) for file_name , id in zip (file_names , image_ids )]
795
801
796
- annotations , info = self ._create_annotations (image_ids , num_annotations_per_image )
802
+ annotations , info = self ._create_annotations (image_ids , num_annotations_per_image , segmentation_kind )
797
803
self ._create_json (root , name , dict (images = images , annotations = annotations ))
798
804
799
805
return info
800
806
801
- def _create_annotations (self , image_ids , num_annotations_per_image ):
807
+ def _create_annotations (self , image_ids , num_annotations_per_image , segmentation_kind = "list" ):
802
808
annotations = []
803
809
annotion_id = 0
810
+
804
811
for image_id in itertools .islice (itertools .cycle (image_ids ), len (image_ids ) * num_annotations_per_image ):
812
+ segmentation = {
813
+ "list" : [torch .rand (8 ).tolist ()],
814
+ "rle" : {"size" : [10 , 10 ], "counts" : [1 ]},
815
+ "rle_encoded" : {"size" : [2400 , 2400 ], "counts" : "PQRQ2[1\\ Y2f0gNVNRhMg2" },
816
+ "bad" : 123 ,
817
+ }[segmentation_kind ]
818
+
805
819
annotations .append (
806
820
dict (
807
821
image_id = image_id ,
808
822
id = annotion_id ,
809
823
bbox = torch .rand (4 ).tolist (),
810
- segmentation = [ torch . rand ( 8 ). tolist ()] ,
824
+ segmentation = segmentation ,
811
825
category_id = int (torch .randint (91 , ())),
812
826
area = float (torch .rand (1 )),
813
827
iscrowd = int (torch .randint (2 , size = (1 ,))),
@@ -832,11 +846,27 @@ def test_slice_error(self):
832
846
with pytest .raises (ValueError , match = "Index must be of type integer" ):
833
847
dataset [:2 ]
834
848
849
+ def test_segmentation_kind (self ):
850
+ if isinstance (self , CocoCaptionsTestCase ):
851
+ return
852
+
853
+ for segmentation_kind in ("list" , "rle" , "rle_encoded" ):
854
+ config = {"segmentation_kind" : segmentation_kind }
855
+ with self .create_dataset (config ) as (dataset , _ ):
856
+ dataset = datasets .wrap_dataset_for_transforms_v2 (dataset , target_keys = "all" )
857
+ list (dataset )
858
+
859
+ config = {"segmentation_kind" : "bad" }
860
+ with self .create_dataset (config ) as (dataset , _ ):
861
+ dataset = datasets .wrap_dataset_for_transforms_v2 (dataset , target_keys = "all" )
862
+ with pytest .raises (ValueError , match = "COCO segmentation expected to be a dict or a list" ):
863
+ list (dataset )
864
+
835
865
836
866
class CocoCaptionsTestCase (CocoDetectionTestCase ):
837
867
DATASET_CLASS = datasets .CocoCaptions
838
868
839
- def _create_annotations (self , image_ids , num_annotations_per_image ):
869
+ def _create_annotations (self , image_ids , num_annotations_per_image , segmentation_kind = "list" ):
840
870
captions = [str (idx ) for idx in range (num_annotations_per_image )]
841
871
annotations = combinations_grid (image_id = image_ids , caption = captions )
842
872
for id , annotation in enumerate (annotations ):
0 commit comments