Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT Refactor trusted #338

Merged
merged 9 commits into from
Apr 6, 2023
42 changes: 23 additions & 19 deletions skops/io/_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ._utils import LoadContext, get_module, get_type_paths
from .exceptions import UntrustedTypesFoundException

NODE_TYPE_MAPPING: dict[tuple[str, int], Node] = {}
NODE_TYPE_MAPPING: dict[tuple[str, int], Type[Node]] = {}
VALID_NODE_CHILD_TYPES = Optional[
Union["Node", List["Node"], Dict[str, "Node"], Type, str, io.BytesIO]
]
Expand Down Expand Up @@ -43,7 +43,7 @@ def check_type(
return module_name + "." + type_name in trusted


def audit_tree(tree: Node, trusted: bool | Sequence[str]) -> None:
def audit_tree(tree: Node) -> None:
"""Audit a tree of nodes.

A tree is safe if it only contains trusted types. Audit is skipped if
Expand All @@ -54,24 +54,15 @@ def audit_tree(tree: Node, trusted: bool | Sequence[str]) -> None:
tree : skops.io._dispatch.Node
The tree to audit.

trusted : True, or list of str
If ``True``, the tree is considered safe. Otherwise trusted has to be
a list of trusted types names.

An entry in the list is typically of the form
``skops.io._utils.get_module(obj) + "." + obj.__class__.__name__``.

Raises
------
UntrustedTypesFoundException
If the tree contains an untrusted type.
"""
if trusted is True:
if tree.trusted is True:
return

unsafe = tree.get_unsafe_set()
if isinstance(trusted, (list, set)):
unsafe -= set(trusted)
if unsafe:
raise UntrustedTypesFoundException(unsafe)

Expand Down Expand Up @@ -193,10 +184,12 @@ def _get_trusted(
) -> Literal[True] | list[str]:
"""Return a trusted list, or True.

If ``trusted`` is ``False``, we return the ``default``, otherwise the
``trusted`` value is used.
If ``trusted`` is ``False``, we return the ``default``. If a list of
types are being passed, those types, as well as default trusted types,
are returned.

This is a convenience method called by child classes.

"""
if trusted is True:
# if trusted is True, we trust the node
Expand All @@ -206,8 +199,8 @@ def _get_trusted(
# if trusted is False, we only trust the defaults
return get_type_paths(default)

# otherwise, we trust the given list
return get_type_paths(trusted)
# otherwise, we trust the given list and default trusted types
return get_type_paths(trusted) + get_type_paths(default)

def is_self_safe(self) -> bool:
"""True only if the node's type is considered safe.
Expand Down Expand Up @@ -314,10 +307,14 @@ def _construct(self):
return self.cached.construct()


NODE_TYPE_MAPPING[("CachedNode", PROTOCOL)] = CachedNode # type: ignore
NODE_TYPE_MAPPING[("CachedNode", PROTOCOL)] = CachedNode


def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node:
def get_tree(
state: dict[str, Any],
load_context: LoadContext,
trusted: bool | Sequence[str],
) -> Node:
"""Get the tree of nodes.

This function returns the root node of the tree of nodes. The tree is
Expand All @@ -336,6 +333,13 @@ def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node:
load_context : LoadContext
The context of the loading process.

trusted : bool, or list of str
If ``True``, the object will be loaded without any security checks. If
``False``, the object will be loaded only if there are only trusted
objects in the dumped file. If a list of strings, the object will be
loaded only if there are only trusted objects and objects of types
listed in ``trusted`` in the dumped file.

Returns
-------
loaded_tree : Node
Expand Down Expand Up @@ -372,5 +376,5 @@ def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node:
f"protocol {protocol}."
)

loaded_tree = node_cls(state, load_context, trusted=False) # type: ignore
loaded_tree = node_cls(state, load_context, trusted=trusted)
return loaded_tree
36 changes: 24 additions & 12 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def __init__(
super().__init__(state, load_context, trusted)
self.trusted = self._get_trusted(trusted, [dict])
self.children = {
"key_types": get_tree(state["key_types"], load_context),
"key_types": get_tree(state["key_types"], load_context, trusted=trusted),
"content": {
key: get_tree(value, load_context)
key: get_tree(value, load_context, trusted=trusted)
for key, value in state["content"].items()
},
}
Expand Down Expand Up @@ -96,7 +96,10 @@ def __init__(
super().__init__(state, load_context, trusted)
self.trusted = self._get_trusted(trusted, [list])
self.children = {
"content": [get_tree(value, load_context) for value in state["content"]]
"content": [
get_tree(value, load_context, trusted=trusted)
for value in state["content"]
]
}

def _construct(self):
Expand Down Expand Up @@ -125,7 +128,10 @@ def __init__(
super().__init__(state, load_context, trusted)
self.trusted = self._get_trusted(trusted, [set])
self.children = {
"content": [get_tree(value, load_context) for value in state["content"]]
"content": [
get_tree(value, load_context, trusted=trusted)
for value in state["content"]
]
}

def _construct(self):
Expand Down Expand Up @@ -154,7 +160,10 @@ def __init__(
super().__init__(state, load_context, trusted)
self.trusted = self._get_trusted(trusted, [tuple])
self.children = {
"content": [get_tree(value, load_context) for value in state["content"]]
"content": [
get_tree(value, load_context, trusted=trusted)
for value in state["content"]
]
}

def _construct(self):
Expand Down Expand Up @@ -241,10 +250,12 @@ def __init__(
# TODO: should we trust anything?
self.trusted = self._get_trusted(trusted, [])
self.children = {
"func": get_tree(state["content"]["func"], load_context),
"args": get_tree(state["content"]["args"], load_context),
"kwds": get_tree(state["content"]["kwds"], load_context),
"namespace": get_tree(state["content"]["namespace"], load_context),
"func": get_tree(state["content"]["func"], load_context, trusted=trusted),
"args": get_tree(state["content"]["args"], load_context, trusted=trusted),
"kwds": get_tree(state["content"]["kwds"], load_context, trusted=trusted),
"namespace": get_tree(
state["content"]["namespace"], load_context, trusted=trusted
),
}

def _construct(self):
Expand Down Expand Up @@ -375,7 +386,7 @@ def __init__(

content = state.get("content")
if content is not None:
attrs = get_tree(content, load_context)
attrs = get_tree(content, load_context, trusted=trusted)
else:
attrs = None

Expand Down Expand Up @@ -432,7 +443,7 @@ def __init__(
) -> None:
super().__init__(state, load_context, trusted)
self.children = {
"obj": get_tree(state["content"]["obj"], load_context),
"obj": get_tree(state["content"]["obj"], load_context, trusted=trusted),
"func": state["content"]["func"],
}
# TODO: what do we trust?
Expand All @@ -458,6 +469,7 @@ def __init__(
super().__init__(state, load_context, trusted)
self.content = state["content"]
self.children = {}
self.trusted = self._get_trusted(trusted, PRIMITIVE_TYPE_NAMES)

def is_safe(self) -> bool:
# JsonNode is always considered safe.
Expand Down Expand Up @@ -552,7 +564,7 @@ def __init__(
) -> None:
super().__init__(state, load_context, trusted)
self.trusted = self._get_trusted(trusted, [])
self.children["attrs"] = get_tree(state["attrs"], load_context)
self.children["attrs"] = get_tree(state["attrs"], load_context, trusted=trusted)

def _construct(self):
op = getattr(operator, self.class_name)
Expand Down
21 changes: 13 additions & 8 deletions skops/io/_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def __init__(
}
elif self.type == "json":
self.children = {
"content": [ # type: ignore
get_tree(o, load_context) for o in state["content"] # type: ignore
"content": [
get_tree(o, load_context, trusted=trusted) for o in state["content"]
],
"shape": get_tree(state["shape"], load_context),
"shape": get_tree(state["shape"], load_context, trusted=trusted),
}
else:
raise ValueError(f"Unknown type {self.type}.")
Expand Down Expand Up @@ -126,8 +126,8 @@ def __init__(
super().__init__(state, load_context, trusted)
self.trusted = self._get_trusted(trusted, [np.ma.MaskedArray])
self.children = {
"data": get_tree(state["content"]["data"], load_context),
"mask": get_tree(state["content"]["mask"], load_context),
"data": get_tree(state["content"]["data"], load_context, trusted=trusted),
"mask": get_tree(state["content"]["mask"], load_context, trusted=trusted),
}

def _construct(self):
Expand Down Expand Up @@ -155,7 +155,10 @@ def __init__(
trusted: bool | Sequence[str] = False,
) -> None:
super().__init__(state, load_context, trusted)
self.children = {"content": get_tree(state["content"], load_context)}
# TODO
self.children = {
"content": get_tree(state["content"], load_context, trusted=trusted)
}
self.trusted = self._get_trusted(trusted, [np.random.RandomState])

def _construct(self):
Expand Down Expand Up @@ -185,7 +188,7 @@ def __init__(
super().__init__(state, load_context, trusted)
self.children = {
"bit_generator_state": get_tree(
state["content"]["bit_generator"], load_context
state["content"]["bit_generator"], load_context, trusted=trusted
)
}
self.trusted = self._get_trusted(trusted, [np.random.Generator])
Expand Down Expand Up @@ -224,7 +227,9 @@ def __init__(
trusted: bool | Sequence[str] = False,
) -> None:
super().__init__(state, load_context, trusted)
self.children = {"content": get_tree(state["content"], load_context)}
self.children = {
"content": get_tree(state["content"], load_context, trusted=trusted)
}
# TODO: what should we trust?
self.trusted = self._get_trusted(trusted, [])

Expand Down
17 changes: 8 additions & 9 deletions skops/io/_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def load(file: str | Path, trusted: bool | Sequence[str] = False) -> Any:
``False``, the object will be loaded only if there are only trusted
objects in the dumped file. If a list of strings, the object will be
loaded only if there are only trusted objects and objects of types
listed in ``trusted`` are in the dumped file.
listed in ``trusted`` in the dumped file.

Returns
-------
Expand All @@ -127,8 +127,8 @@ def load(file: str | Path, trusted: bool | Sequence[str] = False) -> Any:
with ZipFile(file, "r") as input_zip:
schema = json.loads(input_zip.read("schema.json"))
load_context = LoadContext(src=input_zip, protocol=schema["protocol"])
tree = get_tree(schema, load_context)
audit_tree(tree, trusted)
tree = get_tree(schema, load_context, trusted=trusted)
audit_tree(tree)
instance = tree.construct()

return instance
Expand All @@ -154,7 +154,7 @@ def loads(data: bytes, trusted: bool | Sequence[str] = False) -> Any:
``False``, the object will be loaded only if there are only trusted
objects in the dumped file. If a list of strings, the object will be
loaded only if there are only trusted objects and objects of types
listed in ``trusted`` are in the dumped file.
listed in ``trusted`` in the dumped file.

Returns
-------
Expand All @@ -167,8 +167,8 @@ def loads(data: bytes, trusted: bool | Sequence[str] = False) -> Any:
with ZipFile(io.BytesIO(data), "r") as zip_file:
schema = json.loads(zip_file.read("schema.json"))
load_context = LoadContext(src=zip_file, protocol=schema["protocol"])
tree = get_tree(schema, load_context)
audit_tree(tree, trusted)
tree = get_tree(schema, load_context, trusted=trusted)
audit_tree(tree)
instance = tree.construct()

return instance
Expand Down Expand Up @@ -210,9 +210,8 @@ def get_untrusted_types(

with ZipFile(content, "r") as zip_file:
schema = json.loads(zip_file.read("schema.json"))
tree = get_tree(
schema, load_context=LoadContext(src=zip_file, protocol=schema["protocol"])
)
load_context = LoadContext(src=zip_file, protocol=schema["protocol"])
tree = get_tree(schema, load_context=load_context, trusted=False)
untrusted_types = tree.get_unsafe_set()

return sorted(untrusted_types)
10 changes: 6 additions & 4 deletions skops/io/_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def __init__(
super().__init__(state, load_context, trusted)
reduce = state["__reduce__"]
self.children = {
"attrs": get_tree(state["content"], load_context),
"args": get_tree(reduce["args"], load_context),
"attrs": get_tree(state["content"], load_context, trusted=trusted),
"args": get_tree(reduce["args"], load_context, trusted=trusted),
"constructor": constructor,
}

Expand Down Expand Up @@ -210,9 +210,11 @@ def __init__(
get_module(_DictWithDeprecatedKeysNode) + "._DictWithDeprecatedKeys"
]
self.children = {
"main": get_tree(state["content"]["main"], load_context),
"main": get_tree(state["content"]["main"], load_context, trusted=trusted),
"_deprecated_key_to_new_key": get_tree(
state["content"]["_deprecated_key_to_new_key"], load_context
state["content"]["_deprecated_key_to_new_key"],
load_context,
trusted=trusted,
),
}

Expand Down
13 changes: 10 additions & 3 deletions skops/io/_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Iterator, Literal
from typing import Any, Callable, Iterator, Literal, Sequence
from zipfile import ZipFile

from ._audit import VALID_NODE_CHILD_TYPES, Node, get_tree
Expand Down Expand Up @@ -281,7 +281,9 @@ def walk_tree(

def visualize(
file: Path | str | bytes,
*,
show: Literal["all", "untrusted", "trusted"] = "all",
trusted: bool | Sequence[str] = False,
sink: Callable[..., None] = pretty_print_tree,
**kwargs: Any,
) -> None:
Expand All @@ -307,8 +309,13 @@ def visualize(
show: "all" or "untrusted" or "trusted"
Whether to print all nodes, only untrusted nodes, or only trusted nodes.

sink: function (default=:func:`~pretty_print_tree`)
trusted: bool, or list of str, default=False
If ``True``, all nodes will be treated as trusted. If ``False``, only
default types are trusted. If a list of strings, where those strongs
describe the trusted types, these types are trusted on top of the
default trusted types.

sink: function (default=:func:`~pretty_print_tree`)
This function should take at least two arguments, an iterator of
:class:`~NodeInfo` instances and an indicator of what to show. The
``NodeInfo`` contains the information about the node, namely:
Expand Down Expand Up @@ -348,7 +355,7 @@ def visualize(
with zf as zip_file:
schema = json.loads(zip_file.read("schema.json"))
load_context = LoadContext(src=zip_file, protocol=schema["protocol"])
tree = get_tree(schema, load_context=load_context)
tree = get_tree(schema, load_context=load_context, trusted=trusted)

nodes = walk_tree(tree)
# TODO: it would be nice to print html representation if inside a notebook
Expand Down
Loading