Skip to content

Commit

Permalink
ENH correctly restore default_factory of a defaultdict (#433)
Browse files Browse the repository at this point in the history
* ENH correctly restore default_factory of a defaultdict

* MNT add changelog for new version and bump version

* add OrderedDict trusted

* rename builtin type names to container type names

* add test for OrderedDict

* test type as well
  • Loading branch information
adrinjalali authored Aug 8, 2024
1 parent 17d6b2e commit 4467309
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 7 deletions.
7 changes: 6 additions & 1 deletion docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@ skops Changelog
:depth: 1
:local:

v0.11
-----
- Correctly restore ``default_factory`` when saving and loading a ``defaultdict``.
:pr:`433` by `Adrin Jalali`_.

v0.10
----
-----
- Removes Pythn 3.8 support and adds Python 3.12 Support :pr:`418` by :user:`Thomas Lazarus <lazarust>`.
- Removes a shortcut to add `sklearn-intelex` as a not dependency.
:pr:`420` by :user:`Thomas Lazarus < lazarust > `.
Expand Down
48 changes: 46 additions & 2 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import operator
import uuid
from collections import defaultdict
from functools import partial
from reprlib import Repr
from types import FunctionType, MethodType
Expand All @@ -14,6 +15,7 @@
from ._audit import Node, get_tree
from ._protocol import PROTOCOL
from ._trusted_types import (
CONTAINER_TYPE_NAMES,
NUMPY_DTYPE_TYPE_NAMES,
NUMPY_UFUNC_TYPE_NAMES,
PRIMITIVE_TYPE_NAMES,
Expand Down Expand Up @@ -63,7 +65,7 @@ def __init__(
trusted: Optional[Sequence[str]] = None,
) -> None:
super().__init__(state, load_context, trusted)
self.trusted = self._get_trusted(trusted, [dict])
self.trusted = self._get_trusted(trusted, [dict, "collections.OrderedDict"])
self.children = {
"key_types": get_tree(state["key_types"], load_context, trusted=trusted),
"content": {
Expand All @@ -80,6 +82,45 @@ def _construct(self):
return content


def defaultdict_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "DefaultDictNode",
}
content = {}
# explicitly pass a dict object instead of _DictWithDeprecatedKeys and
# later construct a _DictWithDeprecatedKeys object.
content["main"] = get_state(dict(obj), save_context)
content["default_factory"] = get_state(obj.default_factory, save_context)
res["content"] = content
return res


class DefaultDictNode(Node):
def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
trusted: Optional[Sequence[str]] = None,
) -> None:
super().__init__(state, load_context, trusted)
self.trusted = ["collections.defaultdict"]
self.children = {
"main": get_tree(state["content"]["main"], load_context, trusted=trusted),
"default_factory": get_tree(
state["content"]["default_factory"],
load_context,
trusted=trusted,
),
}

def _construct(self):
instance = defaultdict(**self.children["main"].construct())
instance.default_factory = self.children["default_factory"].construct()
return instance


def list_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
Expand Down Expand Up @@ -298,7 +339,8 @@ def __init__(
super().__init__(state, load_context, trusted)
# TODO: what do we trust?
self.trusted = self._get_trusted(
trusted, PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES
trusted,
PRIMITIVE_TYPE_NAMES + CONTAINER_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES,
)
# We use a bare Node type here since a Node only checks the type in the
# dict using __class__ and __module__ keys.
Expand Down Expand Up @@ -597,6 +639,7 @@ def _construct(self):
# tuples of type and function that gets the state of that type
GET_STATE_DISPATCH_FUNCTIONS = [
(dict, dict_get_state),
(defaultdict, defaultdict_get_state),
(list, list_get_state),
(set, set_get_state),
(tuple, tuple_get_state),
Expand All @@ -616,6 +659,7 @@ def _construct(self):

NODE_TYPE_MAPPING = {
("DictNode", PROTOCOL): DictNode,
("DefaultDictNode", PROTOCOL): DefaultDictNode,
("ListNode", PROTOCOL): ListNode,
("SetNode", PROTOCOL): SetNode,
("TupleNode", PROTOCOL): TupleNode,
Expand Down
4 changes: 4 additions & 0 deletions skops/io/_trusted_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

PRIMITIVE_TYPE_NAMES = ["builtins." + t.__name__ for t in PRIMITIVES_TYPES]

CONTAINER_TYPES = [list, set, map, tuple]

CONTAINER_TYPE_NAMES = ["builtins." + t.__name__ for t in CONTAINER_TYPES]

SKLEARN_ESTIMATOR_TYPE_NAMES = [
get_type_name(estimator_class)
for _, estimator_class in all_estimators()
Expand Down
2 changes: 1 addition & 1 deletion skops/io/tests/test_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def lgbm(self):
def trusted(self):
# TODO: adjust once more types are trusted by default
return [
"collections.defaultdict",
"collections.OrderedDict",
"lightgbm.basic.Booster",
"lightgbm.sklearn.LGBMClassifier",
"lightgbm.sklearn.LGBMRegressor",
Expand Down
28 changes: 25 additions & 3 deletions skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import operator
import sys
import warnings
from collections import Counter
from collections import Counter, OrderedDict, defaultdict
from functools import partial, wraps
from pathlib import Path
from zipfile import ZIP_DEFLATED, ZipFile
Expand Down Expand Up @@ -56,6 +56,7 @@
from skops.io._audit import NODE_TYPE_MAPPING, get_tree
from skops.io._sklearn import UNSUPPORTED_TYPES
from skops.io._trusted_types import (
CONTAINER_TYPE_NAMES,
NUMPY_DTYPE_TYPE_NAMES,
NUMPY_UFUNC_TYPE_NAMES,
PRIMITIVE_TYPE_NAMES,
Expand Down Expand Up @@ -247,7 +248,9 @@ def _tested_ufuncs():


def _tested_types():
for full_name in PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES:
for full_name in (
PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES + CONTAINER_TYPE_NAMES
):
module_name, _, type_name = full_name.rpartition(".")
yield gettype(module_name=module_name, cls_or_func=type_name)

Expand Down Expand Up @@ -396,7 +399,9 @@ def test_can_trust_ufuncs(ufunc):


@pytest.mark.parametrize(
"type_", _tested_types(), ids=PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES
"type_",
_tested_types(),
ids=PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES + CONTAINER_TYPE_NAMES,
)
def test_can_trust_types(type_):
dumped = dumps(type_)
Expand Down Expand Up @@ -1078,3 +1083,20 @@ def test_trusted_bool_raises(tmp_path):

with pytest.raises(TypeError, match="trusted must be a list of strings"):
loads(dumps(10), trusted=True) # type: ignore


def test_defaultdict():
"""Test that we correctly restore a defaultdict."""
obj = defaultdict(set)
obj["foo"] = "bar"
obj_loaded = loads(dumps(obj))
assert obj_loaded == obj
assert obj_loaded.default_factory == obj.default_factory


@pytest.mark.parametrize("cls", [dict, OrderedDict])
def test_dictionary(cls):
obj = cls({1: 5, 6: 3, 2: 4})
loaded_obj = loads(dumps(obj))
assert obj == loaded_obj
assert type(obj) is cls

0 comments on commit 4467309

Please sign in to comment.