Skip to content

Commit 7c63cf7

Browse files
authored
Add generator option for relative imports
2 parents 1592ec8 + 9deab38 commit 7c63cf7

19 files changed

+180
-197
lines changed

tests/codegen/models/test_class.py

+5
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ def test_is_simple_type(self):
137137
obj.extensions.append(ExtensionFactory.create())
138138
self.assertFalse(obj.is_simple_type)
139139

140+
def test_property_is_group(self):
141+
self.assertTrue(ClassFactory.create(tag=Tag.GROUP).is_group)
142+
self.assertTrue(ClassFactory.create(tag=Tag.ATTRIBUTE_GROUP).is_group)
143+
self.assertFalse(ClassFactory.create(tag=Tag.ELEMENT).is_group)
144+
140145
def test_property_should_generate(self):
141146
obj = ClassFactory.create(tag=Tag.ELEMENT)
142147
self.assertTrue(obj.should_generate)

tests/formats/dataclass/test_filters.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import namedtuple
2+
13
from tests.fixtures.datatypes import Telephone
24
from xsdata.codegen.models import Restrictions
35
from xsdata.formats.dataclass.filters import Filters
@@ -27,7 +29,7 @@ class FiltersTests(FactoryTestCase):
2729
def setUp(self) -> None:
2830
super().setUp()
2931
config = GeneratorConfig()
30-
self.filters = Filters.from_config(config)
32+
self.filters = Filters(config)
3133

3234
def test_class_name(self):
3335
self.filters.class_aliases["boom"] = "Bang"
@@ -595,7 +597,27 @@ def test_format_metadata(self):
595597
self.assertEqual(expected, self.filters.format_metadata(data))
596598
self.assertEqual('""', self.filters.format_metadata(""))
597599

598-
def test_from_config(self):
600+
def test_import_module(self):
601+
case = namedtuple("Case", ["module", "from_module", "result"])
602+
cases = [
603+
case("foo.bar", "foo", ".bar"),
604+
case("bar.foo", "foo", "bar.foo"),
605+
case("a.b.e.f", "a.b.c.d", "..e.f"),
606+
case("a.b.c.f", "a.b.c.d", ".f"),
607+
case("a.b.c.f.e", "a.b", ".c.f.e"),
608+
case("a.b.c.f", "", "a.b.c.f"),
609+
]
610+
611+
transform = self.filters.import_module
612+
self.filters.relative_imports = False
613+
for case in cases:
614+
self.assertEqual(case.module, transform(case.module, case.from_module))
615+
616+
self.filters.relative_imports = True
617+
for case in cases:
618+
self.assertEqual(case.result, transform(case.module, case.from_module))
619+
620+
def test__init(self):
599621
config = GeneratorConfig()
600622
config.conventions.package_name.safe_prefix = "safe_package"
601623
config.conventions.package_name.case = NameCase.MIXED
@@ -611,7 +633,7 @@ def test_from_config(self):
611633
config.aliases.package_name.append(GeneratorAlias("g", "h"))
612634
config.aliases.module_name.append(GeneratorAlias("i", "j"))
613635

614-
filters = Filters.from_config(config)
636+
filters = Filters(config)
615637

616638
self.assertEqual("safe_class", filters.class_safe_prefix)
617639
self.assertEqual("safe_field", filters.field_safe_prefix)

tests/formats/dataclass/test_generator.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@ def test_render(self, mock_render_module, mock_render_package):
4444
(cwd.joinpath("thug/life/tests.py"), "thug.life.tests", "module"),
4545
]
4646
self.assertEqual(expected, actual)
47-
mock_render_package.assert_has_calls([mock.call([x]) for x in classes])
47+
mock_render_package.assert_has_calls(
48+
[
49+
mock.call([classes[0]], "foo.bar"),
50+
mock.call([classes[1]], "bar.foo"),
51+
mock.call([classes[2]], "thug.life"),
52+
]
53+
)
4854
mock_render_module.assert_has_calls([mock.call(mock.ANY, [x]) for x in classes])
4955

5056
def test_render_package(self):
@@ -57,7 +63,7 @@ def test_render_package(self):
5763

5864
random.shuffle(classes)
5965

60-
actual = self.generator.render_package(classes)
66+
actual = self.generator.render_package(classes, "foo.tests")
6167
expected = "\n".join(
6268
[
6369
"from foo.bar import A as BarA",

tests/integration/test_books.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ def test_books_schema():
2020
str(schema),
2121
"--package",
2222
package,
23-
"--ns-struct",
24-
"--docstring-style",
25-
"Google",
23+
"--structure-style=namespaces",
24+
"--docstring-style=Google",
2625
],
2726
catch_exceptions=False,
2827
)

tests/models/test_config.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,16 @@
11
import tempfile
2-
import warnings
32
from pathlib import Path
43
from unittest import TestCase
54

65
from xsdata import __version__
76
from xsdata.exceptions import ParserError
87
from xsdata.models.config import GeneratorConfig
9-
from xsdata.models.config import GeneratorOutput
108

119

1210
class GeneratorConfigTests(TestCase):
1311
def setUp(self) -> None:
1412
self.maxDiff = None
1513

16-
def test_deprecation_warning(self):
17-
with warnings.catch_warnings(record=True) as w:
18-
output = GeneratorOutput(format="pydata")
19-
20-
self.assertEqual(
21-
"Output format 'pydata' renamed to 'dataclasses'", str(w[-1].message)
22-
)
23-
self.assertEqual("dataclasses", output.format)
24-
2514
def test_create(self):
2615
file_path = Path(tempfile.mktemp())
2716
obj = GeneratorConfig.create()
@@ -33,7 +22,7 @@ def test_create(self):
3322
f'<Config xmlns="http://pypi.org/project/xsdata" version="{__version__}">\n'
3423
' <Output maxLineLength="79">\n'
3524
" <Package>generated</Package>\n"
36-
" <Format>dataclasses</Format>\n"
25+
' <Format relativeImports="false">dataclasses</Format>\n'
3726
" <Structure>filenames</Structure>\n"
3827
" <DocstringStyle>reStructuredText</DocstringStyle>\n"
3928
" <CompoundFields>false</CompoundFields>\n"
@@ -81,7 +70,7 @@ def test_read(self):
8170
f'<Config xmlns="http://pypi.org/project/xsdata" version="{__version__}">\n'
8271
' <Output maxLineLength="79">\n'
8372
" <Package>foo.bar</Package>\n"
84-
" <Format>dataclasses</Format>\n"
73+
' <Format relativeImports="false">dataclasses</Format>\n'
8574
" <Structure>filenames</Structure>\n"
8675
" <DocstringStyle>reStructuredText</DocstringStyle>\n"
8776
" <CompoundFields>false</CompoundFields>\n"

tests/models/test_mixins.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def test_property_xs_prefix(self):
173173
element = ElementBase()
174174
self.assertIsNone(element.xs_prefix)
175175

176-
element.ns_map["foo"] = Namespace.XS.uri
176+
element.ns_map = {"a": "a", "foo": Namespace.XS.uri}
177177
self.assertEqual("foo", element.xs_prefix)
178178

179179
def test_children(self):

tests/test_cli.py

+9-24
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def test_generate_with_default_output(self, mock_init, mock_process):
3232
self.assertIsNone(result.exception)
3333
self.assertFalse(mock_init.call_args[1]["print"])
3434
self.assertEqual("foo", config.output.package)
35-
self.assertEqual("dataclasses", config.output.format)
35+
self.assertEqual("dataclasses", config.output.format.value)
36+
self.assertFalse(config.output.format.relative_imports)
3637
self.assertEqual(StructureStyle.FILENAMES, config.output.structure)
3738
self.assertEqual([source.as_uri()], mock_process.call_args[0][0])
3839

@@ -47,22 +48,6 @@ def test_generate_with_print_mode(self, mock_init, mock_process):
4748
self.assertEqual(logging.ERROR, logger.getEffectiveLevel())
4849
self.assertTrue(mock_init.call_args[1]["print"])
4950

50-
@mock.patch.object(SchemaTransformer, "process")
51-
@mock.patch.object(SchemaTransformer, "__init__", return_value=None)
52-
def test_generate_with_ns_struct_mode(self, mock_init, mock_process):
53-
source = fixtures_dir.joinpath("defxmlschema/chapter03.xsd")
54-
result = self.runner.invoke(
55-
cli, [str(source), "--package", "foo", "--ns-struct"]
56-
)
57-
config = mock_init.call_args[1]["config"]
58-
59-
self.assertIsNone(result.exception)
60-
self.assertEqual([source.as_uri()], mock_process.call_args[0][0])
61-
self.assertFalse(mock_init.call_args[1]["print"])
62-
self.assertEqual("foo", config.output.package)
63-
self.assertEqual("dataclasses", config.output.format)
64-
self.assertEqual(StructureStyle.NAMESPACES, config.output.structure)
65-
6651
@mock.patch.object(SchemaTransformer, "process")
6752
@mock.patch.object(SchemaTransformer, "__init__", return_value=None)
6853
def test_generate_with_structure_style_mode(self, mock_init, mock_process):
@@ -77,7 +62,7 @@ def test_generate_with_structure_style_mode(self, mock_init, mock_process):
7762
self.assertEqual([source.as_uri()], mock_process.call_args[0][0])
7863
self.assertFalse(mock_init.call_args[1]["print"])
7964
self.assertEqual("foo", config.output.package)
80-
self.assertEqual("dataclasses", config.output.format)
65+
self.assertEqual("dataclasses", config.output.format.value)
8166
self.assertEqual(StructureStyle.SINGLE_PACKAGE, config.output.structure)
8267

8368
@mock.patch.object(SchemaTransformer, "process")
@@ -112,7 +97,7 @@ def test_generate_with_configuration_file(self, mock_init, mock_process):
11297
self.assertIsNone(result.exception)
11398
self.assertFalse(mock_init.call_args[1]["print"])
11499
self.assertEqual("foo.bar", config.output.package)
115-
self.assertEqual("dataclasses", config.output.format)
100+
self.assertEqual("dataclasses", config.output.format.value)
116101
self.assertEqual(StructureStyle.NAMESPACES, config.output.structure)
117102
self.assertEqual([source.as_uri()], mock_process.call_args[0][0])
118103
file_path.unlink()
@@ -132,17 +117,17 @@ def test_generate_with_configuration_file_and_overriding_args(self, mock_init, _
132117
cli,
133118
[
134119
str(source),
135-
"--config",
136-
str(file_path),
137-
"--package",
138-
"foo",
139-
"--ns-struct",
120+
f"--config={file_path}",
121+
"--package=foo",
122+
"--structure-style=namespaces",
123+
"--relative-imports",
140124
],
141125
)
142126
config = mock_init.call_args[1]["config"]
143127

144128
self.assertIsNone(result.exception)
145129
self.assertEqual("foo", config.output.package)
130+
self.assertTrue(config.output.format.relative_imports)
146131
self.assertEqual(StructureStyle.NAMESPACES, config.output.structure)
147132
file_path.unlink()
148133

tests/utils/test_collections.py

-6
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,6 @@ def test_find(self):
1111
self.assertEqual(-1, collections.find([0, 1], 2))
1212
self.assertEqual(1, collections.find([0, 1], 1))
1313

14-
def test_map_key(self):
15-
dictionary = {"a": "b"}
16-
17-
self.assertIsNone(collections.map_key(dictionary, "x"))
18-
self.assertEqual("a", collections.map_key(dictionary, "b"))
19-
2014
def test_prepend(self):
2115
target = [1, 2, 3]
2216
prepend_values = [4, 5, 6]

xsdata/cli.py

