Skip to content

Commit f61205f

Browse files
committed
Add ParserConfig.class_factory
1 parent 4eadb4a commit f61205f

File tree

11 files changed

+131
-31
lines changed

11 files changed

+131
-31
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ exclude: tests/fixtures
22

33
repos:
44
- repo: https://github.com/asottile/pyupgrade
5-
rev: v2.19.4
5+
rev: v2.20.0
66
hooks:
77
- id: pyupgrade
88
args: [--py37-plus]

docs/examples.rst

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Advance Topics
2222
:maxdepth: 1
2323

2424
examples/custom-property-names
25+
examples/custom-class-factory
2526

2627

2728
Test Suites
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
====================
2+
Custom class factory
3+
====================
4+
5+
6+
It's not recommended to modify the generated models. If you need to add any pre/post
7+
initialization logic or even validations you can use the parser config to override the
8+
default class factory.
9+
10+
.. doctest::
11+
12+
>>> from dataclasses import dataclass
13+
>>> from xsdata.formats.dataclass.parsers import JsonParser
14+
>>> from xsdata.formats.dataclass.parsers.config import ParserConfig
15+
...
16+
>>> def custom_class_factory(clazz, params):
17+
... if clazz.__name__ == "Person":
18+
... return clazz(**{k: v.upper() for k, v in params.items()})
19+
...
20+
... return clazz(**params)
21+
...
22+
23+
>>> config = ParserConfig(class_factory=custom_class_factory)
24+
>>> parser = JsonParser(config=config)
25+
...
26+
>>> @dataclass
27+
... class Person:
28+
... first_name: str
29+
... last_name: str
30+
...
31+
>>> json_str = """{"first_name": "chris", "last_name": "foo"}"""
32+
...
33+
...
34+
>>> print(parser.from_string(json_str, Person))
35+
Person(first_name='CHRIS', last_name='FOO')

tests/formats/dataclass/parsers/nodes/test_union.py

+41-10
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,47 @@
66
from xsdata.exceptions import ParserError
77
from xsdata.formats.dataclass.context import XmlContext
88
from xsdata.formats.dataclass.models.elements import XmlType
9+
from xsdata.formats.dataclass.parsers.config import ParserConfig
910
from xsdata.formats.dataclass.parsers.nodes import UnionNode
1011
from xsdata.models.mixins import attribute
1112
from xsdata.utils.testing import XmlVarFactory
1213

1314

1415
class UnionNodeTests(TestCase):
16+
def setUp(self) -> None:
17+
super().setUp()
18+
19+
self.context = XmlContext()
20+
self.config = ParserConfig()
21+
1522
def test_child(self):
1623
attrs = {"id": "1"}
1724
ns_map = {"ns0": "xsdata"}
18-
ctx = XmlContext()
1925
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", qname="foo")
20-
node = UnionNode(position=0, var=var, context=ctx, attrs={}, ns_map={})
26+
node = UnionNode(
27+
position=0,
28+
var=var,
29+
config=self.config,
30+
context=self.context,
31+
attrs={},
32+
ns_map={},
33+
)
2134
self.assertEqual(node, node.child("foo", attrs, ns_map, 10))
2235

2336
self.assertEqual(1, node.level)
2437
self.assertEqual([("start", "foo", attrs, ns_map)], node.events)
2538
self.assertIsNot(attrs, node.events[0][2])
2639

2740
def test_bind_appends_end_event_when_level_not_zero(self):
28-
ctx = XmlContext()
2941
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", qname="foo")
30-
node = UnionNode(position=0, var=var, context=ctx, attrs={}, ns_map={})
42+
node = UnionNode(
43+
position=0,
44+
var=var,
45+
config=self.config,
46+
context=self.context,
47+
attrs={},
48+
ns_map={},
49+
)
3150
node.level = 1
3251
objects = []
3352

@@ -43,12 +62,18 @@ def test_bind_returns_best_matching_object(self):
4362
item2 = make_dataclass("Item2", [("a", int, attribute())])
4463
root = make_dataclass("Root", [("item", Union[str, int, item2, item])])
4564

46-
ctx = XmlContext()
47-
meta = ctx.build(root)
65+
meta = self.context.build(root)
4866
var = next(meta.find_children("item"))
4967
attrs = {"a": "1", "b": 2}
5068
ns_map = {}
51-
node = UnionNode(position=0, var=var, context=ctx, attrs=attrs, ns_map=ns_map)
69+
node = UnionNode(
70+
position=0,
71+
var=var,
72+
config=self.config,
73+
context=self.context,
74+
attrs=attrs,
75+
ns_map=ns_map,
76+
)
5277
objects = []
5378

