Skip to content

Commit

Permalink
cleanup and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
adrinjalali committed Feb 13, 2025
1 parent dd8a00c commit da97b31
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
6 changes: 2 additions & 4 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,11 +413,9 @@ def object_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
# safe to call it with the specified arguments.

reduce_output = obj.__reduce__()
# note that we do "=="" to compare types instead of "is", since we only accept
# exact matches here.
if len(reduce_output) == 2 and reduce_output[0] == type(obj):
if len(reduce_output) == 2 and reduce_output[0] is type(obj):
return {
"__class__": obj.__class__.__name__,
"__class__": type(obj).__name__,
"__module__": get_module(type(obj)),
"__loader__": "ConstructorFromReduceNode",
"content": get_state(reduce_output[1], save_context),
Expand Down
28 changes: 28 additions & 0 deletions skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,3 +1134,31 @@ def test_slice():
loaded_obj = loads(dumps(obj))
assert obj == loaded_obj
assert type(obj) is slice


# This class is here as opposed to inside the test because it needs to be importable.
reduce_calls = 0


class CustomReduce:
def __init__(self, value):
self.value = value

def __reduce__(self):
global reduce_calls
reduce_calls += 1
return (type(self), (self.value,))


def test_custom_reduce():
obj = CustomReduce(10)
dumped = dumps(obj)

# make sure __reduce__ is called, once.
assert reduce_calls == 1

with pytest.raises(TypeError, match="Untrusted types found"):
loads(dumped)

loaded_obj = loads(dumps(obj), trusted=[CustomReduce])
assert obj.value == loaded_obj.value

0 comments on commit da97b31

Please sign in to comment.