diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index 41d020183..b11355bad 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -3,9 +3,12 @@ from typing import Dict from typing import List from typing import Optional +from typing import Type from typing import Union from xml.etree.ElementTree import QName +from xsdata.utils.constants import return_true + @dataclass class TypeA: @@ -79,7 +82,13 @@ class ChoiceType: {"name": "int2", "type": int, "nillable": True}, {"name": "float", "type": float}, {"name": "qname", "type": QName}, - {"name": "tokens", "type": List[int], "tokens": True}, + {"name": "tokens", "type": List[int], "tokens": True, "default_factory": return_true}, + {"name": "union", "type": Type["UnionType"], "namespace": "foo"}, + {"name": "p", "type": float, "fixed": True, "default": 1.1}, + {"wildcard": True, + "type": object, + "namespace": "http://www.w3.org/1999/xhtml", + }, ), } ) diff --git a/tests/fixtures/submodels.py b/tests/fixtures/submodels.py new file mode 100644 index 000000000..568af28fb --- /dev/null +++ b/tests/fixtures/submodels.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from dataclasses import field +from typing import Dict +from typing import List +from typing import Optional +from typing import Union +from xml.etree.ElementTree import QName + +from tests.fixtures.models import ChoiceType + + +@dataclass +class ChoiceTypeChild(ChoiceType): + pass \ No newline at end of file diff --git a/tests/formats/dataclass/models/test_builders.py b/tests/formats/dataclass/models/test_builders.py index 6af09f958..58885c189 100644 --- a/tests/formats/dataclass/models/test_builders.py +++ b/tests/formats/dataclass/models/test_builders.py @@ -8,15 +8,19 @@ from typing import get_type_hints from typing import Iterator from typing import List -from typing import Type from typing import Union from unittest import mock from unittest import TestCase +from xml.etree.ElementTree import QName from tests.fixtures.artists import Artist from tests.fixtures.books import BookForm +from tests.fixtures.models import ChoiceType +from tests.fixtures.models import TypeA from tests.fixtures.models import TypeB +from tests.fixtures.models import UnionType from tests.fixtures.series import Country +from tests.fixtures.submodels import ChoiceTypeChild from xsdata.exceptions import XmlContextError from xsdata.formats.dataclass.compat import class_types from xsdata.formats.dataclass.models.builders import XmlMetaBuilder @@ -103,6 +107,12 @@ def test_build_with_no_dataclass_raises_exception(self, *args): self.assertEqual(f"Type '{int}' is not a dataclass.", str(cm.exception)) + def test_build_locates_globalns_per_field(self): + actual = self.builder.build(ChoiceTypeChild, None) + self.assertEqual(1, len(actual.choices)) + self.assertEqual(9, len(actual.choices[0].elements)) + self.assertIsNone(self.builder.find_globalns(object, "foo")) + def test_target_namespace(self): class Meta: namespace = "bar" @@ -234,17 +244,18 @@ def setUp(self) -> None: ) super().setUp() + self.maxDiff = None def test_build_with_choice_field(self): - globalns = sys.modules[CompoundFieldExample.__module__].__dict__ - type_hints = get_type_hints(CompoundFieldExample) - class_field = fields(CompoundFieldExample)[0] + globalns = sys.modules[ChoiceType.__module__].__dict__ + type_hints = get_type_hints(ChoiceType) + class_field = fields(ChoiceType)[0] self.builder.parent_ns = "bar" actual = self.builder.build( 66, - "compound", - type_hints["compound"], + "choice", + type_hints["choice"], class_field.metadata, True, list, @@ -252,96 +263,110 @@ def test_build_with_choice_field(self): ) expected = XmlVarFactory.create( index=67, - xml_type=XmlType.ELEMENTS, - name="compound", - qname="compound", + name="choice", + types=(object,), list_element=True, any_type=True, default=list, + xml_type=XmlType.ELEMENTS, elements={ - "{foo}node": XmlVarFactory.create( + "{bar}a": XmlVarFactory.create( index=1, - xml_type=XmlType.ELEMENT, - name="compound", - qname="{foo}node", + name="choice", + qname="{bar}a", + types=(TypeA,), + clazz=TypeA, list_element=True, - types=(CompoundFieldExample,), - namespaces=("foo",), - derived=False, + namespaces=("bar",), ), - "{bar}x": XmlVarFactory.create( + "{bar}b": XmlVarFactory.create( index=2, - xml_type=XmlType.ELEMENT, - name="compound", - qname="{bar}x", - tokens=True, + name="choice", + qname="{bar}b", + types=(TypeB,), + clazz=TypeB, list_element=True, - types=(str,), namespaces=("bar",), - derived=False, - default=return_true, - format="Nope", ), - "{bar}y": XmlVarFactory.create( + "{bar}int": XmlVarFactory.create( index=3, - xml_type=XmlType.ELEMENT, - name="compound", - qname="{bar}y", - nillable=True, - list_element=True, + name="choice", + qname="{bar}int", types=(int,), + list_element=True, namespaces=("bar",), - derived=False, ), - "{bar}z": XmlVarFactory.create( + "{bar}int2": XmlVarFactory.create( index=4, - xml_type=XmlType.ELEMENT, - name="compound", - qname="{bar}z", - nillable=False, - list_element=True, + name="choice", + qname="{bar}int2", types=(int,), - namespaces=("bar",), derived=True, + nillable=True, + list_element=True, + namespaces=("bar",), ), - "{bar}o": XmlVarFactory.create( + "{bar}float": XmlVarFactory.create( index=5, - xml_type=XmlType.ELEMENT, - name="compound", - qname="{bar}o", - nillable=False, + name="choice", + qname="{bar}float", + types=(float,), list_element=True, - types=(object,), namespaces=("bar",), + ), + "{bar}qname": XmlVarFactory.create( + index=6, + name="choice", + qname="{bar}qname", + types=(QName,), + list_element=True, + namespaces=("bar",), + ), + "{bar}tokens": XmlVarFactory.create( + index=7, + name="choice", + qname="{bar}tokens", + types=(int,), + tokens=True, derived=True, - any_type=True, + list_element=True, + default=return_true, + namespaces=("bar",), + ), + "{foo}union": XmlVarFactory.create( + index=8, + name="choice", + qname="{foo}union", + types=(UnionType,), + clazz=UnionType, + list_element=True, + namespaces=("foo",), ), "{bar}p": XmlVarFactory.create( - index=6, - xml_type=XmlType.ELEMENT, - name="compound", + index=9, + name="choice", qname="{bar}p", types=(float,), + derived=True, list_element=True, - namespaces=("bar",), default=1.1, + namespaces=("bar",), ), }, wildcards=[ XmlVarFactory.create( - index=7, + index=10, + name="choice", xml_type=XmlType.WILDCARD, - name="compound", qname="{http://www.w3.org/1999/xhtml}any", types=(object,), - namespaces=("http://www.w3.org/1999/xhtml",), - derived=True, - any_type=False, list_element=True, - ) + default=None, + namespaces=("http://www.w3.org/1999/xhtml",), + ), ], - types=(object,), ) + self.assertEqual(expected, actual) def test_build_validates_result(self): @@ -455,37 +480,3 @@ def test_is_valid(self): XmlType.TEXT, None, None, (int, uuid.UUID), False, False ) ) - - -@dataclass -class CompoundFieldExample: - - compound: List[object] = field( - default_factory=list, - metadata={ - "type": "Elements", - "choices": ( - { - "name": "node", - "type": Type["CompoundFieldExample"], - "namespace": "foo", - }, - { - "name": "x", - "type": List[str], - "tokens": True, - "default_factory": return_true, - "format": "Nope", - }, - {"name": "y", "type": List[int], "nillable": True}, - {"name": "z", "type": List[int]}, - {"name": "o", "type": object}, - {"name": "p", "type": float, "fixed": True, "default": 1.1}, - { - "wildcard": True, - "type": object, - "namespace": "http://www.w3.org/1999/xhtml", - }, - ), - }, - ) diff --git a/tests/formats/dataclass/test_context.py b/tests/formats/dataclass/test_context.py index 460ef0805..5b8ff04f2 100644 --- a/tests/formats/dataclass/test_context.py +++ b/tests/formats/dataclass/test_context.py @@ -6,8 +6,10 @@ from tests.fixtures.artists import BeginArea from tests.fixtures.books import BookForm from tests.fixtures.books import BooksForm +from tests.fixtures.models import BaseType from tests.fixtures.models import ChoiceType from tests.fixtures.models import TypeA +from tests.fixtures.models import TypeC from tests.fixtures.models import UnionType from xsdata.formats.dataclass.context import XmlContext from xsdata.models.enums import DataType @@ -107,10 +109,10 @@ def test_is_derived(self): def test_build_recursive(self): self.ctx.build_recursive(ChoiceType) - self.assertEqual(3, len(self.ctx.cache)) + self.assertEqual(6, len(self.ctx.cache)) - self.ctx.build_recursive(TypeA) - self.assertEqual(3, len(self.ctx.cache)) + self.ctx.build_recursive(BaseType) + self.assertEqual(8, len(self.ctx.cache)) self.ctx.build_recursive(UnionType) - self.assertEqual(6, len(self.ctx.cache)) + self.assertEqual(8, len(self.ctx.cache)) diff --git a/tests/formats/dataclass/test_elements.py b/tests/formats/dataclass/test_elements.py index 040d50c31..063cf22bf 100644 --- a/tests/formats/dataclass/test_elements.py +++ b/tests/formats/dataclass/test_elements.py @@ -43,7 +43,9 @@ def test_property_is_clazz_union(self): def test_property_element_types(self): meta = self.context.build(ChoiceType) var = meta.choices[0] - self.assertEqual({TypeA, TypeB, int, float, QName}, var.element_types) + self.assertEqual( + {TypeA, TypeB, int, float, QName, UnionType}, var.element_types + ) def test_find_choice(self): var = XmlVarFactory.create( diff --git a/xsdata/formats/dataclass/models/builders.py b/xsdata/formats/dataclass/models/builders.py index 1b44b476c..4c14d41af 100644 --- a/xsdata/formats/dataclass/models/builders.py +++ b/xsdata/formats/dataclass/models/builders.py @@ -114,7 +114,6 @@ def build_vars( ): """Build the binding metadata for the given dataclass fields.""" type_hints = get_type_hints(clazz) - globalns = sys.modules[clazz.__module__].__dict__ builder = XmlVarBuilder( class_type=self.class_type, parent_ns=parent_ns, @@ -123,19 +122,28 @@ def build_vars( attribute_name_generator=attribute_name_generator, ) - for index, _field in enumerate(self.class_type.get_fields(clazz)): + for index, field in enumerate(self.class_type.get_fields(clazz)): var = builder.build( index, - _field.name, - type_hints[_field.name], - _field.metadata, - _field.init, - self.class_type.default_value(_field), - globalns, + field.name, + type_hints[field.name], + field.metadata, + field.init, + self.class_type.default_value(field), + self.find_globalns(clazz, field.name), ) if var is not None: yield var + @classmethod + def find_globalns(cls, clazz: Type, name: str) -> Optional[Dict]: + for base in clazz.__mro__: + ann = base.__dict__.get("__annotations__") + if ann and name in ann: + return sys.modules[base.__module__].__dict__ + + return None + @classmethod def build_target_qname(cls, clazz: Type, element_name_generator: Callable) -> str: """Build the source qualified name of a model based on the module