diff --git a/docs/api/codegen.rst b/docs/api/codegen.rst index a2309ce60..6b207641e 100644 --- a/docs/api/codegen.rst +++ b/docs/api/codegen.rst @@ -22,8 +22,11 @@ like naming conventions and aliases. OutputFormat GeneratorConventions GeneratorAliases + GeneratorSubstitutions StructureStyle DocstringStyle + ObjectType GeneratorAlias + GeneratorSubstitution NameConvention NameCase diff --git a/tests/codegen/handlers/test_class_designate.py b/tests/codegen/handlers/test_class_designate.py index 88c49dfd7..71ef0be02 100644 --- a/tests/codegen/handlers/test_class_designate.py +++ b/tests/codegen/handlers/test_class_designate.py @@ -3,6 +3,7 @@ from xsdata.exceptions import CodeGenerationError from xsdata.models.config import GeneratorAlias from xsdata.models.config import GeneratorConfig +from xsdata.models.config import ObjectType from xsdata.models.config import StructureStyle from xsdata.models.enums import Namespace from xsdata.utils.testing import AttrFactory @@ -190,8 +191,10 @@ def test_combine_ns_package(self): result = self.handler.combine_ns_package(namespace) self.assertEqual(["generated", "bar", "foo", "add"], result) - alias = GeneratorAlias(source=namespace, target="add.again") - self.config.aliases.package_name.append(alias) + alias = GeneratorAlias( + type=ObjectType.PACKAGE, source=namespace, target="add.again" + ) + self.config.aliases.alias.append(alias) result = self.handler.combine_ns_package(namespace) self.assertEqual(["generated", "add", "again"], result) diff --git a/tests/fixtures/calculator/AddRQ.xml b/tests/fixtures/calculator/AddRQ.xml index 420f345d1..bebf7c5c6 100644 --- a/tests/fixtures/calculator/AddRQ.xml +++ b/tests/fixtures/calculator/AddRQ.xml @@ -1,9 +1,9 @@ - - + + 1 3 - - + + diff --git a/tests/fixtures/hello/HelloRQ.xml b/tests/fixtures/hello/HelloRQ.xml index 87f9a5c7b..a0de134df 100644 --- a/tests/fixtures/hello/HelloRQ.xml +++ b/tests/fixtures/hello/HelloRQ.xml @@ -1,8 +1,8 @@ - - + + chris - - + + diff --git a/tests/formats/dataclass/test_filters.py b/tests/formats/dataclass/test_filters.py index 652603b21..3977c4f01 100644 --- a/tests/formats/dataclass/test_filters.py +++ b/tests/formats/dataclass/test_filters.py @@ -7,7 +7,9 @@ from xsdata.models.config import DocstringStyle from xsdata.models.config import GeneratorAlias from xsdata.models.config import GeneratorConfig +from xsdata.models.config import GeneratorSubstitution from xsdata.models.config import NameCase +from xsdata.models.config import ObjectType from xsdata.models.enums import DataType from xsdata.models.enums import Namespace from xsdata.models.enums import Tag @@ -33,16 +35,19 @@ def setUp(self) -> None: self.filters = Filters(config) def test_class_name(self): - self.filters.class_aliases["boom"] = "Bang" + self.filters.aliases[ObjectType.CLASS]["boom"] = "Bang" + self.filters.substitutions[ObjectType.CLASS]["Abc"] = "Cba" self.assertEqual("XsString", self.filters.class_name("xs:string")) self.assertEqual("FooBarBam", self.filters.class_name("foo:bar_bam")) self.assertEqual("ListType", self.filters.class_name("List")) self.assertEqual("TypeType", self.filters.class_name(".*")) self.assertEqual("Bang", self.filters.class_name("boom")) + self.assertEqual("Cbad", self.filters.class_name("abcd")) def test_field_name(self): - self.filters.field_aliases["boom"] = "Bang" + self.filters.aliases[ObjectType.FIELD]["boom"] = "Bang" + self.filters.substitutions[ObjectType.FIELD]["abc"] = "cba" self.assertEqual("value", self.filters.field_name("", "cls")) self.assertEqual("foo", self.filters.field_name("foo", "cls")) @@ -53,9 +58,11 @@ def test_field_name(self): self.assertEqual("value_1", self.filters.field_name("1", "cls")) self.assertEqual("Bang", self.filters.field_name("boom", "cls")) self.assertEqual("value_minus_1_1", self.filters.field_name("-1.1", "cls")) + self.assertEqual("cbad", self.filters.field_name("abcd", "cls")) def test_constant_name(self): - self.filters.field_aliases["boom"] = "Bang" + self.filters.aliases[ObjectType.FIELD]["boom"] = "Bang" + self.filters.substitutions[ObjectType.FIELD]["ABC"] = "CBA" self.assertEqual("VALUE", self.filters.constant_name("", "cls")) self.assertEqual("FOO", self.filters.constant_name("foo", "cls")) @@ -66,9 +73,13 @@ def test_constant_name(self): self.assertEqual("VALUE_1", self.filters.constant_name("1", "cls")) self.assertEqual("Bang", self.filters.constant_name("boom", "cls")) self.assertEqual("VALUE_MINUS_1", self.filters.constant_name("-1", "cls")) + self.assertEqual("CBAD", self.filters.constant_name("ABCD", "cls")) def test_module_name(self): - self.filters.module_aliases["http://github.com/tefra/xsdata"] = "xsdata" + self.filters.aliases[ObjectType.MODULE].update( + {"http://github.com/tefra/xsdata": "xsdata"} + ) + self.filters.substitutions[ObjectType.MODULE].update({"xsdata": "data"}) self.assertEqual("foo_bar", self.filters.module_name("fooBar")) self.assertEqual("foo_bar_wtf", self.filters.module_name("fooBar.wtf")) @@ -77,22 +88,24 @@ def test_module_name(self): self.assertEqual("foo_bar_bam", self.filters.module_name("foo:bar_bam")) self.assertEqual("bar_bam", self.filters.module_name("urn:bar_bam")) self.assertEqual( - "pypi_org_project_xsdata", + "pypi_org_project_data", self.filters.module_name("http://pypi.org/project/xsdata/"), ) self.assertEqual( - "xsdata", self.filters.module_name("http://github.com/tefra/xsdata") + "data", self.filters.module_name("http://github.com/tefra/xsdata") ) def test_package_name(self): - self.filters.package_aliases["boom"] = "bang" - self.filters.package_aliases["boom.boom"] = "booom" + self.filters.aliases[ObjectType.PACKAGE]["boom"] = "bang" + self.filters.aliases[ObjectType.PACKAGE]["boom.boom"] = "booom" + self.filters.substitutions[ObjectType.PACKAGE]["bam"] = "boom" self.assertEqual( "foo.bar_bar.pkg_1", self.filters.package_name("Foo.BAR_bar.1") ) self.assertEqual("foo.bang.pkg_1", self.filters.package_name("Foo.boom.1")) self.assertEqual("booom", self.filters.package_name("boom.boom")) + self.assertEqual("boom.boom", self.filters.package_name("bam.bam")) self.assertEqual("", self.filters.package_name("")) def test_type_name(self): @@ -774,11 +787,17 @@ def test__init(self): config.conventions.field_name.case = NameCase.PASCAL config.conventions.module_name.safe_prefix = "safe_module" config.conventions.module_name.case = NameCase.SNAKE - config.aliases.class_name.append(GeneratorAlias("a", "b")) - config.aliases.class_name.append(GeneratorAlias("c", "d")) - config.aliases.field_name.append(GeneratorAlias("e", "f")) - config.aliases.package_name.append(GeneratorAlias("g", "h")) - config.aliases.module_name.append(GeneratorAlias("i", "j")) + config.aliases.alias.append(GeneratorAlias(ObjectType.CLASS, "a", "b")) + config.aliases.alias.append(GeneratorAlias(ObjectType.CLASS, "c", "d")) + config.aliases.alias.append(GeneratorAlias(ObjectType.FIELD, "e", "f")) + config.aliases.alias.append(GeneratorAlias(ObjectType.PACKAGE, "g", "h")) + config.aliases.alias.append(GeneratorAlias(ObjectType.MODULE, "i", "j")) + config.substitutions.substitution.append( + GeneratorSubstitution(ObjectType.FIELD, "k", "l") + ) + config.substitutions.substitution.append( + GeneratorSubstitution(ObjectType.PACKAGE, "m", "n") + ) filters = Filters(config) @@ -794,7 +813,18 @@ def test__init(self): self.assertEqual("cAB", filters.package_name("cAB")) self.assertEqual("c_ab", filters.module_name("cAB")) - self.assertEqual({"a": "b", "c": "d"}, filters.class_aliases) - self.assertEqual({"e": "f"}, filters.field_aliases) - self.assertEqual({"g": "h"}, filters.package_aliases) - self.assertEqual({"i": "j"}, filters.module_aliases) + expected_aliases = { + ObjectType.CLASS: {"a": "b", "c": "d"}, + ObjectType.FIELD: {"e": "f"}, + ObjectType.MODULE: {"i": "j"}, + ObjectType.PACKAGE: {"g": "h"}, + } + self.assertEqual(expected_aliases, filters.aliases) + + expected_substitutions = { + ObjectType.CLASS: {}, + ObjectType.FIELD: {"k": "l"}, + ObjectType.MODULE: {}, + ObjectType.PACKAGE: {"m": "n"}, + } + self.assertEqual(expected_substitutions, filters.substitutions) diff --git a/tests/models/test_config.py b/tests/models/test_config.py index 95b29dbcc..2f75e1edf 100644 --- a/tests/models/test_config.py +++ b/tests/models/test_config.py @@ -23,11 +23,10 @@ def test_create(self): expected = ( '\n' - f'\n' + '\n' ' \n' " generated\n" - ' dataclasses\n' + ' dataclasses\n' " filenames\n" " reStructuredText\n" " false\n" @@ -41,12 +40,18 @@ def test_create(self): ' \n' " \n" " \n" - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' + ' \n' + ' \n' + ' \n' + ' \n' + ' \n' + ' \n' + ' \n' + ' \n' " \n" + " \n" + ' \n' + " \n" "\n" ) self.assertEqual(expected, file_path.read_text()) @@ -63,6 +68,7 @@ def test_read(self): ' \n' " \n" " \n" + " \n" "\n" ) file_path = Path(tempfile.mktemp()) @@ -91,6 +97,7 @@ def test_read(self): ' \n' " \n" " \n" + " \n" "\n" ) self.assertEqual(expected, file_path.read_text()) diff --git a/tests/utils/test_namespaces.py b/tests/utils/test_namespaces.py index 56e770999..c47774594 100644 --- a/tests/utils/test_namespaces.py +++ b/tests/utils/test_namespaces.py @@ -22,11 +22,11 @@ def test_load_prefix(self): self.assertEqual("ns0", load_prefix("a", ns_map)) self.assertEqual("ns0", load_prefix("a", ns_map)) self.assertEqual("xs", load_prefix(Namespace.XS.uri, ns_map)) - self.assertEqual("soap-env", load_prefix(Namespace.SOAP_ENV.uri, ns_map)) + self.assertEqual("soapenv", load_prefix(Namespace.SOAP_ENV.uri, ns_map)) expected = { "ns0": "a", - "soap-env": "http://schemas.xmlsoap.org/soap/envelope/", + "soapenv": "http://schemas.xmlsoap.org/soap/envelope/", "xs": "http://www.w3.org/2001/XMLSchema", } self.assertEqual(expected, ns_map) @@ -35,13 +35,13 @@ def test_generate_prefix(self): ns_map: Dict = {} self.assertEqual("ns0", generate_prefix("a", ns_map)) self.assertEqual("xs", generate_prefix(Namespace.XS.uri, ns_map)) - self.assertEqual("soap-env", generate_prefix(Namespace.SOAP_ENV.uri, ns_map)) + self.assertEqual("soapenv", generate_prefix(Namespace.SOAP_ENV.uri, ns_map)) self.assertEqual("ns3", generate_prefix("b", ns_map)) expected = { "ns0": "a", "ns3": "b", - "soap-env": "http://schemas.xmlsoap.org/soap/envelope/", + "soapenv": "http://schemas.xmlsoap.org/soap/envelope/", "xs": "http://www.w3.org/2001/XMLSchema", } self.assertEqual(expected, ns_map) diff --git a/xsdata/codegen/handlers/class_designate.py b/xsdata/codegen/handlers/class_designate.py index c39d8fc28..563dfb783 100644 --- a/xsdata/codegen/handlers/class_designate.py +++ b/xsdata/codegen/handlers/class_designate.py @@ -15,6 +15,7 @@ from xsdata.codegen.models import get_location from xsdata.codegen.models import get_target_namespace from xsdata.exceptions import CodeGenerationError +from xsdata.models.config import ObjectType from xsdata.models.config import StructureStyle from xsdata.models.enums import COMMON_SCHEMA_DIR from xsdata.utils import collections @@ -141,8 +142,11 @@ def group_common_paths(cls, paths: Iterable[str]) -> List[List[str]]: def combine_ns_package(self, namespace: Optional[str]) -> List[str]: result = self.container.config.output.package.split(".") - aliases = self.container.config.aliases.package_name - alias = collections.first(x.target for x in aliases if x.source == namespace) + alias = collections.first( + alias.target + for alias in self.container.config.aliases.alias + if alias.type == ObjectType.PACKAGE and alias.source == namespace + ) if alias: result.extend(alias.split(".")) diff --git a/xsdata/formats/dataclass/filters.py b/xsdata/formats/dataclass/filters.py index 29702ce4e..0304d2ada 100644 --- a/xsdata/formats/dataclass/filters.py +++ b/xsdata/formats/dataclass/filters.py @@ -1,6 +1,7 @@ import math import re import textwrap +from collections import defaultdict from typing import Any from typing import Callable from typing import Dict @@ -20,18 +21,14 @@ from xsdata.codegen.models import Class from xsdata.formats.converter import converter from xsdata.models.config import DocstringStyle -from xsdata.models.config import GeneratorAlias from xsdata.models.config import GeneratorConfig +from xsdata.models.config import ObjectType from xsdata.models.config import OutputFormat from xsdata.utils import collections from xsdata.utils import namespaces from xsdata.utils import text -def index_aliases(aliases: List[GeneratorAlias]) -> Dict: - return {alias.source: alias.target for alias in aliases} - - class Filters: DEFAULT_KEY = "default" @@ -39,10 +36,8 @@ class Filters: UNESCAPED_DBL_QUOTE_REGEX = re.compile(r"([^\\])\"") __slots__ = ( - "class_aliases", - "field_aliases", - "package_aliases", - "module_aliases", + "aliases", + "substitutions", "class_case", "field_case", "constant_case", @@ -61,10 +56,15 @@ class Filters: ) def __init__(self, config: GeneratorConfig): - self.class_aliases: Dict = index_aliases(config.aliases.class_name) - self.field_aliases: Dict = index_aliases(config.aliases.field_name) - self.package_aliases: Dict = index_aliases(config.aliases.package_name) - self.module_aliases: Dict = index_aliases(config.aliases.module_name) + + self.aliases: Dict[ObjectType, Dict[str, str]] = defaultdict(dict) + for alias in config.aliases.alias: + self.aliases[alias.type][alias.source] = alias.target + + self.substitutions: Dict[ObjectType, Dict[str, str]] = defaultdict(dict) + for sub in config.substitutions.substitution: + self.substitutions[sub.type][sub.search] = sub.replace + self.class_case: Callable = config.conventions.class_name.case self.field_case: Callable = config.conventions.field_name.case self.constant_case: Callable = config.conventions.constant_name.case @@ -146,11 +146,17 @@ def class_params(self, obj: Class): def class_name(self, name: str) -> str: """Convert the given string to a class name according to the selected conventions or use an existing alias.""" - alias = self.class_aliases.get(name) - if alias: - return alias + alias = self.aliases[ObjectType.CLASS].get(name) + name = alias or self.safe_name(name, self.class_safe_prefix, self.class_case) - return self.safe_name(name, self.class_safe_prefix, self.class_case) + return self.apply_substitutions(name, self.substitutions[ObjectType.CLASS]) + + @classmethod + def apply_substitutions(cls, name: str, substitutions: Dict) -> str: + for search, replace in substitutions.items(): + name = name.replace(search, replace) + + return name def field_definition( self, @@ -183,14 +189,12 @@ def field_name(self, name: str, class_name: str) -> str: Provide the class name as context for the naming schemes. """ - alias = self.field_aliases.get(name) - if alias: - return alias - - return self.safe_name( + name = self.aliases[ObjectType.FIELD].get(name) or self.safe_name( name, self.field_safe_prefix, self.field_case, class_name=class_name ) + return self.apply_substitutions(name, self.substitutions[ObjectType.FIELD]) + def constant_name(self, name: str, class_name: str) -> str: """ Convert the given name to a constant name according to the selected @@ -198,42 +202,41 @@ def constant_name(self, name: str, class_name: str) -> str: Provide the class name as context for the naming schemes. """ - alias = self.field_aliases.get(name) - if alias: - return alias - - return self.safe_name( - name, self.constant_safe_prefix, self.constant_case, class_name=class_name + name = self.aliases[ObjectType.FIELD].get(name) or self.safe_name( + name, + self.constant_safe_prefix, + self.constant_case, + class_name=class_name, ) + return self.apply_substitutions(name, self.substitutions[ObjectType.FIELD]) + def module_name(self, name: str) -> str: """Convert the given string to a module name according to the selected conventions or use an existing alias.""" - alias = self.module_aliases.get(name) - if alias: - return alias - - return self.safe_name( + name = self.aliases[ObjectType.MODULE].get(name) or self.safe_name( namespaces.clean_uri(name), self.module_safe_prefix, self.module_case ) + return self.apply_substitutions(name, self.substitutions[ObjectType.MODULE]) + def package_name(self, name: str) -> str: """Convert the given string to a package name according to the selected conventions or use an existing alias.""" - alias = self.package_aliases.get(name) - if alias: - return alias + name = self.aliases[ObjectType.PACKAGE].get(name) or name if not name: return name - return ".".join( - self.package_aliases.get(part) + name = ".".join( + self.aliases[ObjectType.PACKAGE].get(part) or self.safe_name(part, self.package_safe_prefix, self.package_case) for part in name.split(".") ) + return self.apply_substitutions(name, self.substitutions[ObjectType.PACKAGE]) + def type_name(self, attr_type: AttrType) -> str: """Return native python type name or apply class name conventions.""" datatype = attr_type.datatype diff --git a/xsdata/models/config.py b/xsdata/models/config.py index 505813d3e..5c5f8f4fa 100644 --- a/xsdata/models/config.py +++ b/xsdata/models/config.py @@ -19,6 +19,7 @@ from xsdata.formats.dataclass.serializers import XmlSerializer from xsdata.formats.dataclass.serializers.config import SerializerConfig from xsdata.formats.dataclass.serializers.writers import XmlEventWriter +from xsdata.models.enums import Namespace from xsdata.models.mixins import array_element from xsdata.models.mixins import attribute from xsdata.models.mixins import element @@ -131,6 +132,22 @@ class DocstringStyle(Enum): BLANK = "Blank" +class ObjectType(Enum): + """ + Object type enumeration. + + :cvar CLASS: class + :cvar FIELD: field + :cvar MODULE: module + :cvar PACKAGE: package + """ + + CLASS = "class" + FIELD = "field" + MODULE = "module" + PACKAGE = "package" + + @dataclass class OutputFormat: """ @@ -265,10 +282,12 @@ class GeneratorAlias: filename or target namespace depending the selected output structure. + :param type: :param type: The target object type :param source: The source name from schema definition :param target: The target name of the object. """ + type: ObjectType = attribute(required=True) source: str = attribute(required=True) target: str = attribute(required=True) @@ -276,22 +295,46 @@ class GeneratorAlias: @dataclass class GeneratorAliases: """ - Generator aliases for classes, fields, packages and modules that bypass the - global naming conventions. + Generator aliases for classes, fields, packages and modules names. The + process overrides the naming conventions. .. warning:: The generator doesn't validate aliases. - :param class_name: list of class name aliases - :param field_name: list of field name aliases - :param package_name: list of package name aliases - :param module_name: list of module name aliases + :param alias: The list of aliases + """ + + alias: List[GeneratorAlias] = array_element() + + +@dataclass +class GeneratorSubstitution: + """ + Search and replace substitution for a specific target type. + + :param type: The target object type + :param search: The search case sensitive string value + :param replace: The replacement case sensitive string value + """ + + type: ObjectType = attribute(required=True) + search: str = attribute(required=True) + replace: str = attribute(required=True) + + +@dataclass +class GeneratorSubstitutions: + """ + Generator search and replace substitutions for classes, fields, packages + and modules names. The process overrides both aliases and naming + conventions. + + .. warning:: The generator doesn't validate substitutions. + + :param substitution: The list of substitutions """ - class_name: List[GeneratorAlias] = array_element() - field_name: List[GeneratorAlias] = array_element() - package_name: List[GeneratorAlias] = array_element() - module_name: List[GeneratorAlias] = array_element() + substitution: List[GeneratorSubstitution] = array_element() @dataclass @@ -313,19 +356,23 @@ class Meta: output: GeneratorOutput = element(default_factory=GeneratorOutput) conventions: GeneratorConventions = element(default_factory=GeneratorConventions) aliases: GeneratorAliases = element(default_factory=GeneratorAliases) + substitutions: GeneratorSubstitutions = element( + default_factory=GeneratorSubstitutions + ) @classmethod def create(cls) -> "GeneratorConfig": obj = cls() - obj.aliases.class_name.append(GeneratorAlias("fooType", "Foo")) - obj.aliases.class_name.append(GeneratorAlias("ABCSomething", "ABCSomething")) - obj.aliases.field_name.append( - GeneratorAlias("ChangeofGauge", "change_of_gauge") - ) - obj.aliases.package_name.append( - GeneratorAlias("http://www.w3.org/1999/xhtml", "xtml") + + for ns in Namespace: + obj.aliases.alias.append( + GeneratorAlias(type=ObjectType.PACKAGE, source=ns.uri, target=ns.prefix) + ) + + obj.substitutions.substitution.append( + GeneratorSubstitution(type=ObjectType.CLASS, search="Class", replace="Type") ) - obj.aliases.module_name.append(GeneratorAlias("2010.1", "2020a")) + return obj @classmethod @@ -341,8 +388,16 @@ def read(cls, path: Path) -> "GeneratorConfig": fail_on_unknown_properties=False, fail_on_converter_warnings=True, ) + + # I already hate it but it's needed to maintain compatibility + xml_str = path.read_text() + xml_str = xml_str.replace("