5479
self.assertTrue(node.bind("item", "1", None, objects))
@@ -73,11 +98,17 @@ def test_bind_returns_best_matching_object(self):
7398
self.assertEqual("a", objects[-1][1])
7499

75100
def test_bind_raises_parser_error_on_failure(self):
76-
ctx = XmlContext()
77-
meta = ctx.build(UnionType)
101+
meta = self.context.build(UnionType)
78102
var = next(meta.find_children("element"))
79103

80-
node = UnionNode(position=0, var=var, context=ctx, attrs={}, ns_map={})
104+
node = UnionNode(
105+
position=0,
106+
var=var,
107+
config=self.config,
108+
context=self.context,
109+
attrs={},
110+
ns_map={},
111+
)
81112

82113
with self.assertRaises(ParserError) as cm:
83114
node.bind("element", None, None, [])

xsdata/formats/dataclass/parsers/config.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
from dataclasses import dataclass
2+
from typing import Callable
3+
from typing import Dict
24
from typing import Optional
5+
from typing import Type
6+
from typing import TypeVar
7+
8+
T = TypeVar("T")
9+
10+
11+
def default_class_factory(cls: Type[T], params: Dict) -> T:
12+
return cls(**params) # type: ignore
313

414

515
@dataclass
@@ -10,6 +20,7 @@ class ParserConfig:
1020
:param base_url: Specify a base URL when parsing from memory and
1121
you need support for relative links eg xinclude
1222
:param process_xinclude: Enable xinclude statements processing
23+
:param class_factory: Override default object instantiation
1324
:param fail_on_unknown_properties: Skip unknown properties or
1425
fail with exception
1526
:param fail_on_converter_warnings: Turn converter warnings to
@@ -18,5 +29,6 @@ class ParserConfig:
1829

1930
base_url: Optional[str] = None
2031
process_xinclude: bool = False
32+
class_factory: Callable[[Type[T], Dict], T] = default_class_factory
2133
fail_on_unknown_properties: bool = True
2234
fail_on_converter_warnings: bool = False

xsdata/formats/dataclass/parsers/json.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ def verify_type(self, clazz: Optional[Type[T]], data: Union[Dict, List]) -> Type
9090
if list_type != isinstance(data, list):
9191
if list_type:
9292
raise ParserError("Document is object, expected array")
93-
else:
94-
raise ParserError("Document is array, expected object")
93+
raise ParserError("Document is array, expected object")
9594

9695
return clazz # type: ignore
9796

@@ -128,7 +127,7 @@ def bind_dataclass(self, data: Dict, clazz: Type[T]) -> T:
128127
params[var.name] = self.bind_value(meta, var, value)
129128

130129
try:
131-
return clazz(**params) # type: ignore
130+
return self.config.class_factory(clazz, params)
132131
except TypeError as e:
133132
raise ParserError(e)
134133

@@ -257,22 +256,22 @@ def bind_complex_type(self, meta: XmlMeta, var: XmlVar, data: Dict) -> Any:
257256
if var.is_clazz_union:
258257
# Union of dataclasses
259258
return self.bind_best_dataclass(data, var.types)
260-
elif var.elements:
259+
if var.elements:
261260
# Compound field with multiple choices
262261
return self.bind_best_dataclass(data, var.element_types)
263-
elif var.any_type or var.is_wildcard:
262+
if var.any_type or var.is_wildcard:
264263
# xs:anyType element, check all meta classes
265264
return self.bind_best_dataclass(data, meta.element_types)
266-
else:
267-
assert var.clazz is not None
268265

269-
subclasses = set(self.context.get_subclasses(var.clazz))
270-
if subclasses:
271-
# field annotation is an abstract/base type
272-
subclasses.add(var.clazz)
273-
return self.bind_best_dataclass(data, subclasses)
266+
assert var.clazz is not None
267+
268+
subclasses = set(self.context.get_subclasses(var.clazz))
269+
if subclasses:
270+
# field annotation is an abstract/base type
271+
subclasses.add(var.clazz)
272+
return self.bind_best_dataclass(data, subclasses)
274273

275-
return self.bind_dataclass(data, var.clazz)
274+
return self.bind_dataclass(data, var.clazz)
276275

277276
def bind_derived_value(self, meta: XmlMeta, var: XmlVar, data: Dict) -> T:
278277
"""Bind derived element entry point."""

