Skip to content

Commit bad72ee

Browse files
committed
Fix code issues
1 parent 091330f commit bad72ee

File tree

4 files changed

+38
-34
lines changed

4 files changed

+38
-34
lines changed

tests/codegen/models/test_class.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ def test_dependencies(self):
2929
],
3030
choices=[
3131
AttrChoiceFactory.create(
32-
name="x", types=[AttrTypeFactory.create(qname="choiceAttr")]
32+
name="x",
33+
types=[
34+
AttrTypeFactory.create(qname="choiceAttr"),
35+
AttrTypeFactory.xs_string(),
36+
],
3337
),
3438
AttrChoiceFactory.create(
3539
name="x",
@@ -81,7 +85,7 @@ def test_dependencies(self):
8185
"{http://www.w3.org/2001/XMLSchema}foobar",
8286
"{xsdata}foo",
8387
]
84-
self.assertEqual(expected, list(obj.dependencies()))
88+
self.assertCountEqual(expected, list(obj.dependencies()))
8589

8690
def test_property_has_suffix_attr(self):
8791
obj = ClassFactory.create()

tests/formats/dataclass/test_elements.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99
from xsdata.formats.dataclass.models.elements import XmlAttributes
1010
from xsdata.formats.dataclass.models.elements import XmlElement
1111
from xsdata.formats.dataclass.models.elements import XmlElements
12-
from xsdata.formats.dataclass.models.elements import XmlMeta
1312
from xsdata.formats.dataclass.models.elements import XmlText
1413
from xsdata.formats.dataclass.models.elements import XmlVar
1514
from xsdata.formats.dataclass.models.elements import XmlWildcard
16-
from xsdata.models.enums import FormType
1715

1816

1917
@dataclass
@@ -64,6 +62,10 @@ def test_find_choice(self):
6462
var = XmlVar(name="foo", qname="foo")
6563
self.assertIsNone(var.find_choice("foo"))
6664

65+
def test_find_choice_typed(self):
66+
var = XmlVar(name="foo", qname="foo")
67+
self.assertIsNone(var.find_choice_typed(int))
68+
6769

6870
class XmlElementTests(TestCase):
6971
def test_property_is_element(self):

xsdata/codegen/models.py

+15-25
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,14 @@ def merge(self, source: "Restrictions"):
127127
if source.sequential and (is_list or not self.is_list):
128128
self.sequential = source.sequential
129129

130-
if source.choice:
131-
self.choice = source.choice
130+
self.choice = source.choice or self.choice
131+
self.tokens = source.tokens or self.tokens
132132

133-
if not self.tokens and source.tokens:
134-
self.tokens = True
135-
136-
# Update min occurs if current value is None and the new value is more than one.
133+
# Update min occurs if current value is None or the new value is more than one.
137134
if self.min_occurs is None or (min_occurs is not None and min_occurs != 1):
138135
self.min_occurs = min_occurs
139136

140-
# Update max occurs if current value is None and the new value is more than one.
137+
# Update max occurs if current value is None or the new value is more than one.
141138
if self.max_occurs is None or (max_occurs is not None and max_occurs != 1):
142139
self.max_occurs = max_occurs
143140

@@ -190,8 +187,8 @@ class AttrType:
190187
"""
191188

192189
qname: str
193-
index: int = field(default_factory=int)
194-
alias: Optional[str] = field(default=None)
190+
index: int = field(default_factory=int, compare=False)
191+
alias: Optional[str] = field(default=None, compare=False)
195192
native: bool = field(default=False)
196193
forward: bool = field(default=False)
197194
circular: bool = field(default=False)
@@ -522,28 +519,21 @@ def dependencies(self) -> Iterator[str]:
522519
Collect:
523520
* base classes
524521
* attribute types
522+
* attribute choice types
525523
* recursively go through the inner classes
526524
* Ignore inner class references
527525
* Ignore native types.
528526
"""
529527

530-
seen = set()
528+
types = {ext.type for ext in self.extensions}
529+
531530
for attr in self.attrs:
532-
for attr_type in attr.types:
533-
if attr_type.is_dependency and attr_type.name not in seen:
534-
yield attr_type.qname
535-
seen.add(attr_type.name)
536-
537-
for attr_choice in attr.choices:
538-
for attr_type in attr_choice.types:
539-
if attr_type.is_dependency and attr_type.name not in seen:
540-
yield attr_type.qname
541-
seen.add(attr_type.name)
542-
543-
for ext in self.extensions:
544-
if ext.type.is_dependency and ext.type.name not in seen:
545-
yield ext.type.qname
546-
seen.add(ext.type.name)
531+
types.update(attr.types)
532+
types.update(tp for choice in attr.choices for tp in choice.types)
533+
534+
for tp in types:
535+
if tp.is_dependency:
536+
yield tp.qname
547537

548538
for inner in self.inner:
549539
yield from inner.dependencies()

xsdata/formats/dataclass/context.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from xsdata.formats.dataclass.models.elements import XmlMeta
2323
from xsdata.formats.dataclass.models.elements import XmlVar
2424
from xsdata.models.enums import NamespaceType
25-
from xsdata.utils.collections import first
2625
from xsdata.utils.constants import EMPTY_SEQUENCE
2726
from xsdata.utils.namespaces import build_qname
2827

@@ -150,8 +149,8 @@ def get_type_hints(self, clazz: Type, parent_ns: Optional[str]) -> Iterator[XmlV
150149
xml_clazz = XmlType.to_xml_class(xml_type)
151150
namespace = var.metadata.get("namespace")
152151
namespaces = self.resolve_namespaces(xml_type, namespace, parent_ns)
153-
first_namespace = first(x for x in namespaces if x and x[0] != "#")
154-
qname = build_qname(first_namespace, local_name)
152+
default_namespace = self.default_namespace(namespaces)
153+
qname = build_qname(default_namespace, local_name)
155154

156155
choices = list(
157156
self.build_choices(
@@ -190,12 +189,12 @@ def build_choices(
190189
xml_type = choice.get("tag", XmlType.ELEMENT)
191190
namespace = choice.get("namespace")
192191
namespaces = self.resolve_namespaces(xml_type, namespace, parent_namespace)
193-
first_namespace = first(x for x in namespaces if x and x[0] != "#")
192+
default_namespace = self.default_namespace(namespaces)
194193

195194
types = self.real_types(_eval_type(choice["type"], globalns, None))
196195
is_class = any(is_dataclass(clazz) for clazz in types)
197196
xml_clazz = XmlType.to_xml_class(xml_type)
198-
qname = build_qname(first_namespace, choice.get("name", "any"))
197+
qname = build_qname(default_namespace, choice.get("name", "any"))
199198

200199
yield xml_clazz(
201200
name=parent_name,
@@ -242,6 +241,15 @@ def resolve_namespaces(
242241
result.add(ns)
243242
return list(result)
244243

244+
@classmethod
245+
def default_namespace(cls, namespaces: List[str]) -> Optional[str]:
246+
"""Return the first valid namespace uri or None."""
247+
for namespace in namespaces:
248+
if namespace and not namespace.startswith("#"):
249+
return namespace
250+
251+
return None
252+
245253
@classmethod
246254
def default_value(cls, var: Field) -> Any:
247255
"""Return the default value/factory for the given field."""

0 commit comments

Comments
 (0)