-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT Function to visualize skops files #317
Changes from 4 commits
4e3dba2
5a955e5
e753d25
5960d81
50a2125
37bf95f
f91b313
80bb8af
eb56d88
3cd95d6
0e6c71f
1853dcf
8e3a735
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,262 @@ | ||
from __future__ import annotations | ||
|
||
import json | ||
from dataclasses import dataclass | ||
from functools import singledispatch | ||
from pathlib import Path | ||
from typing import Callable, Iterator, Literal, Sequence | ||
from zipfile import ZipFile | ||
|
||
from ._audit import Node, get_tree | ||
from ._general import FunctionNode, JsonNode | ||
from ._numpy import NdArrayNode | ||
from ._scipy import SparseMatrixNode | ||
from ._utils import LoadContext | ||
|
||
|
||
@dataclass | ||
class PrintConfig: | ||
# fmt: off | ||
tag_safe: str = "" # noqa: E222 | ||
tag_unsafe: str = " [UNSAFE]" | ||
|
||
line_start: str = "├─" | ||
line: str = "──" # noqa: E222 | ||
|
||
use_colors: bool = True | ||
color_safe: str = '\033[32m' # green # noqa: E222 | ||
color_unsafe: str = '\033[31m' # red # noqa: E222 | ||
color_child_unsafe: str = '\033[33m' # yellow | ||
color_end: str = '\033[0m' # noqa: E222 | ||
# fmt: on | ||
|
||
|
||
print_config = PrintConfig() | ||
|
||
|
||
@dataclass | ||
class FormattedNode: | ||
level: int | ||
key: str # the key to the node | ||
val: str # the value of the node | ||
visible: bool # whether it should be shown | ||
|
||
|
||
def pretty_print_tree( | ||
formatted_nodes: Iterator[FormattedNode], config: PrintConfig | ||
) -> None: | ||
# TODO: the "tree" lines could be made prettier since all nodes are known | ||
# here | ||
for formatted_node in formatted_nodes: | ||
if not formatted_node.visible: | ||
continue | ||
|
||
line = print_config.line_start | ||
line += (formatted_node.level - 1) * print_config.line | ||
line += f"{formatted_node.key}: {formatted_node.val}" | ||
print(line) | ||
|
||
|
||
def _check_visibility( | ||
node: Node, | ||
node_is_safe: bool, | ||
node_and_children_are_safe: bool, | ||
show: Literal["all", "untrusted", "trusted"], | ||
) -> bool: | ||
if show == "all": | ||
should_print = True | ||
elif show == "untrusted": | ||
should_print = not node_and_children_are_safe | ||
else: # only trusted | ||
should_print = node_is_safe | ||
return should_print | ||
|
||
|
||
def _check_node_and_children_safe(node: Node, trusted: bool | Sequence[str]) -> bool: | ||
# Note: this is very inefficient, because get_unsafe_set will be called many | ||
# times on the same node (since parents recursively call children) but maybe | ||
# that's acceptable for this context. If not, caching could be an option. | ||
if trusted is True: | ||
node_and_children_are_safe = True | ||
elif trusted is False: | ||
node_and_children_are_safe = not node.get_unsafe_set() | ||
else: | ||
node_and_children_are_safe = not (node.get_unsafe_set() - set(trusted)) | ||
return node_and_children_are_safe | ||
|
||
|
||
# use singledispatch so that we can register specialized visualization functions | ||
@singledispatch | ||
def format_node(node: Node) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also don't mind moving these to nodes. |
||
"""Format the name of the node. | ||
|
||
By default, this is just the fully qualified name of the class, e.g. | ||
``"sklearn.preprocessing._data.MinMaxScaler"``. But for some types of nodes, | ||
having a more specific output is desirable. These node types can be | ||
registered with this function. | ||
|
||
""" | ||
return f"{node.module_name}.{node.class_name}" | ||
|
||
|
||
@format_node.register | ||
def _format_function_node(node: FunctionNode) -> str: | ||
# if a FunctionNode, children are not visited, but safety should still be checked | ||
child = node.children["content"] | ||
fn_name = f"{child['module_path']}.{child['function']}" | ||
return f"{node.module_name}.{node.class_name} => {fn_name}" | ||
|
||
|
||
@format_node.register | ||
def _format_json_node(node: JsonNode) -> str: | ||
return f"json-type({node.content})" | ||
|
||
|
||
def walk_tree( | ||
node: Node | dict[str, Node] | Sequence[Node], | ||
trusted: bool | Sequence[str] = False, | ||
show: Literal["all", "untrusted", "trusted"] = "all", | ||
node_name: str = "root", | ||
level: int = 0, | ||
config: PrintConfig = print_config, | ||
) -> Iterator[FormattedNode]: | ||
# helper function to pretty-print the nodes | ||
if node_name == "key_types": | ||
# _check_key_types_schema(node) | ||
return | ||
|
||
# COMPOSITE TYPES: CHECK ALL ITEMS | ||
if isinstance(node, dict): | ||
for key, val in node.items(): | ||
yield from walk_tree( | ||
val, | ||
node_name=key, | ||
level=level, | ||
trusted=trusted, | ||
show=show, | ||
config=config, | ||
) | ||
return | ||
|
||
if isinstance(node, (list, tuple)): | ||
for val in node: | ||
yield from walk_tree( | ||
val, | ||
node_name=node_name, | ||
level=level, | ||
trusted=trusted, | ||
show=show, | ||
config=config, | ||
) | ||
return | ||
|
||
# NO MATCH: RAISE ERROR | ||
if not isinstance(node, Node): | ||
raise TypeError(f"{type(node)}") | ||
|
||
# THE ACTUAL FORMATTING HAPPENS HERE | ||
node_is_safe = node.is_self_safe() | ||
node_and_children_are_safe = _check_node_and_children_safe(node, trusted) | ||
visible = _check_visibility( | ||
node, | ||
node_is_safe=node_is_safe, | ||
node_and_children_are_safe=node_and_children_are_safe, | ||
show=show, | ||
) | ||
|
||
node_val = format_node(node) | ||
tag = config.tag_safe if node_is_safe else config.tag_unsafe | ||
if tag: | ||
node_val += f" {tag}" | ||
|
||
if config.use_colors: | ||
if node_and_children_are_safe: | ||
color = config.color_safe | ||
elif node_is_safe: | ||
color = config.color_child_unsafe | ||
else: | ||
color = config.color_unsafe | ||
node_val = f"{color}{node_val}{config.color_end}" | ||
|
||
yield FormattedNode(level=level, key=node_name, val=node_val, visible=visible) | ||
|
||
# TYPES WHOSE CHILDREN IT MAKES NO SENSE TO VISIT | ||
if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type != "json"): | ||
# _check_array_schema(node) | ||
return | ||
|
||
if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type == "json"): | ||
# _check_array_json_schema(node) | ||
return | ||
|
||
if isinstance(node, FunctionNode): | ||
# _check_function_schema(node) | ||
return | ||
|
||
if isinstance(node, JsonNode): | ||
# _check_json_schema(node) | ||
pass | ||
|
||
# RECURSE | ||
yield from walk_tree( | ||
node.children, | ||
node_name=node_name, | ||
level=level + 1, | ||
trusted=trusted, | ||
show=show, | ||
config=config, | ||
) | ||
|
||
|
||
def visualize_tree( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it matters that it's a tree. What about just |
||
file: Path | str, # TODO: from bytes | ||
trusted: bool | Sequence[str] = False, | ||
show: Literal["all", "untrusted", "trusted"] = "all", | ||
sink: Callable[[Iterator[FormattedNode], PrintConfig], None] = pretty_print_tree, | ||
print_config: PrintConfig = print_config, | ||
) -> None: | ||
"""Visualize the contents of a skops file. | ||
|
||
Shows the schema of a skops file as a tree view. In particular, highlights | ||
untrusted nodes. A node is considered untrusted if at least one of its child | ||
nodes is untrusted. | ||
|
||
Parameters | ||
---------- | ||
file: str or pathlib.Path | ||
The file name of the object to be loaded. | ||
|
||
trusted: bool, or list of str, default=False | ||
If ``True``, the object will be loaded without any security checks. If | ||
``False``, the object will be loaded only if there are only trusted | ||
objects in the dumped file. If a list of strings, the object will be | ||
loaded only if there are only trusted objects and objects of types | ||
listed in ``trusted`` are in the dumped file. | ||
|
||
show: "all" or "untrusted" or "trusted" | ||
Whether to print all nodes, only untrusted nodes, or only trusted nodes. | ||
|
||
sink: function | ||
This function should take two arguments, an iterator of | ||
``FormattedNode`` and a ``PrintConfig``. The ``FormattedNode`` contains | ||
the information about the node, namely: | ||
|
||
- the level of nesting (int) | ||
- the key of the node (str) | ||
- the value of the node as a string representation (str) | ||
- the visibility of the node, depending on the ``show`` argument (bool) | ||
|
||
The second argument is the print config (see description of next argument). | ||
|
||
print_config: :class:`~PrintConfig` | ||
The ``PrintConfig`` is a simple object with attributes that determine | ||
how the node should be visualized, e.g. the ``use_colors`` attribute | ||
determines if colors should be used. | ||
|
||
""" | ||
with ZipFile(file, "r") as zip_file: | ||
schema = json.loads(zip_file.read("schema.json")) | ||
tree = get_tree(schema, load_context=LoadContext(src=zip_file)) | ||
|
||
nodes = walk_tree(tree, trusted=trusted, show=show, config=print_config) | ||
sink(nodes, print_config) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import numpy as np | ||
import pytest | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.pipeline import FeatureUnion, Pipeline | ||
from sklearn.preprocessing import ( | ||
FunctionTransformer, | ||
MinMaxScaler, | ||
PolynomialFeatures, | ||
StandardScaler, | ||
) | ||
|
||
import skops.io as sio | ||
from skops.io._visualize import visualize_tree | ||
|
||
|
||
class TestVisualizeTree: | ||
@pytest.fixture | ||
def simple(self): | ||
return MinMaxScaler(feature_range=(-555, 123)) | ||
|
||
@pytest.fixture | ||
def simple_file(self, simple, tmp_path): | ||
f_name = tmp_path / "estimator.skops" | ||
sio.dump(simple, f_name) | ||
return f_name | ||
|
||
@pytest.fixture | ||
def pipeline(self): | ||
# fmt: off | ||
pipeline = Pipeline([ | ||
("features", FeatureUnion([ | ||
("scaler", StandardScaler()), | ||
("scaled-poly", Pipeline([ | ||
("polys", FeatureUnion([ | ||
("poly1", PolynomialFeatures()), | ||
("poly2", PolynomialFeatures(degree=3, include_bias=False)) | ||
])), | ||
("square-root", FunctionTransformer(np.sqrt)), | ||
("scale", MinMaxScaler()), | ||
])), | ||
])), | ||
("clf", LogisticRegression(random_state=0, solver="liblinear")), | ||
]).fit([[0, 1], [2, 3], [4, 5]], [0, 1, 2]) | ||
# fmt: on | ||
return pipeline | ||
|
||
@pytest.fixture | ||
def pipeline_file(self, pipeline, tmp_path): | ||
f_name = tmp_path / "estimator.skops" | ||
sio.dump(pipeline, f_name) | ||
return f_name | ||
|
||
@pytest.mark.parametrize("show", ["all", "trusted", "untrusted"]) | ||
def test_print_simple(self, simple_file, show): | ||
visualize_tree(simple_file, show=show) | ||
|
||
@pytest.mark.parametrize( | ||
"show_tell", [("all", 8), ("trusted", 8), ("untrusted", 0)] | ||
) | ||
def test_inspect_simple(self, simple_file, show_tell): | ||
nodes = [] | ||
show, expected_elements = show_tell | ||
visualize_tree(simple_file, sink=lambda n, _: nodes.extend(list(n)), show=show) | ||
assert len([node for node in nodes if node.visible]) == expected_elements | ||
|
||
@pytest.mark.parametrize("show", ["all", "trusted", "untrusted"]) | ||
def test_print_pipeline(self, pipeline_file, show): | ||
visualize_tree(pipeline_file, show=show) | ||
|
||
@pytest.mark.parametrize( | ||
"show_tell", [("all", 129), ("trusted", 127), ("untrusted", 19)] | ||
) | ||
def test_inspect_pipeline(self, pipeline_file, show_tell): | ||
nodes = [] | ||
show, expected_elements = show_tell | ||
visualize_tree( | ||
pipeline_file, sink=lambda n, _: nodes.extend(list(n)), show=show | ||
) | ||
assert len([node for node in nodes if node.visible]) == expected_elements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it'd be okay to move this to nodes and cache them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah, there is already
node.is_safe()
>__<