Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT Function to visualize skops files #317

Merged
Merged
4 changes: 2 additions & 2 deletions skops/io/_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def __init__(
trusted: bool | Sequence[str] = False,
) -> None:
super().__init__(state, load_context, trusted)
type = state["type"]
self.type = state["type"]
self.trusted = self._get_trusted(trusted, [spmatrix])
if type != "scipy":
if self.type != "scipy":
raise TypeError(
f"Cannot load object of type {self.module_name}.{self.class_name}"
)
Expand Down
262 changes: 262 additions & 0 deletions skops/io/_visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
from __future__ import annotations

import json
from dataclasses import dataclass
from functools import singledispatch
from pathlib import Path
from typing import Callable, Iterator, Literal, Sequence
from zipfile import ZipFile

from ._audit import Node, get_tree
from ._general import FunctionNode, JsonNode
from ._numpy import NdArrayNode
from ._scipy import SparseMatrixNode
from ._utils import LoadContext


@dataclass
class PrintConfig:
# fmt: off
tag_safe: str = "" # noqa: E222
tag_unsafe: str = " [UNSAFE]"

line_start: str = "├─"
line: str = "──" # noqa: E222

use_colors: bool = True
color_safe: str = '\033[32m' # green # noqa: E222
color_unsafe: str = '\033[31m' # red # noqa: E222
color_child_unsafe: str = '\033[33m' # yellow
color_end: str = '\033[0m' # noqa: E222
# fmt: on


print_config = PrintConfig()


@dataclass
class FormattedNode:
level: int
key: str # the key to the node
val: str # the value of the node
visible: bool # whether it should be shown


def pretty_print_tree(
formatted_nodes: Iterator[FormattedNode], config: PrintConfig
) -> None:
# TODO: the "tree" lines could be made prettier since all nodes are known
# here
for formatted_node in formatted_nodes:
if not formatted_node.visible:
continue

line = print_config.line_start
line += (formatted_node.level - 1) * print_config.line
line += f"{formatted_node.key}: {formatted_node.val}"
print(line)


def _check_visibility(
node: Node,
node_is_safe: bool,
node_and_children_are_safe: bool,
show: Literal["all", "untrusted", "trusted"],
) -> bool:
if show == "all":
should_print = True
elif show == "untrusted":
should_print = not node_and_children_are_safe
else: # only trusted
should_print = node_is_safe
return should_print


def _check_node_and_children_safe(node: Node, trusted: bool | Sequence[str]) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be okay to move this to nodes and cache them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, there is already node.is_safe() >__<

# Note: this is very inefficient, because get_unsafe_set will be called many
# times on the same node (since parents recursively call children) but maybe
# that's acceptable for this context. If not, caching could be an option.
if trusted is True:
node_and_children_are_safe = True
elif trusted is False:
node_and_children_are_safe = not node.get_unsafe_set()
else:
node_and_children_are_safe = not (node.get_unsafe_set() - set(trusted))
return node_and_children_are_safe


# use singledispatch so that we can register specialized visualization functions
@singledispatch
def format_node(node: Node) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also don't mind moving these to nodes.

"""Format the name of the node.

By default, this is just the fully qualified name of the class, e.g.
``"sklearn.preprocessing._data.MinMaxScaler"``. But for some types of nodes,
having a more specific output is desirable. These node types can be
registered with this function.

"""
return f"{node.module_name}.{node.class_name}"


@format_node.register
def _format_function_node(node: FunctionNode) -> str:
# if a FunctionNode, children are not visited, but safety should still be checked
child = node.children["content"]
fn_name = f"{child['module_path']}.{child['function']}"
return f"{node.module_name}.{node.class_name} => {fn_name}"


@format_node.register
def _format_json_node(node: JsonNode) -> str:
return f"json-type({node.content})"


def walk_tree(
node: Node | dict[str, Node] | Sequence[Node],
trusted: bool | Sequence[str] = False,
show: Literal["all", "untrusted", "trusted"] = "all",
node_name: str = "root",
level: int = 0,
config: PrintConfig = print_config,
) -> Iterator[FormattedNode]:
# helper function to pretty-print the nodes
if node_name == "key_types":
# _check_key_types_schema(node)
return

# COMPOSITE TYPES: CHECK ALL ITEMS
if isinstance(node, dict):
for key, val in node.items():
yield from walk_tree(
val,
node_name=key,
level=level,
trusted=trusted,
show=show,
config=config,
)
return

if isinstance(node, (list, tuple)):
for val in node:
yield from walk_tree(
val,
node_name=node_name,
level=level,
trusted=trusted,
show=show,
config=config,
)
return

# NO MATCH: RAISE ERROR
if not isinstance(node, Node):
raise TypeError(f"{type(node)}")

# THE ACTUAL FORMATTING HAPPENS HERE
node_is_safe = node.is_self_safe()
node_and_children_are_safe = _check_node_and_children_safe(node, trusted)
visible = _check_visibility(
node,
node_is_safe=node_is_safe,
node_and_children_are_safe=node_and_children_are_safe,
show=show,
)

node_val = format_node(node)
tag = config.tag_safe if node_is_safe else config.tag_unsafe
if tag:
node_val += f" {tag}"

