File tree 1 file changed +41
-0
lines changed
1 file changed +41
-0
lines changed Original file line number Diff line number Diff line change
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import unittest
3
+
4
+ from pythia .common .registry import registry
5
+ from pythia .tasks .base_dataset import BaseDataset
6
+ from pythia .utils .configuration import Configuration
7
+
8
+
9
+ class TestBaseDataset (unittest .TestCase ):
10
+ def test_init_processors (self ):
11
+ configuration = Configuration (
12
+ "../../common/defaults/configs/tasks/vqa/vqa2.yml"
13
+ )
14
+ configuration .freeze ()
15
+
16
+ base_dataset = BaseDataset (
17
+ "vqa" ,
18
+ "vqa2" ,
19
+ configuration .get_config ()["task_attributes" ]["vqa" ]["dataset_attributes" ][
20
+ "vqa2"
21
+ ],
22
+ )
23
+ expected_processors = [
24
+ "text_processor" ,
25
+ "answer_processor" ,
26
+ "context_processor" ,
27
+ "ocr_token_processor" ,
28
+ "bbox_processor" ,
29
+ ]
30
+
31
+ # Check no processors are initialized before init_processors call
32
+ self .assertFalse (any (hasattr (base_dataset , key ) for key in expected_processors ))
33
+
34
+ for processor in expected_processors :
35
+ self .assertIsNone (registry .get ("{}_{}" .format ("vqa" , processor )))
36
+
37
+ # Check processors are initialized after init_processors
38
+ base_dataset .init_processors ()
39
+ self .assertTrue (all (hasattr (base_dataset , key ) for key in expected_processors ))
40
+ for processor in expected_processors :
41
+ self .assertIsNotNone (registry .get ("{}_{}" .format ("vqa" , processor )))
You can’t perform that action at this time.
0 commit comments