Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prohibit parent fields on restriction extensions #908

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions tests/codegen/handlers/test_validate_attributes_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,46 @@ def setUp(self):
self.container = ClassContainer(config=GeneratorConfig())
self.processor = ValidateAttributesOverrides(container=self.container)

def test_prohibit_parent_attrs(self):
child = ClassFactory.create(
status=Status.FLATTENING,
attrs=[
AttrFactory.create(name="el", tag=Tag.ELEMENT),
AttrFactory.create(name="at", tag=Tag.ATTRIBUTE),
],
)

parent = ClassFactory.create(
status=Status.FLATTENED,
attrs=[
AttrFactory.element(default="foo"),
AttrFactory.attribute(),
AttrFactory.extension(),
AttrFactory.element(default="bar"),
],
)

child.extensions.append(
ExtensionFactory.reference(parent.qname, tag=Tag.RESTRICTION)
)
self.container.extend((parent, child))
self.processor.process(child)

self.assertEqual(4, len(child.attrs))

self.assertEqual(parent.attrs[0].name, child.attrs[0].name)
self.assertEqual([], child.attrs[0].types)
self.assertIsNone(child.attrs[0].default)
self.assertTrue(child.attrs[0].is_prohibited)

self.assertEqual(parent.attrs[3].name, child.attrs[1].name)
self.assertEqual([], child.attrs[1].types)
self.assertIsNone(child.attrs[1].default)
self.assertTrue(child.attrs[1].is_prohibited)

