From 4a24ed6a879433a0b541b3a8357a298370808893 Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Wed, 28 Jun 2023 09:32:51 +0800 Subject: [PATCH 1/4] feat: allow users to register custom encoders Signed-off-by: Frost Ming --- tests/test_items.py | 14 ++++++++++++++ tomlkit/__init__.py | 4 ++++ tomlkit/api.py | 22 ++++++++++++++++++++++ tomlkit/items.py | 27 +++++++++++++++++++++++++-- 4 files changed, 65 insertions(+), 2 deletions(-) diff --git a/tests/test_items.py b/tests/test_items.py index 45aea258..bbdeb00d 100644 --- a/tests/test_items.py +++ b/tests/test_items.py @@ -946,3 +946,17 @@ def test_copy_copy(): ) def test_escape_key(key_str, escaped): assert api.key(key_str).as_string() == escaped + + +def test_custom_encoders(): + import decimal + + @api.register_encoder + def encode_decimal(obj): + if isinstance(obj, decimal.Decimal): + return api.float_(str(obj)) + raise TypeError + + assert api.item(decimal.Decimal("1.23")).as_string() == "1.23" + assert api.dumps({"foo": decimal.Decimal("1.23")}) == "foo = 1.23\n" + api.unregister_encoder(encode_decimal) diff --git a/tomlkit/__init__.py b/tomlkit/__init__.py index acc7046c..c2dd53fe 100644 --- a/tomlkit/__init__.py +++ b/tomlkit/__init__.py @@ -18,9 +18,11 @@ from tomlkit.api import loads from tomlkit.api import nl from tomlkit.api import parse +from tomlkit.api import register_encoder from tomlkit.api import string from tomlkit.api import table from tomlkit.api import time +from tomlkit.api import unregister_encoder from tomlkit.api import value from tomlkit.api import ws @@ -52,4 +54,6 @@ "TOMLDocument", "value", "ws", + "register_encoder", + "unregister_encoder", ] diff --git a/tomlkit/api.py b/tomlkit/api.py index 8ec5653c..686fd1c0 100644 --- a/tomlkit/api.py +++ b/tomlkit/api.py @@ -1,14 +1,17 @@ from __future__ import annotations +import contextlib import datetime as _datetime from collections.abc import Mapping from typing import IO from typing import Iterable +from typing import TypeVar from tomlkit._utils import parse_rfc3339 from tomlkit.container import Container from tomlkit.exceptions import UnexpectedCharError +from tomlkit.items import CUSTOM_ENCODERS from tomlkit.items import AoT from tomlkit.items import Array from tomlkit.items import Bool @@ -16,6 +19,7 @@ from tomlkit.items import Date from tomlkit.items import DateTime from tomlkit.items import DottedKey +from tomlkit.items import Encoder from tomlkit.items import Float from tomlkit.items import InlineTable from tomlkit.items import Integer @@ -284,3 +288,21 @@ def nl() -> Whitespace: def comment(string: str) -> Comment: """Create a comment item.""" return Comment(Trivia(comment_ws=" ", comment="# " + string)) + + +E = TypeVar("E", bound=Encoder) + + +def register_encoder(encoder: E) -> E: + """Add a custom encoder, which should be a function that will be called + if the value can't otherwise be converted. It should takes a single value + and return a TOMLKit item or raise a ``TypeError``. + """ + CUSTOM_ENCODERS.append(encoder) + return encoder + + +def unregister_encoder(encoder: Encoder) -> None: + """Unregister a custom encoder.""" + with contextlib.suppress(ValueError): + CUSTOM_ENCODERS.remove(encoder) diff --git a/tomlkit/items.py b/tomlkit/items.py index 683c1893..b0651d3e 100644 --- a/tomlkit/items.py +++ b/tomlkit/items.py @@ -13,6 +13,7 @@ from enum import Enum from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Collection from typing import Iterable from typing import Iterator @@ -57,6 +58,15 @@ class _CustomDict(MutableMapping, dict): ItemT = TypeVar("ItemT", bound="Item") +Encoder = Callable[[Any], "Item"] +CUSTOM_ENCODERS: list[Encoder] = [] + + +class _EncodeError(TypeError, ValueError): + """An internal error raised when item() fails to encode a value. + It should be a TypeError, but due to historical reasons + it needs to subclass ValueError as well. + """ @overload @@ -218,8 +228,21 @@ def item(value: Any, _parent: Item | None = None, _sort_keys: bool = False) -> I Trivia(), value.isoformat(), ) - - raise ValueError(f"Invalid type {type(value)}") + else: + for encoder in CUSTOM_ENCODERS: + try: + rv = encoder(value) + except TypeError: + pass + else: + if not isinstance(rv, Item): + raise _EncodeError( + f"Custom encoder {encoder} returned {type(rv)}, " + f"expected Item" + ) + return rv + + raise _EncodeError(f"Invalid type {type(value)}") class StringType(Enum): From b4cb96eb2f88a6d4975d86e256c4bdace518cdcc Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Wed, 28 Jun 2023 09:34:28 +0800 Subject: [PATCH 2/4] rename the error Signed-off-by: Frost Ming --- tomlkit/items.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tomlkit/items.py b/tomlkit/items.py index b0651d3e..3a2eec9c 100644 --- a/tomlkit/items.py +++ b/tomlkit/items.py @@ -62,8 +62,8 @@ class _CustomDict(MutableMapping, dict): CUSTOM_ENCODERS: list[Encoder] = [] -class _EncodeError(TypeError, ValueError): - """An internal error raised when item() fails to encode a value. +class _ConvertError(TypeError, ValueError): + """An internal error raised when item() fails to convert a value. It should be a TypeError, but due to historical reasons it needs to subclass ValueError as well. """ @@ -236,13 +236,13 @@ def item(value: Any, _parent: Item | None = None, _sort_keys: bool = False) -> I pass else: if not isinstance(rv, Item): - raise _EncodeError( + raise _ConvertError( f"Custom encoder {encoder} returned {type(rv)}, " f"expected Item" ) return rv - raise _EncodeError(f"Invalid type {type(value)}") + raise _ConvertError(f"Invalid type {type(value)}") class StringType(Enum): From 327eed0c97a3eb6b2b19113bf09c509ecfeb4c68 Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Wed, 28 Jun 2023 09:36:33 +0800 Subject: [PATCH 3/4] add more tests Signed-off-by: Frost Ming --- tests/test_items.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_items.py b/tests/test_items.py index bbdeb00d..485e47fb 100644 --- a/tests/test_items.py +++ b/tests/test_items.py @@ -958,5 +958,9 @@ def encode_decimal(obj): raise TypeError assert api.item(decimal.Decimal("1.23")).as_string() == "1.23" + + with pytest.raises(TypeError): + api.item(object()) + assert api.dumps({"foo": decimal.Decimal("1.23")}) == "foo = 1.23\n" api.unregister_encoder(encode_decimal) From fdb51422bdd09d1ab6f3dc4f247ff2e9bb3280e4 Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Wed, 28 Jun 2023 09:42:30 +0800 Subject: [PATCH 4/4] improve the message Signed-off-by: Frost Ming --- tomlkit/items.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tomlkit/items.py b/tomlkit/items.py index 3a2eec9c..41dccc37 100644 --- a/tomlkit/items.py +++ b/tomlkit/items.py @@ -237,8 +237,7 @@ def item(value: Any, _parent: Item | None = None, _sort_keys: bool = False) -> I else: if not isinstance(rv, Item): raise _ConvertError( - f"Custom encoder {encoder} returned {type(rv)}, " - f"expected Item" + f"Custom encoder returned {type(rv)}, not a subclass of Item" ) return rv