From a153b39231642426a8376e3df93121041bcb03ed Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Sat, 19 Sep 2020 13:36:02 +0900 Subject: [PATCH] Fix expander --- atcoder/__main__.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/atcoder/__main__.py b/atcoder/__main__.py index da301d2..9ce7b14 100644 --- a/atcoder/__main__.py +++ b/atcoder/__main__.py @@ -68,6 +68,7 @@ def import_module(self, import_from: Optional[str], name: str, imports = iter_child_nodes(ast.parse(source)) import_lines = [] + import_list = [] for import_info in imports: result += self.import_module( import_info.import_from, import_info.name, @@ -76,6 +77,11 @@ def import_module(self, import_from: Optional[str], name: str, import_info.end_lineno): import_lines.append(line) + if import_info.import_from is None: + import_list.append(import_info.name) + else: + import_list.append(import_info.import_from) + for lineno, line in enumerate(lines): if lineno not in import_lines: continue @@ -92,6 +98,17 @@ def import_module(self, import_from: Optional[str], name: str, result += f"{module_name} = types.ModuleType('{module_name}')\n" result += f'exec({code}, {module_name}.__dict__)\n' + imported = [] + for import_ in import_list: + modules = import_.split('.') + for i in range(len(modules)): + import_name = '.'.join(modules[:i + 1]) + if import_name in imported: + continue + imported.append(import_name) + result += f"{module_name}.__dict__['{import_name}']" \ + f" = {import_name}\n" + if import_from is None: if asname is None: if name != module_name: