Skip to content

Commit

Permalink
A little hardening of the auditing of Nodes
Browse files Browse the repository at this point in the history
Two measures to harden the auditing (a little bit):

- Type annotate the Node's children to prevent setting invalid types.
- Change all the tests that use loads to only load trusted types instead
  of using trusted=True

The latter is importent because when setting trusted=True, the whole
machinery of checking types is not executed, so any bugs that may be
contained there will not be revealed. In particular, this shows that for
persisting methods, we had a child with a str type and that would raise
an error, i.e. loading method types was not possible for users who
passed trusted!=True.

Additional changes

As a consequence of the last point, the auditing code has been changed
to accept str as type. Alternatively, we can make the change explained
here:

skops-dev#338 (comment)

i.e. not storing the method name in children.

Another "victim" of this change is that the so far dead code of checking
for primitive types inside of get_unsafe_set has been removed. This code
was supposed to check if the type is a primitive type but it was
defective. get_module(child) would raise an error if an instance of the
type would be passed. We could theoretically fix that code, but it would
still be dead code because primitive types are stored as json.

Another small change is to exclude the code in skops/io/old from mypy
checks. Otherwise, we would have to update its type signatures if
signatures in the persistence code change.
  • Loading branch information
BenjaminBossan committed Apr 5, 2023
1 parent e18aa1a commit beaa234
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 30 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ omit = [
]

[tool.mypy]
exclude = "(\\w+/)*test_\\w+\\.py$"
exclude = "(\\w+/)*test_\\w+\\.py$|old"
ignore_missing_imports = true
no_implicit_optional = true
20 changes: 9 additions & 11 deletions skops/io/_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

import io
from contextlib import contextmanager
from typing import Any, Generator, Literal, Sequence, Type, Union
from typing import Any, Generator, Literal, Optional, Sequence, Type, Union

from ._protocol import PROTOCOL
from ._trusted_types import PRIMITIVE_TYPE_NAMES
from ._utils import LoadContext, get_module, get_type_paths
from .exceptions import UntrustedTypesFoundException

NODE_TYPE_MAPPING: dict[tuple[str, int], Node] = {}
VALID_NODE_CHILD_TYPES = Optional[
Union["Node", list["Node"], dict[str, "Node"], Type, str, io.BytesIO]
]


def check_type(
Expand Down Expand Up @@ -168,7 +170,7 @@ def __init__(
# 3. set self.children, where children are states of child nodes; do not
# construct the children objects yet
self.trusted = self._get_trusted(trusted, [])
self.children: dict[str, Any] = {}
self.children: dict[str, VALID_NODE_CHILD_TYPES] = {}

def construct(self):
"""Construct the object.
Expand Down Expand Up @@ -269,15 +271,11 @@ def get_unsafe_set(self) -> set[str]:
if not check_type(get_module(child), child.__name__, self.trusted):
# if the child is a type, we check its safety
res.add(get_module(child) + "." + child.__name__)
elif isinstance(child, io.BytesIO):
elif isinstance(child, (io.BytesIO, str)):
# We trust BytesIO objects, which are read by other
# libraries such as numpy, scipy.
continue
elif check_type(
get_module(child), child.__class__.__name__, PRIMITIVE_TYPE_NAMES
):
# if the child is a primitive type, we don't need to check its
# safety.
# libraries such as numpy, scipy. We trust str but have to
# be careful that anything with str is dealt with
# appropriately.
continue
else:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions skops/io/_sklearn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Callable, Sequence, Type
from typing import Any, Sequence, Type

from sklearn.cluster import Birch

Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
constructor: Type[Any] | Callable[..., Any],
constructor: Type[Any],
trusted: bool | Sequence[str] = False,
) -> None:
super().__init__(state, load_context, trusted)
Expand Down
4 changes: 2 additions & 2 deletions skops/io/_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Callable, Iterator, Literal
from zipfile import ZipFile

from ._audit import Node, get_tree
from ._audit import VALID_NODE_CHILD_TYPES, Node, get_tree
from ._general import FunctionNode, JsonNode, ListNode
from ._numpy import NdArrayNode
from ._scipy import SparseMatrixNode
Expand Down Expand Up @@ -168,7 +168,7 @@ def pretty_print_tree(


def walk_tree(
node: Node | dict[str, Node] | list[Node],
node: VALID_NODE_CHILD_TYPES | dict[str, VALID_NODE_CHILD_TYPES],
node_name: str = "root",
level: int = 0,
is_last: bool = False,
Expand Down
52 changes: 38 additions & 14 deletions skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ def _unsupported_estimators(type_filter=None):
)
def test_can_persist_non_fitted(estimator):
"""Check that non-fitted estimators can be persisted."""
loaded = loads(dumps(estimator), trusted=True)
dumped = dumps(estimator)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)
assert_params_equal(estimator.get_params(), loaded.get_params())


Expand Down Expand Up @@ -458,7 +460,9 @@ def split(self, X, **kwargs):
)
def test_cross_validator(cv):
est = CVEstimator(cv=cv).fit(None, None)
loaded = loads(dumps(est), trusted=True)
dumped = dumps(est)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)
X, y = make_classification(
n_samples=N_SAMPLES, n_features=N_FEATURES, random_state=0
)
Expand Down Expand Up @@ -500,7 +504,9 @@ def test_numpy_object_dtype_2d_array(transpose):
if transpose:
est.obj_array_ = est.obj_array_.T

loaded = loads(dumps(est), trusted=True)
dumped = dumps(est)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)
assert_params_equal(est.__dict__, loaded.__dict__)


Expand Down Expand Up @@ -615,7 +621,8 @@ def test_identical_numpy_arrays_not_duplicated():
X = np.random.random((10, 5))
estimator = EstimatorIdenticalArrays().fit(X)
dumped = dumps(estimator)
loaded = loads(dumped, trusted=True)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)
assert_params_equal(estimator.__dict__, loaded.__dict__)

# check number of numpy arrays stored on disk
Expand Down Expand Up @@ -719,7 +726,9 @@ def test_for_base_case_returns_as_expected(self):
bound_function = obj.bound_method
transformer = FunctionTransformer(func=bound_function)

loaded_transformer = loads(dumps(transformer), trusted=True)
dumped = dumps(transformer)
untrusted_types = get_untrusted_types(data=dumped)
loaded_transformer = loads(dumped, trusted=untrusted_types)
loaded_obj = loaded_transformer.func.__self__

self.assert_transformer_persisted_correctly(loaded_transformer, transformer)
Expand All @@ -736,7 +745,9 @@ def test_when_object_is_changed_after_init_works_as_expected(self):

transformer = FunctionTransformer(func=bound_function)

loaded_transformer = loads(dumps(transformer), trusted=True)
dumped = dumps(transformer)
untrusted_types = get_untrusted_types(data=dumped)
loaded_transformer = loads(dumped, trusted=untrusted_types)
loaded_obj = loaded_transformer.func.__self__

self.assert_transformer_persisted_correctly(loaded_transformer, transformer)
Expand All @@ -749,19 +760,23 @@ def test_works_when_given_multiple_bound_methods_attached_to_single_instance(sel
func=obj.bound_method, inverse_func=obj.other_bound_method
)

loaded_transformer = loads(dumps(transformer), trusted=True)
dumped = dumps(transformer)
untrusted_types = get_untrusted_types(data=dumped)
loaded_transformer = loads(dumped, trusted=untrusted_types)

# check that both func and inverse_func are from the same object instance
loaded_0 = loaded_transformer.func.__self__
loaded_1 = loaded_transformer.inverse_func.__self__
assert loaded_0 is loaded_1

@pytest.mark.xfail(reason="Failing due to circular self reference")
@pytest.mark.xfail(reason="Failing due to circular self reference", strict=True)
def test_scipy_stats(self, tmp_path):
from scipy import stats

estimator = FunctionTransformer(func=stats.zipf)
loads(dumps(estimator), trusted=True)
dumped = dumps(estimator)
untrusted_types = get_untrusted_types(data=dumped)
loads(dumped, trusted=untrusted_types)


class CustomEstimator(BaseEstimator):
Expand Down Expand Up @@ -862,7 +877,9 @@ def test_dump_and_load_with_file_wrapper(tmp_path):
)
def test_when_given_object_referenced_twice_loads_as_one_object(obj):
an_object = {"obj_1": obj, "obj_2": obj}
persisted_object = loads(dumps(an_object), trusted=True)
dumped = dumps(an_object)
untrusted_types = get_untrusted_types(data=dumped)
persisted_object = loads(dumped, trusted=untrusted_types)

assert persisted_object["obj_1"] is persisted_object["obj_2"]

Expand All @@ -876,7 +893,9 @@ def fit(self, X, y, **fit_params):

def test_estimator_with_bytes():
est = EstimatorWithBytes().fit(None, None)
loaded = loads(dumps(est), trusted=True)
dumped = dumps(est)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)
assert_params_equal(est.__dict__, loaded.__dict__)


Expand Down Expand Up @@ -934,13 +953,17 @@ def test_persist_operator(op):
_, func = op
# unfitted
est = FunctionTransformer(func)
loaded = loads(dumps(est), trusted=True)
dumped = dumps(est)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)
assert_params_equal(est.__dict__, loaded.__dict__)

# fitted
X, y = get_input(est)
est.fit(X, y)
loaded = loads(dumps(est), trusted=True)
dumped = dumps(est)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)
assert_params_equal(est.__dict__, loaded.__dict__)

# Technically, we don't need to call transform. However, if this is skipped,
Expand Down Expand Up @@ -973,7 +996,8 @@ def test_persist_function(func):
estimator.fit(X, y)

dumped = dumps(estimator)
loaded = loads(dumped, trusted=True)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)

# check that loaded estimator is identical
assert_params_equal(estimator.__dict__, loaded.__dict__)
Expand Down

0 comments on commit beaa234

Please sign in to comment.