Skip to content

Commit

Permalink
ENH support types with a safe __reduce__
Browse files Browse the repository at this point in the history
  • Loading branch information
adrinjalali committed Jan 31, 2025
1 parent f032cc0 commit aa70377
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 6 deletions.
5 changes: 0 additions & 5 deletions skops/card/_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,6 @@ def _check_table(self) -> None:
raise ValueError("Trying to add table with no columns")

def format(self) -> str:
if self._is_pandas_df:
pass # type: ignore
else:
self.table.keys()

table = PrettyTable()
table.set_style(TableStyle.MARKDOWN)
for key, values in self.table.items():
Expand Down
3 changes: 2 additions & 1 deletion skops/io/_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@ def get_tree(
type_name = f"{state['__module__']}.{state['__class__']}"
raise TypeError(
f" Can't find loader {state['__loader__']} for type {type_name} and "
f"protocol {protocol}."
f"protocol {protocol}. You might need to update skops to load this "
"file."
)

loaded_tree = node_cls(state, load_context, trusted=trusted)
Expand Down
38 changes: 38 additions & 0 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ def object_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
# This method is for objects which can either be persisted with json, or
# the ones for which we can get/set attributes through
# __getstate__/__setstate__ or reading/writing to __dict__.

# We first check if the object can be serialized using json.
try:
# if we can simply use json, then we're done.
obj_str = json.dumps(obj)
Expand All @@ -405,6 +407,23 @@ def object_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
except Exception:
pass

# Then we check if the output of __reduce__ is of the form
# (constructor, (constructor_args,))
# If the constructor is the same as the object's type, then we consider it
# safe to call it with the specified arguments.

reduce_output = obj.__reduce__()
# note that we do "=="" to compare types instead of "is", since we only accept
# exact matches here.
if len(reduce_output) == 2 and reduce_output[0] == type(obj):
return {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "ConstructorFromReduceNode",
"content": get_state(reduce_output[1], save_context),
}

# Otherwise we recover the object from the __dict__ or __getstate__
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
Expand All @@ -427,6 +446,24 @@ def object_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
return res


class ConstructorFromReduceNode(Node):
def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
trusted: Optional[Sequence[str]] = None,
) -> None:
super().__init__(state, load_context, trusted)
self.children = {
"content": get_tree(state["content"], load_context, trusted=trusted)
}

def _construct(self):
return gettype(self.module_name, self.class_name)(
*self.children["content"].construct()
)


class ObjectNode(Node):
def __init__(
self,
Expand Down Expand Up @@ -670,6 +707,7 @@ def _construct(self):
("MethodNode", PROTOCOL): MethodNode,
("PartialNode", PROTOCOL): PartialNode,
("TypeNode", PROTOCOL): TypeNode,
("ConstructorFromReduceNode", PROTOCOL): ConstructorFromReduceNode,
("ObjectNode", PROTOCOL): ObjectNode,
("JsonNode", PROTOCOL): JsonNode,
("OperatorFuncNode", PROTOCOL): OperatorFuncNode,
Expand Down
15 changes: 15 additions & 0 deletions skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import warnings
from collections import Counter, OrderedDict, defaultdict
from datetime import datetime
from functools import partial, wraps
from pathlib import Path
from zipfile import ZIP_DEFLATED, ZipFile
Expand Down Expand Up @@ -1119,3 +1120,17 @@ def test_dictionary(cls):
loaded_obj = loads(dumps(obj))
assert obj == loaded_obj
assert type(obj) is cls


def test_datetime():
obj = datetime.now()
loaded_obj = loads(dumps(obj), trusted=[datetime])
assert obj == loaded_obj
assert type(obj) is datetime


def test_slice():
obj = slice(1, 2, 3)
loaded_obj = loads(dumps(obj))
assert obj == loaded_obj
assert type(obj) is slice

0 comments on commit aa70377

Please sign in to comment.