if config.use_colors:
if node_and_children_are_safe:
color = config.color_safe
elif node_is_safe:
color = config.color_child_unsafe
else:
color = config.color_unsafe
node_val = f"{color}{node_val}{config.color_end}"

yield FormattedNode(level=level, key=node_name, val=node_val, visible=visible)

# TYPES WHOSE CHILDREN IT MAKES NO SENSE TO VISIT
if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type != "json"):
# _check_array_schema(node)
return

if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type == "json"):
# _check_array_json_schema(node)
return

if isinstance(node, FunctionNode):
# _check_function_schema(node)
return

if isinstance(node, JsonNode):
# _check_json_schema(node)
pass

# RECURSE
yield from walk_tree(
node.children,
node_name=node_name,
level=level + 1,
trusted=trusted,
show=show,
config=config,
)


def visualize_tree(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it matters that it's a tree. What about just visualize? Or inspect (although there's the inspect module)?

file: Path | str, # TODO: from bytes
trusted: bool | Sequence[str] = False,
show: Literal["all", "untrusted", "trusted"] = "all",
sink: Callable[[Iterator[FormattedNode], PrintConfig], None] = pretty_print_tree,
print_config: PrintConfig = print_config,
) -> None:
"""Visualize the contents of a skops file.

Shows the schema of a skops file as a tree view. In particular, highlights
untrusted nodes. A node is considered untrusted if at least one of its child
nodes is untrusted.

Parameters
----------
file: str or pathlib.Path
The file name of the object to be loaded.

trusted: bool, or list of str, default=False
If ``True``, the object will be loaded without any security checks. If
``False``, the object will be loaded only if there are only trusted
objects in the dumped file. If a list of strings, the object will be
loaded only if there are only trusted objects and objects of types
listed in ``trusted`` are in the dumped file.

show: "all" or "untrusted" or "trusted"
Whether to print all nodes, only untrusted nodes, or only trusted nodes.

sink: function
This function should take two arguments, an iterator of
``FormattedNode`` and a ``PrintConfig``. The ``FormattedNode`` contains
the information about the node, namely:

- the level of nesting (int)
- the key of the node (str)
- the value of the node as a string representation (str)
- the visibility of the node, depending on the ``show`` argument (bool)

The second argument is the print config (see description of next argument).

print_config: :class:`~PrintConfig`
The ``PrintConfig`` is a simple object with attributes that determine
how the node should be visualized, e.g. the ``use_colors`` attribute
determines if colors should be used.

"""
with ZipFile(file, "r") as zip_file:
schema = json.loads(zip_file.read("schema.json"))
tree = get_tree(schema, load_context=LoadContext(src=zip_file))

nodes = walk_tree(tree, trusted=trusted, show=show, config=print_config)
sink(nodes, print_config)
79 changes: 79 additions & 0 deletions skops/io/tests/test_visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import pytest
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import FeatureUnion, Pipeline
from sklearn.preprocessing import (
FunctionTransformer,
MinMaxScaler,
PolynomialFeatures,
StandardScaler,
)

import skops.io as sio
from skops.io._visualize import visualize_tree


class TestVisualizeTree:
@pytest.fixture
def simple(self):
return MinMaxScaler(feature_range=(-555, 123))

@pytest.fixture
def simple_file(self, simple, tmp_path):
f_name = tmp_path / "estimator.skops"
sio.dump(simple, f_name)
return f_name

@pytest.fixture
def pipeline(self):
# fmt: off
pipeline = Pipeline([
("features", FeatureUnion([
("scaler", StandardScaler()),
("scaled-poly", Pipeline([
("polys", FeatureUnion([
("poly1", PolynomialFeatures()),
("poly2", PolynomialFeatures(degree=3, include_bias=False))
])),
("square-root", FunctionTransformer(np.sqrt)),
("scale", MinMaxScaler()),
])),
])),
("clf", LogisticRegression(random_state=0, solver="liblinear")),
]).fit([[0, 1], [2, 3], [4, 5]], [0, 1, 2])
# fmt: on
return pipeline

@pytest.fixture
def pipeline_file(self, pipeline, tmp_path):
f_name = tmp_path / "estimator.skops"
sio.dump(pipeline, f_name)
return f_name

@pytest.mark.parametrize("show", ["all", "trusted", "untrusted"])
def test_print_simple(self, simple_file, show):
visualize_tree(simple_file, show=show)

@pytest.mark.parametrize(
"show_tell", [("all", 8), ("trusted", 8), ("untrusted", 0)]
)
def test_inspect_simple(self, simple_file, show_tell):
nodes = []
show, expected_elements = show_tell
visualize_tree(simple_file, sink=lambda n, _: nodes.extend(list(n)), show=show)
assert len([node for node in nodes if node.visible]) == expected_elements

@pytest.mark.parametrize("show", ["all", "trusted", "untrusted"])
def test_print_pipeline(self, pipeline_file, show):
visualize_tree(pipeline_file, show=show)

@pytest.mark.parametrize(
"show_tell", [("all", 129), ("trusted", 127), ("untrusted", 19)]
)
def test_inspect_pipeline(self, pipeline_file, show_tell):
nodes = []
show, expected_elements = show_tell
visualize_tree(
pipeline_file, sink=lambda n, _: nodes.extend(list(n)), show=show
)
assert len([node for node in nodes if node.visible]) == expected_elements