xsdata/formats/dataclass/parsers/nodes/element.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class ElementNode(XmlNode):
3131
:param context: Model context provider
3232
:param position: The node position of objects cache
3333
:param mixed: The node supports mixed content
34-
:param derived: The xml element is derived from a base type
34+
:param derived_factory: Derived element factory
3535
:param xsi_type: The xml type substitution
3636
"""
3737

@@ -92,7 +92,7 @@ def bind(
9292
if isinstance(params[key], PendingCollection):
9393
params[key] = params[key].evaluate()
9494

95-
obj = self.meta.clazz(**params)
95+
obj = self.config.class_factory(self.meta.clazz, params)
9696
if self.derived_factory:
9797
obj = self.derived_factory(qname=qname, value=obj, type=self.xsi_type)
9898

@@ -330,6 +330,7 @@ def build_node(
330330
var=var,
331331
attrs=attrs,
332332
ns_map=ns_map,
333+
config=self.config,
333334
context=self.context,
334335
position=position,
335336
)

xsdata/formats/dataclass/parsers/nodes/primitive.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class PrimitiveNode(XmlNode):
1515
1616
:param var: Class field xml var instance
1717
:param ns_map: Namespace prefix-URI map
18+
:param derived_factory: Derived element factory
1819
"""
1920

2021
__slots__ = "var", "ns_map", "derived_factory"

xsdata/formats/dataclass/parsers/nodes/standard.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@ class StandardNode(XmlNode):
1515
1616
:param datatype: Standard xsi data type
1717
:param ns_map: Namespace prefix-URI map
18-
:param derived: Specify whether the value needs to be wrapped with
19-
:class:`~xsdata.formats.dataclass.models.generics.DerivedElement`
2018
:param nillable: Specify whether the node supports nillable content
19+
:param derived_factory: Optional derived element factory
2120
"""
2221

2322
__slots__ = "datatype", "ns_map", "nillable", "derived_factory"

xsdata/formats/dataclass/parsers/nodes/union.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from xsdata.formats.dataclass.context import XmlContext
1414
from xsdata.formats.dataclass.models.elements import XmlVar
1515
from xsdata.formats.dataclass.parsers.bases import NodeParser
16+
from xsdata.formats.dataclass.parsers.config import ParserConfig
1617
from xsdata.formats.dataclass.parsers.mixins import EventsHandler
1718
from xsdata.formats.dataclass.parsers.mixins import XmlNode
1819
from xsdata.formats.dataclass.parsers.utils import ParserUtils
@@ -32,18 +33,35 @@ class UnionNode(XmlNode):
3233
:param attrs: Key-value attribute mapping
3334
:param ns_map: Namespace prefix-URI map
3435
:param position: The node position of objects cache
36+
:param config: Parser configuration
3537
:param context: Model context provider
3638
"""
3739

38-
__slots__ = "var", "attrs", "ns_map", "position", "context", "level", "events"
40+
__slots__ = (
41+
"var",
42+
"attrs",
43+
"ns_map",
44+
"position",
45+
"config",
46+
"context",
47+
"level",
48+
"events",
49+
)
3950

4051
def __init__(
41-
self, var: XmlVar, attrs: Dict, ns_map: Dict, position: int, context: XmlContext
52+
self,
53+
var: XmlVar,
54+
attrs: Dict,
55+
ns_map: Dict,
56+
position: int,
57+
config: ParserConfig,
58+
context: XmlContext,
4259
):
4360
self.var = var
4461
self.attrs = attrs
4562
self.ns_map = ns_map
4663
self.position = position
64+
self.config = config
4765
self.context = context
4866
self.level = 0
4967
self.events: List[Tuple[str, str, Any, Any]] = []
@@ -94,7 +112,9 @@ def parse_class(self, clazz: Type[T]) -> Optional[T]:
94112
with warnings.catch_warnings():
95113
warnings.filterwarnings("error", category=ConverterWarning)
96114

97-
parser = NodeParser(context=self.context, handler=EventsHandler)
115+
parser = NodeParser(
116+
config=self.config, context=self.context, handler=EventsHandler
117+
)
98118
return parser.parse(self.events, clazz)
99119
except Exception:
100120
return None

xsdata/formats/dataclass/parsers/nodes/wildcard.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class WildcardNode(XmlNode):
2020
:param attrs: Key-value attribute mapping
2121
:param ns_map: Namespace prefix-URI map
2222
:param position: The node position of objects cache
23+
:param factory: Wildcard element factory
2324
"""
2425

2526
__slots__ = "var", "attrs", "ns_map", "position", "factory"

0 commit comments

Comments
 (0)