+15-23
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from xsdata.logger import logger
1515
from xsdata.models.config import DocstringStyle
1616
from xsdata.models.config import GeneratorConfig
17+
from xsdata.models.config import OutputFormat
1718
from xsdata.models.config import StructureStyle
1819
from xsdata.utils.downloader import Downloader
1920
from xsdata.utils.hooks import load_entry_points
@@ -87,11 +88,6 @@ def download(source: str, output: str):
8788
help=(
8889
"Specify the target package to be created inside the current working directory "
8990
"Default: generated"
90-
"\n\n"
91-
"The generated module structure relies on the common input source path"
92-
"\n\n"
93-
"Use the --ns-struct option for a more flat structure and to avoid circular "
94-
"import errors."
9591
),
9692
default="generated",
9793
)
@@ -115,17 +111,6 @@ def download(source: str, output: str):
115111
),
116112
default="reStructuredText",
117113
)
118-
@click.option(
119-
"-ns",
120-
"--ns-struct",
121-
is_flag=True,
122-
default=False,
123-
help=(
124-
"Use namespaces to group classes in modules. "
125-
"Useful against circular import errors. "
126-
"Deprecated use '--structure-style namespaces'"
127-
),
128-
)
129114
@click.option(
130115
"-ss",
131116
"--structure-style",
@@ -154,6 +139,13 @@ def download(source: str, output: str):
154139
"ordering between data binding operations."
155140
),
156141
)
142+
@click.option(
143+
"-ri",
144+
"--relative-imports",
145+
is_flag=True,
146+
default=False,
147+
help="Enable relative imports",
148+
)
157149
@click.option(
158150
"-pp",
159151
"--print",
@@ -179,19 +171,19 @@ def generate(**kwargs: Any):
179171
config.output.package = kwargs["package"]
180172
else:
181173
config = GeneratorConfig()
182-
config.output.format = kwargs["output"]
174+
config.output.format = OutputFormat(
175+
value=kwargs["output"], relative_imports=kwargs["relative_imports"]
176+
)
183177
config.output.package = kwargs["package"]
184178
config.output.compound_fields = kwargs["compound_fields"]
185179
config.output.docstring_style = DocstringStyle(kwargs["docstring_style"])
186180

187-
if kwargs["ns_struct"]:
188-
config.output.structure = StructureStyle.NAMESPACES
189-
logger.warning(
190-
"--ns-struct is deprecated switch to '--structure-style namespaces'"
191-
)
192-
elif kwargs["structure_style"] != StructureStyle.FILENAMES.value:
181+
if kwargs["structure_style"] != StructureStyle.FILENAMES.value:
193182
config.output.structure = StructureStyle(kwargs["structure_style"])
194183

184+
if kwargs["relative_imports"]:
185+
config.output.format.relative_imports = True
186+
195187
uris = resolve_source(kwargs["source"])
196188
transformer = SchemaTransformer(config=config, print=kwargs["print"])
197189
transformer.process(list(uris))

xsdata/codegen/handlers/attribute_compound_choice.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from operator import attrgetter
12
from typing import List
23

34
from xsdata.codegen.mixins import HandlerInterface
@@ -17,7 +18,7 @@ class AttributeCompoundChoiceHandler(HandlerInterface):
1718
__slots__ = ()
1819

1920
def process(self, target: Class):
20-
groups = group_by(target.attrs, lambda x: x.restrictions.choice)
21+
groups = group_by(target.attrs, attrgetter("restrictions.choice"))
2122
for choice, attrs in groups.items():
2223
if choice and len(attrs) > 1 and any(attr.is_list for attr in attrs):
2324
self.group_fields(target, attrs)

xsdata/codegen/handlers/attribute_group.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from operator import attrgetter
2+
13
from xsdata.codegen.mixins import RelativeHandlerInterface
24
from xsdata.codegen.models import Attr
35
from xsdata.codegen.models import Class
46
from xsdata.codegen.utils import ClassUtils
57
from xsdata.exceptions import AnalyzerValueError
6-
from xsdata.models.enums import Tag
78

89

910
class AttributeGroupHandler(RelativeHandlerInterface):
@@ -36,9 +37,7 @@ def process_attribute(self, target: Class, attr: Attr):
3637
:raises AnalyzerValueError: if source class is not found.
3738
"""
3839
qname = attr.types[0].qname # group attributes have one type only.
39-
source = self.container.find(
40-
qname, condition=lambda x: x.tag in (Tag.ATTRIBUTE_GROUP, Tag.GROUP)
41-
)
40+
source = self.container.find(qname, condition=attrgetter("is_group"))
4241

4342
if not source:
4443
raise AnalyzerValueError(f"Group attribute not found: `{qname}`")

xsdata/codegen/models.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ class AttrType:
209209
forward: bool = field(default=False)
210210
circular: bool = field(default=False)
211211

212+
@property
213+
def datatype(self) -> Optional[DataType]:
214+
return DataType.from_qname(self.qname) if self.native else None
215+
212216
@property
213217
def name(self) -> str:
214218
"""Shortcut for qname local name."""
@@ -222,10 +226,6 @@ def is_dependency(self, allow_circular: bool) -> bool:
222226
self.forward or self.native or (not allow_circular and self.circular)
223227
)
224228

225-
@property
226-
def datatype(self) -> Optional[DataType]:
227-
return DataType.from_qname(self.qname) if self.native else None
228-
229229
def clone(self) -> "AttrType":
230230
"""Return a deep cloned instance."""
231231
return replace(self)
@@ -465,6 +465,12 @@ def is_element(self) -> bool:
465465
xs:element."""
466466
return self.tag == Tag.ELEMENT
467467

468+
@property
469+
def is_group(self) -> bool:
470+
"""Return whether this attribute is derived from an xs:group or
471+
xs:attributeGroup."""
472+
return self.tag in (Tag.ATTRIBUTE_GROUP, Tag.GROUP)
473+
468474
@property
469475
def is_enumeration(self) -> bool:
470476
"""Return whether all attributes are derived from xs:enumeration."""

0 commit comments

Comments
 (0)