Skip to content

Commit f969577

Browse files
vedanujapsdehal
authored andcommitted
Add tests for init processor in base_dataset class (#48)
1 parent 358d034 commit f969577

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

pythia/tests/tasks/base_dataset.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import unittest
3+
4+
from pythia.common.registry import registry
5+
from pythia.tasks.base_dataset import BaseDataset
6+
from pythia.utils.configuration import Configuration
7+
8+
9+
class TestBaseDataset(unittest.TestCase):
10+
def test_init_processors(self):
11+
configuration = Configuration(
12+
"../../common/defaults/configs/tasks/vqa/vqa2.yml"
13+
)
14+
configuration.freeze()
15+
16+
base_dataset = BaseDataset(
17+
"vqa",
18+
"vqa2",
19+
configuration.get_config()["task_attributes"]["vqa"]["dataset_attributes"][
20+
"vqa2"
21+
],
22+
)
23+
expected_processors = [
24+
"text_processor",
25+
"answer_processor",
26+
"context_processor",
27+
"ocr_token_processor",
28+
"bbox_processor",
29+
]
30+
31+
# Check no processors are initialized before init_processors call
32+
self.assertFalse(any(hasattr(base_dataset, key) for key in expected_processors))
33+
34+
for processor in expected_processors:
35+
self.assertIsNone(registry.get("{}_{}".format("vqa", processor)))
36+
37+
# Check processors are initialized after init_processors
38+
base_dataset.init_processors()
39+
self.assertTrue(all(hasattr(base_dataset, key) for key in expected_processors))
40+
for processor in expected_processors:
41+
self.assertIsNotNone(registry.get("{}_{}".format("vqa", processor)))

0 commit comments

Comments
 (0)