@mock.patch.object(ValidateAttributesOverrides, "resolve_conflict")
@mock.patch.object(ValidateAttributesOverrides, "validate_override")
def test_process(self, mock_validate_override, mock_resolve_conflict):
def test_validate_attrs(self, mock_validate_override, mock_resolve_conflict):
class_a = ClassFactory.create(
status=Status.FLATTENING,
attrs=[
Expand All @@ -35,8 +72,12 @@ def test_process(self, mock_validate_override, mock_resolve_conflict):
class_b = ClassFactory.elements(2, status=Status.FLATTENED)
class_c = ClassFactory.create(status=Status.FLATTENED)

class_b.extensions.append(ExtensionFactory.reference(class_c.qname))
class_a.extensions.append(ExtensionFactory.reference(class_b.qname))
class_b.extensions.append(
ExtensionFactory.reference(class_c.qname, tag=Tag.EXTENSION)
)
class_a.extensions.append(
ExtensionFactory.reference(class_b.qname, tag=Tag.EXTENSION)
)

class_c.attrs.append(class_a.attrs[0].clone())
class_c.attrs.append(class_a.attrs[1].clone())
Expand All @@ -52,7 +93,7 @@ def test_process(self, mock_validate_override, mock_resolve_conflict):
class_a.attrs[1], class_c.attrs[1]
)

def test_process_remove_non_overriding_prohibited_attrs(self):
def test_validate_attrs_remove_non_overriding_prohibited_attrs(self):
target = ClassFactory.elements(1)
target.attrs[0].restrictions.max_occurs = 0

Expand Down
6 changes: 6 additions & 0 deletions tests/codegen/models/test_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def test__eq__(self):
attr.namespace = __file__
self.assertNotEqual(attr, clone)

def test_can_be_restricted(self):
self.assertFalse(AttrFactory.create(tag=Tag.ATTRIBUTE).can_be_restricted())
self.assertFalse(AttrFactory.create(tag=Tag.EXTENSION).can_be_restricted())
self.assertFalse(AttrFactory.create(tag=Tag.RESTRICTION).can_be_restricted())
self.assertTrue(AttrFactory.create(tag=Tag.ELEMENT).can_be_restricted())

def test_property_key(self):
attr = AttrFactory.attribute(name="a", namespace="b")
self.assertEqual("Attribute.b.a", attr.key)
Expand Down
10 changes: 10 additions & 0 deletions tests/codegen/models/test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ def test_property_is_global_type(self):
obj.attrs.append(AttrFactory.create(tag=Tag.EXTENSION))
self.assertFalse(obj.is_global_type)

def test_property_is_restricted(self):
obj = ClassFactory.create()
ext = ExtensionFactory.create(tag=Tag.EXTENSION)
obj.extensions.append(ext)

self.assertFalse(obj.is_restricted)

ext.tag = Tag.RESTRICTION
self.assertTrue(obj.is_restricted)

def test_property_is_simple_type(self):
obj = ClassFactory.elements(2)

Expand Down
3 changes: 2 additions & 1 deletion tests/fixtures/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from dataclasses import field
from typing import Dict
from typing import Dict, Any
from typing import List
from typing import Optional
from typing import Type
Expand Down Expand Up @@ -30,6 +30,7 @@ class TypeC:
y: str
z: float
fixed: str = field(init=False, default="ignored")
restricted: Any = field(init=False, metadata={"type": "Ignore"})


@dataclass
Expand Down
52 changes: 39 additions & 13 deletions xsdata/codegen/handlers/validate_attributes_overrides.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Set

from xsdata.codegen.mixins import RelativeHandlerInterface
from xsdata.codegen.models import Attr, Class, get_slug
Expand All @@ -9,30 +9,56 @@


class ValidateAttributesOverrides(RelativeHandlerInterface):
"""
Check override attributes are valid.

Steps:
1. The attribute is a valid override, leave it alone
2. The attribute is unnecessary remove it
3. The attribute is an invalid override, rename one of them
"""
"""Validate override and restricted attributes."""

__slots__ = ()

def process(self, target: Class):
base_attrs_map = self.base_attrs_map(target)
# We need the original class attrs before validation, in order to
# prohibit the rest of the parent attrs later...
restricted_attrs = {
attr.slug for attr in target.attrs if attr.can_be_restricted()
}
self.validate_attrs(target, base_attrs_map)
if target.is_restricted:
self.prohibit_parent_attrs(target, restricted_attrs, base_attrs_map)

@classmethod
def prohibit_parent_attrs(
cls,
target: Class,
restricted_attrs: Set[str],
base_attrs_map: Dict[str, List[Attr]],
):
"""
Prepend prohibited parent attrs to the target class.

Reset the types and default value in order to avoid conflicts
later.
"""
for slug, attrs in reversed(base_attrs_map.items()):
attr = attrs[0]
if attr.can_be_restricted() and slug not in restricted_attrs:
attr_restricted = attr.clone()
attr_restricted.restrictions.max_occurs = 0
attr_restricted.default = None
attr_restricted.types.clear()
target.attrs.insert(0, attr_restricted)

@classmethod
def validate_attrs(cls, target: Class, base_attrs_map: Dict[str, List[Attr]]):
for attr in list(target.attrs):
base_attrs = base_attrs_map.get(attr.slug)

if base_attrs:
base_attr = base_attrs[0]
if self.overrides(attr, base_attr):
self.validate_override(target, attr, base_attr)
if cls.overrides(attr, base_attr):
cls.validate_override(target, attr, base_attr)
else:
self.resolve_conflict(attr, base_attr)
cls.resolve_conflict(attr, base_attr)
elif attr.is_prohibited:
self.remove_attribute(target, attr)
cls.remove_attribute(target, attr)

@classmethod
def overrides(cls, a: Attr, b: Attr) -> bool:
Expand Down
10 changes: 10 additions & 0 deletions xsdata/codegen/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,10 @@ def get_native_types(self) -> Iterator[Type]:
if datatype:
yield datatype.type

def can_be_restricted(self) -> bool:
"""Return whether this attribute can be restricted."""
return self.xml_type not in (Tag.ATTRIBUTE, None)


@dataclass(unsafe_hash=True)
class Extension:
Expand Down Expand Up @@ -460,6 +464,12 @@ def is_mixed(self) -> bool:
"""Return whether this class supports mixed content."""
return self.mixed or any(x.mixed for x in self.attrs)

@property
def is_restricted(self) -> bool:
return any(
True for extension in self.extensions if extension.tag == Tag.RESTRICTION
)

@property
def is_service(self) -> bool:
"""Return whether this instance is derived from wsdl:operation."""
Expand Down
4 changes: 3 additions & 1 deletion xsdata/formats/dataclass/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def score(value: Any) -> float:
return 0.0

if self.is_model(obj):
return sum(score(getattr(obj, var.name)) for var in self.get_fields(obj))
return sum(
score(getattr(obj, var.name, None)) for var in self.get_fields(obj)
)

return score(obj)

Expand Down