diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a1326ffa1..5dfd4af19 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: - id: end-of-file-fixer - id: debug-statements - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.1 + rev: v0.3.0 hooks: - id: ruff args: [ --fix, --show-fixes ] diff --git a/docs/codegen/architecture.md b/docs/codegen/architecture.md index 6fd644368..71db0ed43 100644 --- a/docs/codegen/architecture.md +++ b/docs/codegen/architecture.md @@ -139,6 +139,7 @@ pass through each step before next one starts. The order of the steps is very im - [VacuumInnerClasses][xsdata.codegen.handlers.VacuumInnerClasses] - [CreateCompoundFields][xsdata.codegen.handlers.CreateCompoundFields] +- [DisambiguateChoices][xsdata.codegen.handlers.DisambiguateChoices] - [ResetAttributeSequenceNumbers][xsdata.codegen.handlers.ResetAttributeSequenceNumbers] ### Step: Designate diff --git a/docs/models/fields.md b/docs/models/fields.md index b3b5e89ed..fd7271580 100644 --- a/docs/models/fields.md +++ b/docs/models/fields.md @@ -256,9 +256,12 @@ Elements type represents repeatable choice elements. It's more commonly referred !!! Warning - If a compound field includes ambiguous types, you need to use - `~xsdata.formats.dataclass.models.generics.DerivedElement` to wrap - your values, otherwise your object can be assigned to the wrong element. + A compound field can not contain ambigous types because it's impossible to infer the + element from the actual value. + + The xml contenxt will raise an error. The solution is to introduce intermediate + simple types or subclasses per element. This will resolve xml roundtrips but + it will not work for certain json roundtrips. #### Wildcard diff --git a/tests/codegen/handlers/test_disambiguate_choices.py b/tests/codegen/handlers/test_disambiguate_choices.py new file mode 100644 index 000000000..1a99e0afe --- /dev/null +++ b/tests/codegen/handlers/test_disambiguate_choices.py @@ -0,0 +1,251 @@ +from dataclasses import replace + +from xsdata.codegen.container import ClassContainer +from xsdata.codegen.handlers import DisambiguateChoices +from xsdata.codegen.models import Restrictions, Status +from xsdata.models.config import GeneratorConfig +from xsdata.models.enums import DataType, Tag +from xsdata.utils.testing import ( + AttrFactory, + AttrTypeFactory, + ClassFactory, + FactoryTestCase, +) + + +class DisambiguateChoicesTest(FactoryTestCase): + maxDiff = None + + def setUp(self): + super().setUp() + + self.container = ClassContainer(config=GeneratorConfig()) + self.handler = DisambiguateChoices(self.container) + + def test_process_with_duplicate_wildcards(self): + compound = AttrFactory.create(tag=Tag.CHOICE, types=[]) + target = ClassFactory.create() + target.attrs.append(compound) + compound.choices.append(AttrFactory.native(DataType.STRING)) + compound.choices.append(AttrFactory.any(namespace="foo")) + compound.choices.append( + AttrFactory.any( + namespace="bar", restrictions=Restrictions(min_occurs=1, max_occurs=1) + ) + ) + compound.choices.append( + AttrFactory.any( + namespace="bar", restrictions=Restrictions(max_occurs=3, min_occurs=0) + ) + ) + self.container.add(target) + self.handler.process(target) + + self.assertEqual(2, len(compound.choices)) + + wildcard = compound.choices[-1] + self.assertEqual("content", wildcard.name) + self.assertEqual([AttrTypeFactory.native(DataType.ANY_TYPE)], wildcard.types) + self.assertEqual("foo bar", wildcard.namespace) + self.assertEqual(1, wildcard.restrictions.min_occurs) + self.assertEqual(4, wildcard.restrictions.max_occurs) + + def test_process_with_duplicate_simple_types(self): + compound = AttrFactory.create(tag=Tag.CHOICE, types=[]) + target = ClassFactory.create() + target.attrs.append(compound) + compound.choices.append(AttrFactory.native(DataType.STRING, name="a")) + compound.choices.append( + AttrFactory.native(DataType.STRING, name="b", namespace="xs") + ) + self.container.add(target) + + self.handler.process(target) + self.assertEqual(2, len(compound.choices)) + + self.assertEqual("a", compound.choices[0].types[0].qname) + self.assertEqual("{xs}b", compound.choices[1].types[0].qname) + + self.assertEqual(2, len(target.inner)) + self.assertEqual("a", target.inner[0].qname) + self.assertEqual("{xs}b", target.inner[1].qname) + + self.assertEqual(["a", "{xs}b"], [x.qname for x in compound.types]) + + def test_process_with_duplicate_any_types(self): + compound = AttrFactory.create(tag=Tag.CHOICE, types=[]) + target = ClassFactory.create() + target.attrs.append(compound) + compound.choices.append(AttrFactory.native(DataType.ANY_TYPE, name="a")) + compound.choices.append( + AttrFactory.native(DataType.ANY_TYPE, name="b", namespace="xs") + ) + self.container.add(target) + + self.handler.process(target) + self.assertEqual(2, len(compound.choices)) + + self.assertEqual("a", compound.choices[0].types[0].qname) + self.assertEqual("{xs}b", compound.choices[1].types[0].qname) + + self.assertEqual(2, len(target.inner)) + self.assertEqual("a", target.inner[0].qname) + self.assertEqual("{xs}b", target.inner[1].qname) + + def test_process_with_duplicate_complex_types(self): + compound = AttrFactory.any() + target = ClassFactory.create() + target.attrs.append(compound) + compound.choices.append(AttrFactory.reference(name="a", qname="myint")) + compound.choices.append(AttrFactory.reference(name="b", qname="myint")) + self.container.add(target) + + self.handler.process(target) + self.assertEqual(2, len(compound.choices)) + + self.assertEqual("attr_C", compound.choices[0].types[0].qname) + self.assertEqual("attr_D", compound.choices[1].types[0].qname) + + self.assertEqual(2, len(target.inner)) + self.assertEqual("attr_C", target.inner[0].qname) + self.assertEqual("attr_D", target.inner[1].qname) + + for inner in target.inner: + self.assertEqual("myint", inner.extensions[0].type.qname) + self.assertEqual("myint", inner.extensions[0].type.qname) + + self.assertEqual(DataType.ANY_TYPE, compound.types[0].datatype) + + def test_disambiguate_choice_with_unnest_true(self): + target = ClassFactory.create() + attr = AttrFactory.reference(qname="a") + + config = GeneratorConfig() + config.output.unnest_classes = True + container = ClassContainer(config=config) + handler = DisambiguateChoices(container) + + container.add(target) + handler.disambiguate_choice(target, attr) + + self.assertIsNotNone(container.find(attr.qname)) + + def test_disambiguate_choice_with_circular_ref(self): + target = ClassFactory.create() + attr = AttrFactory.reference(qname="a") + attr.types[0].circular = True + + self.container.add(target) + self.handler.disambiguate_choice(target, attr) + + self.assertTrue(attr.types[0].circular) + self.assertIsNotNone(self.container.find(attr.qname)) + + def test_find_ambiguous_choices_ignore_wildcards(self): + """Wildcards are merged.""" + + attr = AttrFactory.create() + attr.choices.append(AttrFactory.any()) + attr.choices.append(AttrFactory.any()) + attr.choices.append( + AttrFactory.create( + name="this", types=[AttrTypeFactory.native(DataType.ANY_TYPE)] + ) + ) + + result = self.handler.find_ambiguous_choices(attr) + self.assertEqual(["this"], [x.name for x in result]) + + def test_is_simple_type(self): + attr = AttrFactory.native(DataType.STRING) + self.assertTrue(self.handler.is_simple_type(attr)) + + enumeration = ClassFactory.enumeration(2) + self.container.add(enumeration) + attr = AttrFactory.reference(qname=enumeration.qname) + self.assertTrue(self.handler.is_simple_type(attr)) + + complex = ClassFactory.create() + self.container.add(complex) + attr = AttrFactory.reference(qname=complex.qname) + self.assertFalse(self.handler.is_simple_type(attr)) + + def test_create_ref_class(self): + source = ClassFactory.create( + status=Status.RESOLVED, + location="here.xsd", + ns_map={"foo": "bar"}, + ) + attr = AttrFactory.create( + namespace="test", + restrictions=Restrictions(nillable=True), + ) + + result = self.handler.create_ref_class(source, attr, inner=True) + + self.assertTrue(result.local_type) + self.assertEqual("{test}attr_B", result.qname) + self.assertEqual(source.status, result.status) + self.assertEqual(Tag.ELEMENT, result.tag) + self.assertEqual(source.location, result.location) + self.assertEqual(source.ns_map, result.ns_map) + self.assertEqual(attr.restrictions.nillable, result.nillable) + + def test_create_ref_class_creates_unique_inner_names(self): + source = ClassFactory.create( + status=Status.RESOLVED, + location="here.xsd", + ns_map={"foo": "bar"}, + ) + attr = AttrFactory.create(name="a") + source.inner.append(ClassFactory.create(qname="{xs}a")) + result = self.handler.create_ref_class(source, attr, inner=True) + + self.assertEqual("a_1", result.name) + + def test_add_any_type_value(self): + target = ClassFactory.elements(2) + source = AttrFactory.any() + self.handler.add_any_type_value(target, source) + + last = target.attrs[-1] + self.assertEqual("content", last.name) + self.assertEqual(Tag.ANY, last.tag) + self.assertEqual(source.namespace, last.namespace) + self.assertEqual([AttrTypeFactory.native(DataType.ANY_TYPE)], last.types) + self.assertFalse(last.restrictions.is_optional) + self.assertFalse(last.restrictions.is_list) + + def test_add_simply_type_value(self): + target = ClassFactory.elements(2) + source = AttrFactory.native( + DataType.STRING, + restrictions=Restrictions( + max_length=2, nillable=True, path=[("s", 1, 1, 1)] + ), + ) + self.handler.add_simple_type_value(target, source) + + last = target.attrs[-1] + self.assertEqual("value", last.name) + self.assertEqual(Tag.EXTENSION, last.tag) + self.assertIsNone(last.namespace) + self.assertEqual(source.types, last.types) + self.assertFalse(last.restrictions.is_optional) + self.assertFalse(last.restrictions.is_list) + self.assertEqual([], last.restrictions.path) + self.assertFalse(last.restrictions.nillable) + + def test_add_extension(self): + target = ClassFactory.create() + source = AttrFactory.reference("{xs}type") + source.types[0].forward = True + source.types[0].circular = True + self.handler.add_extension(target, source) + + last = target.extensions[-1] + self.assertEqual(Tag.EXTENSION, last.tag) + + expected = replace(source.types[0], forward=False, circular=False) + self.assertEqual(expected, last.type) + self.assertEqual(Restrictions(), last.restrictions) diff --git a/tests/codegen/handlers/test_rename_duplicate_classes.py b/tests/codegen/handlers/test_rename_duplicate_classes.py index b194b82df..12552ec72 100644 --- a/tests/codegen/handlers/test_rename_duplicate_classes.py +++ b/tests/codegen/handlers/test_rename_duplicate_classes.py @@ -20,8 +20,9 @@ def setUp(self): self.container = ClassContainer(config=GeneratorConfig()) self.processor = RenameDuplicateClasses(container=self.container) + @mock.patch.object(RenameDuplicateClasses, "merge_classes") @mock.patch.object(RenameDuplicateClasses, "rename_classes") - def test_run(self, mock_rename_classes): + def test_run(self, mock_rename_classes, mock_merge_classes): classes = [ ClassFactory.create(qname="{foo}A"), ClassFactory.create(qname="{foo}a"), @@ -38,9 +39,13 @@ def test_run(self, mock_rename_classes): mock.call(classes[3:], False), ] ) + self.assertEqual(0, mock_merge_classes.call_count) + @mock.patch.object(RenameDuplicateClasses, "merge_classes") @mock.patch.object(RenameDuplicateClasses, "rename_classes") - def test_run_with_single_package_structure(self, mock_rename_classes): + def test_run_with_single_package_structure( + self, mock_rename_classes, mock_merge_classes + ): classes = [ ClassFactory.create(qname="{foo}a"), ClassFactory.create(qname="{bar}a"), @@ -50,9 +55,13 @@ def test_run_with_single_package_structure(self, mock_rename_classes): self.processor.run() mock_rename_classes.assert_called_once_with(classes, True) + self.assertEqual(0, mock_merge_classes.call_count) + @mock.patch.object(RenameDuplicateClasses, "merge_classes") @mock.patch.object(RenameDuplicateClasses, "rename_classes") - def test_run_with_single_location_source(self, mock_rename_classes): + def test_run_with_single_location_source( + self, mock_rename_classes, mock_merge_classes + ): classes = [ ClassFactory.create(qname="{foo}a"), ClassFactory.create(qname="{bar}a"), @@ -64,9 +73,11 @@ def test_run_with_single_location_source(self, mock_rename_classes): self.processor.run() mock_rename_classes.assert_called_once_with(classes, True) + self.assertEqual(0, mock_merge_classes.call_count) + @mock.patch.object(RenameDuplicateClasses, "merge_classes") @mock.patch.object(RenameDuplicateClasses, "rename_classes") - def test_run_with_clusters_structure(self, mock_rename_classes): + def test_run_with_clusters_structure(self, mock_rename_classes, mock_merge_classes): classes = [ ClassFactory.create(qname="{foo}a"), ClassFactory.create(qname="{bar}a"), @@ -77,6 +88,62 @@ def test_run_with_clusters_structure(self, mock_rename_classes): self.processor.run() mock_rename_classes.assert_called_once_with(classes, True) + self.assertEqual(0, mock_merge_classes.call_count) + + @mock.patch.object(RenameDuplicateClasses, "merge_classes") + @mock.patch.object(RenameDuplicateClasses, "rename_classes") + def test_run_with_same_classes(self, mock_rename_classes, mock_merge_classes): + first = ClassFactory.create() + second = first.clone() + third = ClassFactory.create() + + self.container.extend([first, second, third]) + self.processor.run() + + self.assertEqual(0, mock_rename_classes.call_count) + mock_merge_classes.assert_called_once_with([first, second]) + + @mock.patch.object(RenameDuplicateClasses, "update_class_references") + def test_merge_classes(self, mock_update_class_references): + first = ClassFactory.create() + second = first.clone() + third = first.clone() + fourth = ClassFactory.create() + fifth = ClassFactory.create() + + self.container.extend([first, second, third, fourth, fifth]) + self.processor.run() + + replacements = { + id(second): id(first), + id(third): id(first), + } + + mock_update_class_references.assert_has_calls( + [ + mock.call(first, replacements), + mock.call(fourth, replacements), + mock.call(fifth, replacements), + ] + ) + self.assertEqual([first, fourth, fifth], list(self.container)) + + def test_update_class_references(self): + replacements = {1: 2, 3: 4, 5: 6, 7: 8} + target = ClassFactory.create( + attrs=AttrFactory.list(3), + extensions=ExtensionFactory.list(2), + inner=[ClassFactory.elements(2), ClassFactory.create()], + ) + target.attrs[1].choices = AttrFactory.list(2) + + target.attrs[0].types[0].reference = 1 + target.attrs[1].choices[0].types[0].reference = 3 + target.extensions[1].type.reference = 5 + target.inner[0].attrs[0].types[0].reference = 7 + + self.processor.update_class_references(target, replacements) + self.assertEqual([6, 2, 4, 8], list(target.references)) @mock.patch.object(RenameDuplicateClasses, "rename_class") def test_rename_classes(self, mock_rename_class): diff --git a/tests/codegen/handlers/test_validate_attributes_overrides.py b/tests/codegen/handlers/test_validate_attributes_overrides.py index e7d795918..6dcd4aa17 100644 --- a/tests/codegen/handlers/test_validate_attributes_overrides.py +++ b/tests/codegen/handlers/test_validate_attributes_overrides.py @@ -115,7 +115,7 @@ def test_overrides(self): def test_validate_override(self): attr_a = AttrFactory.create() attr_b = attr_a.clone() - attr_b.parent = ClassFactory.create() + attr_b.parent = ClassFactory.create().qname target = ClassFactory.create() target.attrs.append(attr_a) diff --git a/tests/codegen/models/test_attr.py b/tests/codegen/models/test_attr.py index e570dca36..47788ee4f 100644 --- a/tests/codegen/models/test_attr.py +++ b/tests/codegen/models/test_attr.py @@ -35,6 +35,10 @@ def test_property_key(self): attr = AttrFactory.attribute(name="a", namespace="b") self.assertEqual("Attribute.b.a", attr.key) + def test_property_qname(self): + attr = AttrFactory.attribute(name="a", namespace="b") + self.assertEqual("{b}a", attr.qname) + def test_property_is_property(self): self.assertTrue(AttrFactory.attribute().is_attribute) self.assertTrue(AttrFactory.any_attribute().is_attribute) @@ -65,6 +69,18 @@ def test_property_is_forward_ref(self): attr.types.append(AttrTypeFactory.create("foo", forward=True)) self.assertTrue(attr.is_forward_ref) + def test_property_is_circular_ref(self): + attr = AttrFactory.create() + self.assertFalse(attr.is_circular_ref) + + attr.types.append(AttrTypeFactory.create("foo", forward=True)) + self.assertFalse(attr.is_circular_ref) + + self.assertFalse(attr.is_circular_ref) + + attr.types.append(AttrTypeFactory.create("foo", circular=True)) + self.assertTrue(attr.is_circular_ref) + def test_property_is_group(self): self.assertTrue(AttrFactory.group().is_group) self.assertTrue(AttrFactory.attribute_group().is_group) diff --git a/tests/codegen/test_container.py b/tests/codegen/test_container.py index b86ace5e5..37ac74ed5 100644 --- a/tests/codegen/test_container.py +++ b/tests/codegen/test_container.py @@ -56,6 +56,7 @@ def test_initialize(self): 50: [ "VacuumInnerClasses", "CreateCompoundFields", + "DisambiguateChoices", "ResetAttributeSequenceNumbers", ], } diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index e358c14cf..1c2e47972 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from dataclasses import field +from decimal import Decimal from typing import Dict, Any from typing import List from typing import Optional @@ -95,17 +96,10 @@ class ChoiceType: {"name": "a", "type": TypeA}, {"name": "b", "type": TypeB}, {"name": "int", "type": int}, - {"name": "int2", "type": int, "nillable": True}, {"name": "float", "type": float}, {"name": "qname", "type": QName}, - { - "name": "tokens", - "type": List[int], - "tokens": True, - "default_factory": return_true - }, {"name": "union", "type": Type["UnionType"], "namespace": "foo"}, - {"name": "p", "type": float, "fixed": True, "default": 1.1}, + {"name": "tokens", "type": List[Decimal], "tokens": True}, { "wildcard": True, "type": object, @@ -128,6 +122,18 @@ class OptionalChoiceType: ) +@dataclass +class AmbiguousChoiceType: + choice: int = field( + metadata={ + "type": "Elements", + "choices": ( + {"name": "a", "type": int}, + {"name": "b", "type": int}, + ), + } + ) + @dataclass class UnionType: diff --git a/tests/formats/dataclass/models/test_builders.py b/tests/formats/dataclass/models/test_builders.py index 76c0bf5ac..c8129ffe1 100644 --- a/tests/formats/dataclass/models/test_builders.py +++ b/tests/formats/dataclass/models/test_builders.py @@ -1,13 +1,23 @@ +import functools import sys import uuid from dataclasses import dataclass, field, fields, make_dataclass +from decimal import Decimal from typing import Dict, Iterator, List, Tuple, Union, get_type_hints from unittest import TestCase, mock from xml.etree.ElementTree import QName from tests.fixtures.artists import Artist from tests.fixtures.books import BookForm -from tests.fixtures.models import ChoiceType, Parent, TypeA, TypeB, TypeNS1, UnionType +from tests.fixtures.models import ( + AmbiguousChoiceType, + ChoiceType, + Parent, + TypeA, + TypeB, + TypeNS1, + UnionType, +) from tests.fixtures.series import Country from tests.fixtures.submodels import ChoiceTypeChild from xsdata.exceptions import XmlContextError @@ -16,7 +26,7 @@ from xsdata.formats.dataclass.models.elements import XmlMeta, XmlType from xsdata.models.datatype import XmlDate from xsdata.utils import text -from xsdata.utils.constants import return_input, return_true +from xsdata.utils.constants import return_input from xsdata.utils.namespaces import build_qname from xsdata.utils.testing import FactoryTestCase, XmlMetaFactory, XmlVarFactory @@ -132,7 +142,7 @@ def test_build_with_no_dataclass_raises_exception(self, *args): def test_build_locates_globalns_per_field(self): actual = self.builder.build(ChoiceTypeChild, None) self.assertEqual(1, len(actual.choices)) - self.assertEqual(9, len(actual.choices[0].elements)) + self.assertEqual(7, len(actual.choices[0].elements)) with self.assertRaises(XmlContextError): self.builder.find_declared_class(object, "foo") @@ -276,6 +286,8 @@ def test_default_xml_type(self): class XmlVarBuilderTests(TestCase): + maxDiff = None + def setUp(self) -> None: self.builder = XmlVarBuilder( class_type=class_types.get_type("dataclasses"), @@ -285,15 +297,14 @@ def setUp(self) -> None: ) super().setUp() - self.maxDiff = None def test_build_with_choice_field(self): globalns = sys.modules[ChoiceType.__module__].__dict__ type_hints = get_type_hints(ChoiceType) class_field = fields(ChoiceType)[0] - self.maxDiff = None actual = self.builder.build( + ChoiceType, "choice", type_hints["choice"], class_field.metadata, @@ -337,18 +348,8 @@ def test_build_with_choice_field(self): factory=list, namespaces=("bar",), ), - "{bar}int2": XmlVarFactory.create( - index=5, - name="choice", - qname="{bar}int2", - types=(int,), - derived=True, - nillable=True, - factory=list, - namespaces=("bar",), - ), "{bar}float": XmlVarFactory.create( - index=6, + index=5, name="choice", qname="{bar}float", types=(float,), @@ -356,26 +357,15 @@ def test_build_with_choice_field(self): namespaces=("bar",), ), "{bar}qname": XmlVarFactory.create( - index=7, + index=6, name="choice", qname="{bar}qname", types=(QName,), factory=list, namespaces=("bar",), ), - "{bar}tokens": XmlVarFactory.create( - index=8, - name="choice", - qname="{bar}tokens", - types=(int,), - tokens_factory=list, - derived=True, - factory=list, - default=return_true, - namespaces=("bar",), - ), "{foo}union": XmlVarFactory.create( - index=9, + index=7, name="choice", qname="{foo}union", types=(UnionType,), @@ -383,20 +373,20 @@ def test_build_with_choice_field(self): factory=list, namespaces=("foo",), ), - "{bar}p": XmlVarFactory.create( - index=10, + "{bar}tokens": XmlVarFactory.create( + index=8, name="choice", - qname="{bar}p", - types=(float,), + qname="{bar}tokens", + types=(Decimal,), + tokens_factory=list, derived=True, factory=list, - default=1.1, namespaces=("bar",), ), }, wildcards=[ XmlVarFactory.create( - index=11, + index=9, name="choice", xml_type=XmlType.WILDCARD, qname="{http://www.w3.org/1999/xhtml}any", @@ -408,17 +398,44 @@ def test_build_with_choice_field(self): ], ) - self.maxDiff = None self.assertEqual(expected, actual) + def test_build_with_ambiguous_choices(self): + type_hints = get_type_hints(AmbiguousChoiceType) + class_field = fields(AmbiguousChoiceType)[0] + + with self.assertRaises(XmlContextError) as cm: + self.builder.build( + AmbiguousChoiceType, + "choice", + type_hints["choice"], + class_field.metadata, + True, + None, + None, + {}, + ) + + self.assertEqual( + "Error on AmbiguousChoiceType::choice: Compound field contains ambiguous types", + str(cm.exception), + ) + def test_build_validates_result(self): with self.assertRaises(XmlContextError) as cm: self.builder.build( - "foo", List[int], {"type": "Attributes"}, True, None, None, None + BookForm, + "foo", + List[int], + {"type": "Attributes"}, + True, + None, + None, + None, ) self.assertEqual( - "Xml type 'Attributes' does not support typing: typing.List[int]", + "Error on BookForm::foo: Xml Attributes does not support typing `typing.List[int]`", str(cm.exception), ) @@ -465,20 +482,22 @@ def test_resolve_namespaces(self): self.assertEqual(("foo", "p"), tuple(sorted(actual))) def test_analyze_types(self): - actual = self.builder.analyze_types(List[List[Union[str, int]]], None) + func = functools.partial(self.builder.analyze_types, BookForm, "foo") + + actual = func(List[List[Union[str, int]]], None) self.assertEqual((list, list, (int, str)), actual) - actual = self.builder.analyze_types(Union[str, int], None) + actual = func(Union[str, int], None) self.assertEqual((None, None, (int, str)), actual) - actual = self.builder.analyze_types(Dict[str, int], None) + actual = func(Dict[str, int], None) self.assertEqual((dict, None, (int, str)), actual) with self.assertRaises(XmlContextError) as cm: - self.builder.analyze_types(List[List[List[int]]], None) + func(List[List[List[int]]], None) self.assertEqual( - "Unsupported typing: typing.List[typing.List[typing.List[int]]]", + "Error on BookForm::foo: Unsupported field typing `typing.List[typing.List[typing.List[int]]]`", str(cm.exception), ) diff --git a/tests/formats/dataclass/parsers/nodes/test_element.py b/tests/formats/dataclass/parsers/nodes/test_element.py index 21f3f7840..6aaa0beae 100644 --- a/tests/formats/dataclass/parsers/nodes/test_element.py +++ b/tests/formats/dataclass/parsers/nodes/test_element.py @@ -371,7 +371,6 @@ def test_build_node_with_dataclass_var(self, mock_ctx_fetch, mock_xsi_type): name="a", qname="a", types=(TypeC,), - derived=True, ) xsi_type = "foo" namespace = self.meta.namespace @@ -384,7 +383,6 @@ def test_build_node_with_dataclass_var(self, mock_ctx_fetch, mock_xsi_type): self.assertIsInstance(actual, ElementNode) self.assertEqual(10, actual.position) - self.assertEqual(DerivedElement, actual.derived_factory) self.assertIs(mock_ctx_fetch.return_value, actual.meta) mock_xsi_type.assert_called_once_with(attrs, ns_map) diff --git a/tests/formats/dataclass/parsers/nodes/test_primitive.py b/tests/formats/dataclass/parsers/nodes/test_primitive.py index 05ff90444..a644789f4 100644 --- a/tests/formats/dataclass/parsers/nodes/test_primitive.py +++ b/tests/formats/dataclass/parsers/nodes/test_primitive.py @@ -2,7 +2,6 @@ from xsdata.exceptions import XmlContextError from xsdata.formats.dataclass.models.elements import XmlType -from xsdata.formats.dataclass.models.generics import DerivedElement from xsdata.formats.dataclass.parsers.nodes import PrimitiveNode from xsdata.formats.dataclass.parsers.utils import ParserUtils from xsdata.utils.testing import XmlVarFactory @@ -16,7 +15,7 @@ def test_bind(self, mock_parse_value): xml_type=XmlType.TEXT, name="foo", qname="foo", types=(int,), format="Nope" ) ns_map = {"foo": "bar"} - node = PrimitiveNode(var, ns_map, False, DerivedElement) + node = PrimitiveNode(var, ns_map, False) objects = [] self.assertTrue(node.bind("foo", "13", "Impossible", objects)) @@ -31,23 +30,12 @@ def test_bind(self, mock_parse_value): format=var.format, ) - def test_bind_derived_mode(self): - var = XmlVarFactory.create( - xml_type=XmlType.TEXT, name="foo", qname="foo", types=(int,), derived=True - ) - ns_map = {"foo": "bar"} - node = PrimitiveNode(var, ns_map, False, DerivedElement) - objects = [] - - self.assertTrue(node.bind("foo", "13", "Impossible", objects)) - self.assertEqual(DerivedElement("foo", 13), objects[-1][1]) - def test_bind_nillable_content(self): var = XmlVarFactory.create( xml_type=XmlType.TEXT, name="foo", qname="foo", types=(str,), nillable=False ) ns_map = {"foo": "bar"} - node = PrimitiveNode(var, ns_map, False, DerivedElement) + node = PrimitiveNode(var, ns_map, False) objects = [] self.assertTrue(node.bind("foo", None, None, objects)) @@ -66,7 +54,7 @@ def test_bind_nillable_bytes_content(self): nillable=False, ) ns_map = {"foo": "bar"} - node = PrimitiveNode(var, ns_map, False, DerivedElement) + node = PrimitiveNode(var, ns_map, False) objects = [] self.assertTrue(node.bind("foo", None, None, objects)) @@ -77,29 +65,25 @@ def test_bind_nillable_bytes_content(self): self.assertIsNone(objects[-1][1]) def test_bind_mixed_with_tail_content(self): - var = XmlVarFactory.create( - xml_type=XmlType.TEXT, name="foo", types=(int,), derived=True - ) - node = PrimitiveNode(var, {}, True, DerivedElement) + var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", types=(int,)) + node = PrimitiveNode(var, {}, True) objects = [] self.assertTrue(node.bind("foo", "13", "tail", objects)) self.assertEqual((None, "tail"), objects[-1]) - self.assertEqual(DerivedElement("foo", 13), objects[-2][1]) + self.assertEqual(13, objects[-2][1]) def test_bind_mixed_without_tail_content(self): - var = XmlVarFactory.create( - xml_type=XmlType.TEXT, name="foo", types=(int,), derived=True - ) - node = PrimitiveNode(var, {}, True, DerivedElement) + var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", types=(int,)) + node = PrimitiveNode(var, {}, True) objects = [] self.assertTrue(node.bind("foo", "13", "", objects)) - self.assertEqual(DerivedElement("foo", 13), objects[-1][1]) + self.assertEqual(13, objects[-1][1]) def test_child(self): var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", qname="foo") - node = PrimitiveNode(var, {}, False, DerivedElement) + node = PrimitiveNode(var, {}, False) with self.assertRaises(XmlContextError): node.child("foo", {}, {}, 0) diff --git a/tests/formats/dataclass/parsers/test_dict.py b/tests/formats/dataclass/parsers/test_dict.py index 77b1e164c..61a687237 100644 --- a/tests/formats/dataclass/parsers/test_dict.py +++ b/tests/formats/dataclass/parsers/test_dict.py @@ -1,5 +1,6 @@ import json from dataclasses import asdict, make_dataclass +from decimal import Decimal from typing import List, Optional, Union from xml.etree.ElementTree import QName @@ -232,17 +233,15 @@ def test_bind_simple_type_with_wildcard_var(self): self.assertEqual(2, actual.wildcard) def test_bind_simple_type_with_elements_var(self): - data = {"choice": ["1.0", 1, ["1"], "a", "{a}b"]} + data = {"choice": ["1.0", 1, ["1.2"], "{a}b"]} actual = self.decoder.bind_dataclass(data, ChoiceType) self.assertEqual(1.0, actual.choice[0]) self.assertEqual(1, actual.choice[1]) - self.assertEqual([1], actual.choice[2]) - self.assertEqual(QName("a"), actual.choice[3]) + self.assertEqual([Decimal("1.2")], actual.choice[2]) + self.assertEqual(QName("{a}b"), actual.choice[3]) self.assertIsInstance(actual.choice[3], QName) - self.assertEqual(QName("{a}b"), actual.choice[4]) - self.assertIsInstance(actual.choice[4], QName) data = {"choice": ["!NotAQname"]} with self.assertRaises(ParserError) as cm: @@ -278,13 +277,6 @@ def test_bind_choice_dataclass(self): expected = ChoiceType(choice=[TypeA(x=1), TypeB(x=1, y="a")]) self.assertEqual(expected, self.decoder.bind_dataclass(data, ChoiceType)) - def test_bind_derived_value_with_simple_type(self): - data = {"choice": [{"qname": "int2", "value": 1, "type": None}]} - - actual = self.decoder.bind_dataclass(data, ChoiceType) - expected = ChoiceType(choice=[DerivedElement(qname="int2", value=1)]) - self.assertEqual(expected, actual) - def test_bind_derived_value_with_choice_var(self): data = { "choice": [ @@ -317,6 +309,13 @@ def test_bind_derived_value_with_choice_var(self): str(cm.exception), ) + def test_bind_derived_value_with_simple_type(self): + data = {"choice": [{"qname": "float", "value": 1, "type": None}]} + + actual = self.decoder.bind_dataclass(data, ChoiceType) + expected = ChoiceType(choice=[DerivedElement(qname="float", value=1)]) + self.assertEqual(expected, actual) + def test_bind_wildcard_dataclass(self): data = {"a": None, "wildcard": {"x": 1}} expected = ExtendedType(wildcard=TypeA(x=1)) diff --git a/tests/formats/dataclass/parsers/test_node.py b/tests/formats/dataclass/parsers/test_node.py index f72d8110e..caf86cd01 100644 --- a/tests/formats/dataclass/parsers/test_node.py +++ b/tests/formats/dataclass/parsers/test_node.py @@ -182,7 +182,7 @@ def test_end(self, mock_assemble): objects = [("q", "result")] queue = [] var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", qname="foo") - queue.append(PrimitiveNode(var, {}, False, DerivedElement)) + queue.append(PrimitiveNode(var, {}, False)) self.assertTrue(parser.end(queue, objects, "author", "foobar", None)) self.assertEqual(0, len(queue)) diff --git a/tests/formats/dataclass/parsers/test_xml.py b/tests/formats/dataclass/parsers/test_xml.py index 7bc3f6778..fa5af97b5 100644 --- a/tests/formats/dataclass/parsers/test_xml.py +++ b/tests/formats/dataclass/parsers/test_xml.py @@ -31,7 +31,7 @@ def test_end(self, mock_emit_event): objects = [] queue = [] var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", types=(bool,)) - queue.append(PrimitiveNode(var, {}, False, None)) + queue.append(PrimitiveNode(var, {}, False)) result = self.parser.end(queue, objects, "enabled", "true", None) self.assertTrue(result) diff --git a/tests/formats/dataclass/serializers/test_dict.py b/tests/formats/dataclass/serializers/test_dict.py index a25ecd120..5f5f95c61 100644 --- a/tests/formats/dataclass/serializers/test_dict.py +++ b/tests/formats/dataclass/serializers/test_dict.py @@ -1,11 +1,9 @@ -import json from unittest.case import TestCase from tests.fixtures.books import BookForm, Books from tests.fixtures.datatypes import Telephone from xsdata.exceptions import XmlContextError -from xsdata.formats.dataclass.serializers import DictEncoder -from xsdata.formats.dataclass.serializers.json import DictFactory, JsonSerializer +from xsdata.formats.dataclass.serializers import DictEncoder, DictFactory from xsdata.models.datatype import XmlDate from xsdata.models.xsd import Attribute from xsdata.utils.testing import XmlVarFactory @@ -73,8 +71,7 @@ def test_encode_list_of_objects(self): def test_encode_with_enum(self): obj = Attribute() - serializer = JsonSerializer(dict_factory=DictFactory.FILTER_NONE) - actual = json.loads(serializer.render(obj)) + actual = self.encoder.encode(obj) self.assertEqual("optional", actual["use"]) @@ -85,9 +82,8 @@ def test_convert_namedtuple(self): def test_next_value(self): book = self.books.book[0] - serializer = JsonSerializer() - actual = [name for name, value in serializer.next_value(book)] + actual = [name for name, value in self.encoder.next_value(book)] expected = [ "author", "title", @@ -100,7 +96,7 @@ def test_next_value(self): ] self.assertEqual(expected, actual) - serializer.config.ignore_default_attributes = True + self.encoder.config.ignore_default_attributes = True expected = expected[:-1] - actual = [name for name, value in serializer.next_value(book)] + actual = [name for name, value in self.encoder.next_value(book)] self.assertEqual(expected, actual) diff --git a/tests/formats/dataclass/serializers/test_json.py b/tests/formats/dataclass/serializers/test_json.py index eeb7e5d30..68748c6cf 100644 --- a/tests/formats/dataclass/serializers/test_json.py +++ b/tests/formats/dataclass/serializers/test_json.py @@ -2,8 +2,9 @@ from unittest.case import TestCase from tests.fixtures.books import BookForm, Books +from xsdata.formats.dataclass.serializers import DictFactory from xsdata.formats.dataclass.serializers.config import SerializerConfig -from xsdata.formats.dataclass.serializers.json import DictFactory, JsonSerializer +from xsdata.formats.dataclass.serializers.json import JsonSerializer from xsdata.models.datatype import XmlDate diff --git a/tests/formats/dataclass/test_elements.py b/tests/formats/dataclass/test_elements.py index 6899d10cc..b51d23d1f 100644 --- a/tests/formats/dataclass/test_elements.py +++ b/tests/formats/dataclass/test_elements.py @@ -1,4 +1,5 @@ from dataclasses import make_dataclass +from decimal import Decimal from unittest import mock from unittest.case import TestCase from xml.etree.ElementTree import QName @@ -45,7 +46,7 @@ def test_property_element_types(self): meta = self.context.build(ChoiceType) var = meta.choices[0] self.assertEqual( - {TypeA, TypeB, int, float, QName, UnionType}, var.element_types + {TypeA, TypeB, int, float, QName, UnionType, Decimal}, var.element_types ) def test_find_choice(self): @@ -88,14 +89,11 @@ def test_find_value_choice(self): meta = self.context.build(ChoiceType) var = meta.choices[0] - self.assertIsNone(var.find_value_choice(["1.1", "1.2"], False)) + self.assertEqual(var.elements["tokens"], var.find_value_choice(["1.2"], False)) self.assertIsNone(var.find_value_choice([], False)) - self.assertEqual(var.elements["int2"], var.find_value_choice(None, False)) self.assertEqual(var.elements["qname"], var.find_value_choice("foo", False)) self.assertEqual(var.elements["int"], var.find_value_choice(1, False)) - self.assertEqual(var.elements["tokens"], var.find_value_choice([1, 2], False)) self.assertEqual(var.elements["a"], var.find_value_choice(TypeA(1), True)) - der = make_dataclass("Der", fields=[], bases=(TypeA,)) self.assertEqual(var.elements["a"], var.find_value_choice(der(1), True)) diff --git a/xsdata/codegen/container.py b/xsdata/codegen/container.py index 72a269ac4..160731aef 100644 --- a/xsdata/codegen/container.py +++ b/xsdata/codegen/container.py @@ -5,6 +5,7 @@ CalculateAttributePaths, CreateCompoundFields, DesignateClassPackages, + DisambiguateChoices, FilterClasses, FlattenAttributeGroups, FlattenClassExtensions, @@ -93,6 +94,7 @@ def __init__(self, config: GeneratorConfig): Steps.FINALIZE: [ VacuumInnerClasses(), CreateCompoundFields(self), + DisambiguateChoices(self), ResetAttributeSequenceNumbers(self), ], } @@ -239,6 +241,14 @@ def add(self, item: Class): """ self.data.setdefault(item.qname, []).append(item) + def remove(self, item: Class): + """Remove class instance from to the container. + + Args: + item: The class instances to remove + """ + self.data[item.qname].remove(item) + def reset(self, item: Class, qname: str): """Update the given class qualified name. diff --git a/xsdata/codegen/handlers/__init__.py b/xsdata/codegen/handlers/__init__.py index e9497e642..e0b3c00ec 100644 --- a/xsdata/codegen/handlers/__init__.py +++ b/xsdata/codegen/handlers/__init__.py @@ -2,6 +2,7 @@ from .calculate_attribute_paths import CalculateAttributePaths from .create_compound_fields import CreateCompoundFields from .designate_class_packages import DesignateClassPackages +from .disambiguate_choices import DisambiguateChoices from .filter_classes import FilterClasses from .flatten_attribute_groups import FlattenAttributeGroups from .flatten_class_extensions import FlattenClassExtensions @@ -24,6 +25,7 @@ "CalculateAttributePaths", "CreateCompoundFields", "DesignateClassPackages", + "DisambiguateChoices", "FilterClasses", "FlattenAttributeGroups", "FlattenClassExtensions", diff --git a/xsdata/codegen/handlers/disambiguate_choices.py b/xsdata/codegen/handlers/disambiguate_choices.py new file mode 100644 index 000000000..fd17955df --- /dev/null +++ b/xsdata/codegen/handlers/disambiguate_choices.py @@ -0,0 +1,291 @@ +from collections import defaultdict +from typing import Iterator + +from xsdata.codegen.mixins import ContainerInterface, RelativeHandlerInterface +from xsdata.codegen.models import Attr, AttrType, Class, Extension, Restrictions +from xsdata.models.enums import DataType, Tag +from xsdata.utils import collections, text +from xsdata.utils.constants import DEFAULT_ATTR_NAME +from xsdata.utils.namespaces import build_qname + + +class DisambiguateChoices(RelativeHandlerInterface): + """Process choices with the same types and disambiguate them. + + Essentially, this handler creates intermediate simple and complex + types to ensure not two elements in a compound field can have the + same type. + + Args: + container: The class container instance + + Attributes: + unnest_classes: Specifies whether to create intermediate + inner or outer classes. + """ + + __slots__ = "unnest_classes" + + def __init__(self, container: ContainerInterface): + super().__init__(container) + self.unnest_classes = container.config.output.unnest_classes + + def process(self, target: Class): + """Process the given class attrs if they contain choices. + + Args: + target: The target class instance + """ + for attr in target.attrs: + if attr.choices: + self.process_compound_field(target, attr) + + def process_compound_field(self, target: Class, attr: Attr): + """Process a compound field. + + A compound field can be created by a mixed wildcard with + explicit children, or because we enabled the configuration + to group repeatable choices. + + Steps: + 1. Merge choices derived from xs:any elements + 2. Find ambiguous choices and create intermediate classes + 3. Reset the attr types if it's not a mixed wildcard. + + + Args: + target: The target class instance + attr: An attr instance that contains choices + """ + self.merge_wildcard_choices(attr) + + for choice in self.find_ambiguous_choices(attr): + self.disambiguate_choice(target, choice) + + if attr.tag == Tag.CHOICE: + types = (tp for choice in attr.choices for tp in choice.types) + attr.types = collections.unique_sequence(types) + + @classmethod + def merge_wildcard_choices(cls, attr: Attr): + """Merge choices derived from xs:any elements. + + It's a compound field it doesn't make sense + to have multiple wildcard choices. Merge them + together. + + Args: + attr: The attr instance that contains choices + """ + choices = [] + namespaces = [] + min_occurs = 0 + max_occurs = 0 + has_wildcard = False + for choice in attr.choices: + if choice.is_wildcard: + min_occurs += choice.restrictions.min_occurs or 0 + max_occurs += choice.restrictions.max_occurs or 0 + namespaces.append(choice.namespace) + has_wildcard = True + else: + choices.append(choice) + + attr.choices = choices + + if has_wildcard: + attr.choices.append( + Attr( + name="content", + types=[AttrType(qname=str(DataType.ANY_TYPE), native=True)], + tag=Tag.ANY, + namespace=" ".join( + collections.unique_sequence(filter(None, namespaces)) + ), + restrictions=Restrictions( + min_occurs=min_occurs, max_occurs=max_occurs + ), + ) + ) + + @classmethod + def find_ambiguous_choices(cls, attr: Attr) -> Iterator[Attr]: + """Find choices with the same types. + + Args: + attr: The attr instance with the choices. + + Yields: + An iterator of the ambiguous choices, except wildcards. + """ + groups = defaultdict(list) + for index, choice in enumerate(attr.choices): + for tp in choice.types: + dt = tp.datatype + if dt: + groups[dt.type.__name__].append(index) + else: + groups[tp.qname].append(index) + + ambiguous = set() + for indexes in groups.values(): + if len(indexes) > 1: + ambiguous.update(indexes) + + for index in ambiguous: + choice = attr.choices[index] + if not choice.is_wildcard: + yield choice + + def disambiguate_choice(self, target: Class, choice: Attr): + """Create intermediate class for the given choice. + + Scenarios: + 1. Choice is derived from xs:anyType + 2. Choice is derived from a xs:anySimpleType + 3. Choice is a reference to xs:complexType or element + + Args: + target: The target class instance + choice: The ambiguous choice attr instance + """ + is_circular = choice.is_circular_ref + inner = not self.unnest_classes and not is_circular + ref_class = self.create_ref_class(target, choice, inner=inner) + + if choice.is_any_type: + self.add_any_type_value(ref_class, choice) + elif self.is_simple_type(choice): + self.add_simple_type_value(ref_class, choice) + else: + self.add_extension(ref_class, choice) + + choice.restrictions = Restrictions( + min_occurs=choice.restrictions.min_occurs, + max_occurs=choice.restrictions.max_occurs, + ) + + ref_type = AttrType( + qname=ref_class.qname, + reference=id(ref_class), + forward=inner, + circular=is_circular, + ) + choice.types = [ref_type] + if not inner: + self.container.add(ref_class) + else: + target.inner.append(ref_class) + + def is_simple_type(self, choice: Attr) -> bool: + """Return whether the choice attr is a simple type reference.""" + if any(tp.native for tp in choice.types): + return True + + source = self.container.find(choice.types[0].qname) + if source and source.is_enumeration: + return True + + return False + + def create_ref_class(self, source: Class, choice: Attr, inner: bool) -> Class: + """Create an intermediate class for the given choice. + + If the reference class is going to be inner, ensure the class name is + unique, otherwise we will still end-up with ambiguous choices. + + Args: + source: The source class instance + choice: The ambiguous choice attr instance + inner: Specifies if the reference class will be inner + """ + name = choice.name + if inner: + name = self.next_available_name(source, name) + + return Class( + qname=build_qname(choice.namespace, name), + status=source.status, + tag=Tag.ELEMENT, + local_type=True, + location=source.location, + ns_map=source.ns_map, + nillable=choice.restrictions.nillable or False, + ) + + @classmethod + def next_available_name(cls, parent: Class, name: str) -> str: + """Find the next available name for an inner class. + + Args: + parent: The parent class instance + name: The name of the inner class + + Returns: + The next available class name by adding a integer suffix. + """ + reserved = {text.alnum(inner.name) for inner in parent.inner} + index = 0 + new_name = name + while True: + cmp = text.alnum(new_name) + + if cmp not in reserved: + return new_name + + index += 1 + new_name = f"{name}_{index}" + + @classmethod + def add_any_type_value(cls, reference: Class, choice: Attr): + """Add a simple any type content value attr to the reference class. + + Args: + reference: The reference class instance + choice: The source choice attr instance + """ + attr = Attr( + name="content", + types=[AttrType(qname=str(DataType.ANY_TYPE), native=True)], + tag=Tag.ANY, + namespace=choice.namespace, + restrictions=Restrictions(min_occurs=1, max_occurs=1), + ) + reference.attrs.append(attr) + + @classmethod + def add_simple_type_value(cls, reference: Class, choice: Attr): + """Add a simple type content value attr to the reference class. + + Args: + reference: The reference class instance + choice: The source choice attr instance + """ + new_attr = Attr( + tag=Tag.EXTENSION, + name=DEFAULT_ATTR_NAME, + namespace=None, + restrictions=choice.restrictions.clone( + min_occurs=1, + max_occurs=1, + path=[], + nillable=False, + ), + types=[tp.clone() for tp in choice.types], + ) + reference.attrs.append(new_attr) + + @classmethod + def add_extension(cls, reference: Class, choice: Attr): + """Add an extension to the reference class from the choice type. + + Args: + reference: The reference class instance + choice: The source choice attr instance + """ + extension = Extension( + tag=Tag.EXTENSION, + type=choice.types[0].clone(forward=False, circular=False), + restrictions=Restrictions(), + ) + reference.extensions.append(extension) diff --git a/xsdata/codegen/handlers/rename_duplicate_classes.py b/xsdata/codegen/handlers/rename_duplicate_classes.py index fbf7c794c..c2eac3c0f 100644 --- a/xsdata/codegen/handlers/rename_duplicate_classes.py +++ b/xsdata/codegen/handlers/rename_duplicate_classes.py @@ -1,7 +1,14 @@ -from typing import List +from typing import Dict, List from xsdata.codegen.mixins import ContainerHandlerInterface -from xsdata.codegen.models import Attr, Class, get_location, get_name, get_qname +from xsdata.codegen.models import ( + Attr, + AttrType, + Class, + get_location, + get_name, + get_qname, +) from xsdata.models.config import StructureStyle from xsdata.utils import collections, namespaces, text @@ -20,7 +27,12 @@ def run(self): groups = collections.group_by(self.container, lambda x: text.alnum(getter(x))) for classes in groups.values(): - if len(classes) > 1: + if len(classes) < 2: + continue + + if all(x == classes[0] for x in classes): + self.merge_classes(classes) + else: self.rename_classes(classes, use_name) def should_use_names(self) -> bool: @@ -36,6 +48,23 @@ def should_use_names(self) -> bool: or len(set(map(get_location, self.container))) == 1 ) + def merge_classes(self, classes: List[Class]): + """Remove the duplicate classes and update all references. + + Args: + classes: A list of duplicate classes + """ + keep = classes[0] + new = keep.ref + + replacements = {} + for item in classes[1:]: + replacements[item.ref] = new + self.container.remove(item) + + for item in self.container: + self.update_class_references(item, replacements) + def rename_classes(self, classes: List[Class], use_name: bool): """Rename all the classes in the list. @@ -98,6 +127,33 @@ def next_qname(self, namespace: str, name: str, use_name: bool) -> str: if cmp not in reserved: return qname + def update_class_references(self, target: Class, replacements: Dict[int, int]): + """Go through all class types and update all references. + + Args: + target: The target class instance to update + replacements: A mapping of old-new class references + """ + + def update_maybe(attr_type: AttrType): + exists = replacements.get(attr_type.reference) + if exists: + attr_type.reference = exists + + for attr in target.attrs: + for tp in attr.types: + update_maybe(tp) + + for choice in attr.choices: + for tp in choice.types: + update_maybe(tp) + + for ext in target.extensions: + update_maybe(ext.type) + + for inner in target.inner: + self.update_class_references(inner, replacements) + def rename_class_dependencies(self, target: Class, reference: int, replace: str): """Search and replace the old qualified class name in all classes. diff --git a/xsdata/codegen/handlers/validate_attributes_overrides.py b/xsdata/codegen/handlers/validate_attributes_overrides.py index 20f0ca739..dd675bdaa 100644 --- a/xsdata/codegen/handlers/validate_attributes_overrides.py +++ b/xsdata/codegen/handlers/validate_attributes_overrides.py @@ -133,7 +133,7 @@ def validate_override(cls, target: Class, child_attr: Attr, parent_attr: Attr): assert parent_attr.parent is not None logger.warning( "Converting parent field `%s::%s` to a list to match child class `%s`", - parent_attr.parent.name, + parent_attr.parent, parent_attr.name, target.name, ) diff --git a/xsdata/codegen/mixins.py b/xsdata/codegen/mixins.py index bfb76588f..ca733ead1 100644 --- a/xsdata/codegen/mixins.py +++ b/xsdata/codegen/mixins.py @@ -70,12 +70,20 @@ def first(self, qname: str) -> Class: @abc.abstractmethod def add(self, item: Class): - """Add class item to the container. + """Add class instance to the container. Args: item: The class instance to add """ + @abc.abstractmethod + def remove(self, item: Class): + """Remove class instance from the container. + + Args: + item: The class instances to remove + """ + @abc.abstractmethod def extend(self, items: List[Class]): """Add a list of classes to the container. @@ -147,7 +155,7 @@ def base_attrs(self, target: Class) -> List[Attr]: attrs.extend(self.base_attrs(base)) for attr in base.attrs: - attr.parent = base + attr.parent = base.qname attrs.append(attr) return attrs diff --git a/xsdata/codegen/models.py b/xsdata/codegen/models.py index 0c01ae8ba..24bc84260 100644 --- a/xsdata/codegen/models.py +++ b/xsdata/codegen/models.py @@ -72,10 +72,10 @@ class Restrictions: pattern: Optional[str] = field(default=None) explicit_timezone: Optional[str] = field(default=None) nillable: Optional[bool] = field(default=None) - sequence: Optional[int] = field(default=None) + sequence: Optional[int] = field(default=None, compare=False) tokens: Optional[bool] = field(default=None) format: Optional[str] = field(default=None) - choice: Optional[int] = field(default=None) + choice: Optional[int] = field(default=None, compare=False) group: Optional[int] = field(default=None) process_contents: Optional[str] = field(default=None) path: List[Tuple[str, int, int, int]] = field(default_factory=list) @@ -178,9 +178,9 @@ def asdict(self, types: Optional[List[Type]] = None) -> Dict: return result - def clone(self) -> "Restrictions": - """Return a deep cloned instance.""" - return replace(self) + def clone(self, **kwargs: Any) -> "Restrictions": + """Return a deep cloned instance and replace any args.""" + return replace(self, **kwargs) @classmethod def from_element(cls, element: ElementBase) -> "Restrictions": @@ -243,9 +243,9 @@ def is_dependency(self, allow_circular: bool) -> bool: self.forward or self.native or (not allow_circular and self.circular) ) - def clone(self) -> "AttrType": + def clone(self, **kwargs: Any) -> "AttrType": """Return a deep cloned instance.""" - return replace(self) + return replace(self, **kwargs) @dataclass @@ -265,7 +265,7 @@ class Attr: namespace: The attr namespace help: The attr help text restrictions: The attr restrictions instance - parent: The class reference of the attr + parent: The parent class qualified name of the attr substitution: The substitution group this attr belongs to """ @@ -281,7 +281,7 @@ class Attr: namespace: Optional[str] = field(default=None) help: Optional[str] = field(default=None, compare=False) restrictions: Restrictions = field(default_factory=Restrictions, compare=False) - parent: Optional["Class"] = field(default=None, compare=False) + parent: Optional[str] = field(default=None, compare=False) substitution: Optional[str] = field(default=None, compare=False) def __post_init__(self): @@ -302,6 +302,11 @@ def key(self) -> str: """ return f"{self.tag}.{self.namespace}.{self.local_name}" + @property + def qname(self) -> str: + """Return the fully qualified name of the attr.""" + return namespaces.build_qname(self.namespace, self.local_name) + @property def is_attribute(self) -> bool: """Return whether this attr represents a xml attribute node.""" @@ -327,6 +332,11 @@ def is_forward_ref(self) -> bool: """Return whether any attr types is a forward or circular reference.""" return any(tp.circular or tp.forward for tp in self.types) + @property + def is_circular_ref(self) -> bool: + """Return whether any attr types is a circular reference.""" + return any(tp.circular for tp in self.types) + @property def is_group(self) -> bool: """Return whether this attr is a reference to a group class.""" @@ -493,7 +503,7 @@ class Class: qname: str tag: str - location: str + location: str = field(compare=False) mixed: bool = field(default=False) abstract: bool = field(default=False) nillable: bool = field(default=False) diff --git a/xsdata/formats/dataclass/client.py b/xsdata/formats/dataclass/client.py index 964d54833..88241c9f8 100644 --- a/xsdata/formats/dataclass/client.py +++ b/xsdata/formats/dataclass/client.py @@ -87,7 +87,7 @@ def __init__( elif not serializer: assert parser is not None serializer = XmlSerializer(context=parser.context) - elif not parser: + else: assert serializer is not None parser = XmlParser(context=serializer.context) diff --git a/xsdata/formats/dataclass/models/builders.py b/xsdata/formats/dataclass/models/builders.py index d7e804f05..85a7e18bf 100644 --- a/xsdata/formats/dataclass/models/builders.py +++ b/xsdata/formats/dataclass/models/builders.py @@ -190,6 +190,7 @@ def build_vars( parent_namespace = getattr(real_clazz.Meta, "namespace", namespace) var = builder.build( + clazz, field.name, type_hints[field.name], field.metadata, @@ -340,6 +341,7 @@ def __init__( def build( self, + model: Type, name: str, type_hint: Any, metadata: Mapping[str, Any], @@ -352,7 +354,8 @@ def build( """Build the binding metadata for a class field. Args: - name: The field name + model: The model class + name: The model field name type_hint: The typing annotations of the field metadata: The field metadata mapping init: Specify whether this field can be initialized @@ -380,18 +383,20 @@ def build( sequence = metadata.get("sequence", None) wrapper = metadata.get("wrapper", None) - origin, sub_origin, types = self.analyze_types(type_hint, globalns) + origin, sub_origin, types = self.analyze_types(model, name, type_hint, globalns) if not self.is_valid(xml_type, origin, sub_origin, types, tokens, init): raise XmlContextError( - f"Xml type '{xml_type}' does not support typing: {type_hint}" + f"Error on {model.__qualname__}::{name}: " + f"Xml {xml_type} does not support typing `{type_hint}`" ) if wrapper is not None and ( not isinstance(origin, type) or not issubclass(origin, (list, set, tuple)) ): raise XmlContextError( - f"a wrapper requires a collection type on attribute {name}" + f"Error on {model.__qualname__}::{name}: " + f"A wrapper field requires a collection type" ) local_name = local_name or self.build_local_name(xml_type, name) @@ -416,7 +421,7 @@ def build( self.index += 1 cur_index = self.index for choice in self.build_choices( - name, choices, origin, globalns, parent_namespace + model, name, choices, origin, globalns, parent_namespace ): if choice.is_element: elements[choice.qname] = choice @@ -444,12 +449,12 @@ def build( wildcards=wildcards, namespaces=namespaces, xml_type=xml_type, - derived=False, wrapper=wrapper, ) def build_choices( self, + model: Type, name: str, choices: List[Dict], factory: Callable, @@ -459,7 +464,8 @@ def build_choices( """Build the binding metadata for a compound dataclass field. Args: - name: The compound field name + model: The model class + name: The model field name choices: The list of choice metadata factory: The compound field values factory globalns: Python's global namespace @@ -483,6 +489,7 @@ def build_choices( metadata["type"] = XmlType.ELEMENT var = self.build( + model, name, type_hint, metadata, @@ -496,8 +503,11 @@ def build_choices( # It's impossible for choice elements to be ignorable, read above! assert var is not None - if var.any_type or any(True for tp in var.types if tp in existing_types): - var.derived = True + if any(True for tp in var.types if tp in existing_types): + raise XmlContextError( + f"Error on {model.__qualname__}::{name}: " + f"Compound field contains ambiguous types" + ) existing_types.update(var.types) @@ -588,7 +598,7 @@ def is_any_type(cls, types: Sequence[Type], xml_type: str) -> bool: @classmethod def analyze_types( - cls, type_hint: Any, globalns: Any + cls, model: Type, name: str, type_hint: Any, globalns: Any ) -> Tuple[Any, Any, Tuple[Type, ...]]: """Analyze a type hint and return the origin, sub origin and the type args. @@ -617,7 +627,10 @@ def analyze_types( return origin, sub_origin, tuple(converter.sort_types(types)) except Exception: - raise XmlContextError(f"Unsupported typing: {type_hint}") + raise XmlContextError( + f"Error on {model.__qualname__}::{name}: " + f"Unsupported field typing `{type_hint}`" + ) def is_valid( self, diff --git a/xsdata/formats/dataclass/models/elements.py b/xsdata/formats/dataclass/models/elements.py index 6179a91b9..c5d9fb006 100644 --- a/xsdata/formats/dataclass/models/elements.py +++ b/xsdata/formats/dataclass/models/elements.py @@ -69,7 +69,6 @@ class XmlVar(MetaMixin): factory: Callable factory for lists tokens_factory: Callable factory for tokens format: Information about the value format - derived: Indicates whether parsed values should be wrapped with a generic type any_type: Indicates if the field supports dynamic value types process_contents: Information about processing contents required: Indicates if the field is mandatory @@ -107,7 +106,6 @@ class XmlVar(MetaMixin): "factory", "tokens_factory", "format", - "derived", "any_type", "process_contents", "required", @@ -144,7 +142,6 @@ def __init__( factory: Optional[Callable], tokens_factory: Optional[Callable], format: Optional[str], - derived: bool, any_type: bool, process_contents: str, required: bool, @@ -167,7 +164,6 @@ def __init__( self.mixed = mixed self.tokens = tokens_factory is not None self.format = format - self.derived = derived self.any_type = any_type self.process_contents = process_contents self.required = required diff --git a/xsdata/formats/dataclass/parsers/dict.py b/xsdata/formats/dataclass/parsers/dict.py index c44423d83..eb4566ab3 100644 --- a/xsdata/formats/dataclass/parsers/dict.py +++ b/xsdata/formats/dataclass/parsers/dict.py @@ -389,6 +389,7 @@ def bind_derived_value(self, meta: XmlMeta, var: XmlVar, data: Dict) -> Any: return self.bind_derived_value(meta, choice, data) if not isinstance(params, dict): + # Is this scenario still possible??? value = self.bind_text(meta, var, params) elif xsi_type: clazz: Optional[Type] = self.context.find_type(xsi_type) diff --git a/xsdata/formats/dataclass/parsers/nodes/element.py b/xsdata/formats/dataclass/parsers/nodes/element.py index 88e092e3e..24e71f5e5 100644 --- a/xsdata/formats/dataclass/parsers/nodes/element.py +++ b/xsdata/formats/dataclass/parsers/nodes/element.py @@ -501,7 +501,7 @@ def build_node( if var.clazz: return self.build_element_node( var.clazz, - var.derived, + False, var.nillable, attrs, ns_map, @@ -512,15 +512,16 @@ def build_node( ) if not var.any_type and not var.is_wildcard: - return nodes.PrimitiveNode( - var, ns_map, self.meta.mixed_content, derived_factory - ) + return nodes.PrimitiveNode(var, ns_map, self.meta.mixed_content) datatype = DataType.from_qname(xsi_type) if xsi_type else None - derived = var.derived or var.is_wildcard + derived = var.is_wildcard if datatype: return nodes.StandardNode( - datatype, ns_map, var.nillable, derived_factory if derived else None + datatype, + ns_map, + var.nillable, + derived_factory if derived else None, ) node = None diff --git a/xsdata/formats/dataclass/parsers/nodes/primitive.py b/xsdata/formats/dataclass/parsers/nodes/primitive.py index 4d3a41449..8cbcaef8c 100644 --- a/xsdata/formats/dataclass/parsers/nodes/primitive.py +++ b/xsdata/formats/dataclass/parsers/nodes/primitive.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Type +from typing import Dict, List, Optional from xsdata.exceptions import XmlContextError from xsdata.formats.dataclass.models.elements import XmlVar @@ -13,15 +13,13 @@ class PrimitiveNode(XmlNode): var: The xml var instance ns_map: The element namespace prefix-URI map mixed: Specifies if this node supports mixed content - derived_factory: The derived element factory """ - __slots__ = "var", "ns_map", "derived_factory" + __slots__ = "var", "ns_map" - def __init__(self, var: XmlVar, ns_map: Dict, mixed: bool, derived_factory: Type): + def __init__(self, var: XmlVar, ns_map: Dict, mixed: bool): self.var = var self.ns_map = ns_map - self.derived_factory = derived_factory self.mixed = mixed def bind( @@ -58,9 +56,6 @@ def bind( if obj is None and not self.var.nillable: obj = b"" if bytes in self.var.types else "" - if self.var.derived: - obj = self.derived_factory(qname=qname, value=obj) - objects.append((qname, obj)) if self.mixed: diff --git a/xsdata/formats/dataclass/parsers/nodes/wildcard.py b/xsdata/formats/dataclass/parsers/nodes/wildcard.py index 4449c9d48..34ca69029 100644 --- a/xsdata/formats/dataclass/parsers/nodes/wildcard.py +++ b/xsdata/formats/dataclass/parsers/nodes/wildcard.py @@ -58,7 +58,7 @@ def bind( """ children = self.fetch_any_children(self.position, objects) attributes = ParserUtils.parse_any_attributes(self.attrs, self.ns_map) - derived = self.var.derived or qname != self.var.qname + derived = qname != self.var.qname text = ParserUtils.normalize_content(text) if children else text text = "" if text is None and not self.var.nillable else text tail = ParserUtils.normalize_content(tail) diff --git a/xsdata/formats/dataclass/serializers/__init__.py b/xsdata/formats/dataclass/serializers/__init__.py index 42f982773..05d18896d 100644 --- a/xsdata/formats/dataclass/serializers/__init__.py +++ b/xsdata/formats/dataclass/serializers/__init__.py @@ -1,6 +1,6 @@ from xsdata.formats.dataclass.serializers.code import PycodeSerializer -from xsdata.formats.dataclass.serializers.dict import DictEncoder -from xsdata.formats.dataclass.serializers.json import DictFactory, JsonSerializer +from xsdata.formats.dataclass.serializers.dict import DictEncoder, DictFactory +from xsdata.formats.dataclass.serializers.json import JsonSerializer from xsdata.formats.dataclass.serializers.xml import XmlSerializer __all__ = [ diff --git a/xsdata/formats/dataclass/serializers/json.py b/xsdata/formats/dataclass/serializers/json.py index d62606c1d..77b7e5efe 100644 --- a/xsdata/formats/dataclass/serializers/json.py +++ b/xsdata/formats/dataclass/serializers/json.py @@ -1,30 +1,12 @@ import json from dataclasses import dataclass, field from io import StringIO -from typing import Any, Callable, Dict, TextIO, Tuple +from typing import Any, Callable, TextIO from xsdata.formats.bindings import AbstractSerializer from xsdata.formats.dataclass.serializers import DictEncoder -def filter_none(x: Tuple) -> Dict: - """Convert a key-value pairs to dict, ignoring None values. - - Args: - x: Key-value pairs - - Returns: - The filtered dictionary. - """ - return {k: v for k, v in x if v is not None} - - -class DictFactory: - """Dictionary factory types.""" - - FILTER_NONE = filter_none - - @dataclass class JsonSerializer(DictEncoder, AbstractSerializer): """Json serializer for data classes. diff --git a/xsdata/formats/dataclass/templates/class.jinja2 b/xsdata/formats/dataclass/templates/class.jinja2 index 4075911ae..51ba2c4a5 100644 --- a/xsdata/formats/dataclass/templates/class.jinja2 +++ b/xsdata/formats/dataclass/templates/class.jinja2 @@ -18,7 +18,7 @@ class {{ class_name }}{{"({})".format(base_classes) if base_classes }}: {%- if help %} {{ help|indent(4, first=True) }} {%- endif -%} -{%- if local_name or obj.is_nillable or obj.namespace is not none or target_namespace or obj.local_type %} +{%- if local_name or obj.is_nillable or obj.namespace is not none or target_namespace or (obj.local_type and level == 0) %} class Meta: {%- if obj.local_type %} global_type = False diff --git a/xsdata/utils/namespaces.py b/xsdata/utils/namespaces.py index 441032246..caebe0a83 100644 --- a/xsdata/utils/namespaces.py +++ b/xsdata/utils/namespaces.py @@ -111,7 +111,7 @@ def local_name(qname: str) -> str: return split_qname(qname)[1] -NCNAME_PUNCTUATION = {"\u00B7", "\u0387", ".", "-", "_"} +NCNAME_PUNCTUATION = {"\u00b7", "\u0387", ".", "-", "_"} def is_ncname(name: Optional[str]) -> bool: diff --git a/xsdata/utils/testing.py b/xsdata/utils/testing.py index ee469de6a..11e6bc882 100644 --- a/xsdata/utils/testing.py +++ b/xsdata/utils/testing.py @@ -57,6 +57,8 @@ def load_class(output: str, clazz_name: str) -> Any: class FactoryTestCase(unittest.TestCase): + maxDiff = None + def setUp(self): super().setUp() ClassFactory.reset() @@ -370,7 +372,6 @@ def create( factory: Optional[Callable] = None, tokens_factory: Optional[Callable] = None, format: Optional[str] = None, - derived: bool = False, any_type: bool = False, required: bool = False, nillable: bool = False, @@ -409,7 +410,6 @@ def create( factory=factory, tokens_factory=tokens_factory, format=format, - derived=derived, any_type=any_type, required=required, nillable=nillable,