Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Dealing with skops persistence protocol updates #322

1 change: 1 addition & 0 deletions skops/hub_utils/_hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def _create_config(
does not support it. For more info, see
https://intel.github.io/scikit-learn-intelex/.
"""

# so that we don't have to explicitly add keys and they're added as a
# dictionary if they are not found
# see: https://stackoverflow.com/a/13151294/2536294
Expand Down
34 changes: 24 additions & 10 deletions skops/io/_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from contextlib import contextmanager
from typing import Any, Generator, Literal, Sequence, Type, Union

from ._protocol import PROTOCOL
from ._trusted_types import PRIMITIVE_TYPE_NAMES
from ._utils import LoadContext, get_module, get_type_paths
from .exceptions import UntrustedTypesFoundException

NODE_TYPE_MAPPING = {} # type: ignore
NODE_TYPE_MAPPING: dict[tuple[str, int], Node] = {}


def check_type(
Expand Down Expand Up @@ -311,7 +312,7 @@ def _construct(self):
return self.cached.construct()


NODE_TYPE_MAPPING["CachedNode"] = CachedNode
NODE_TYPE_MAPPING[("CachedNode", PROTOCOL)] = CachedNode # type: ignore


def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node:
Expand Down Expand Up @@ -347,14 +348,27 @@ def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node:
# node's ``construct`` method caches the instance.
return load_context.get_object(saved_id)

try:
node_cls = NODE_TYPE_MAPPING[state["__loader__"]]
except KeyError:
type_name = f"{state['__module__']}.{state['__class__']}"
raise TypeError(
f" Can't find loader {state['__loader__']} for type {type_name}."
)
loader: str = state["__loader__"]
protocol = load_context.protocol
key = (loader, protocol)

if key in NODE_TYPE_MAPPING:
node_cls = NODE_TYPE_MAPPING[key]
else:
# What probably happened here is that we released a new protocol. If
# there is no specific key for the old protocol, it means it is safe to
# use the current protocol instead, because this node was not changed.
key_new = (loader, PROTOCOL)
try:
node_cls = NODE_TYPE_MAPPING[key_new]
except KeyError:
# If we still cannot find the loader for this key, something went
# wrong.
type_name = f"{state['__module__']}.{state['__class__']}"
raise TypeError(
f" Can't find loader {state['__loader__']} for type {type_name} and "
f"protocol {protocol}."
)

loaded_tree = node_cls(state, load_context, trusted=False) # type: ignore

return loaded_tree
53 changes: 22 additions & 31 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np

from ._audit import Node, get_tree
from ._protocol import PROTOCOL
from ._trusted_types import (
PRIMITIVE_TYPE_NAMES,
SCIPY_UFUNC_TYPE_NAMES,
Expand Down Expand Up @@ -180,13 +181,9 @@ def isnamedtuple(self, t) -> bool:

def function_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__class__": obj.__name__,
"__module__": get_module(obj),
"__loader__": "FunctionNode",
"content": {
"module_path": get_module(obj),
"function": obj.__name__,
},
}
return res

Expand All @@ -201,26 +198,20 @@ def __init__(
super().__init__(state, load_context, trusted)
# TODO: what do we trust?
self.trusted = self._get_trusted(trusted, default=SCIPY_UFUNC_TYPE_NAMES)
self.children = {"content": state["content"]}
self.children = {}

def _construct(self):
return _import_obj(
self.children["content"]["module_path"],
self.children["content"]["function"],
)
return gettype(self.module_name, self.class_name)

def _get_function_name(self) -> str:
return (
self.children["content"]["module_path"]
+ "."
+ self.children["content"]["function"]
)
return f"{self.module_name}.{self.class_name}"

def get_unsafe_set(self) -> set[str]:
if (self.trusted is True) or (self._get_function_name() in self.trusted):
fn_name = self._get_function_name()
if (self.trusted is True) or (fn_name in self.trusted):
return set()

return {self._get_function_name()}
return {fn_name}


def partial_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
Expand Down Expand Up @@ -586,18 +577,18 @@ def _construct(self):
]

NODE_TYPE_MAPPING = {
"DictNode": DictNode,
"ListNode": ListNode,
"SetNode": SetNode,
"TupleNode": TupleNode,
"BytesNode": BytesNode,
"BytearrayNode": BytearrayNode,
"SliceNode": SliceNode,
"FunctionNode": FunctionNode,
"MethodNode": MethodNode,
"PartialNode": PartialNode,
"TypeNode": TypeNode,
"ObjectNode": ObjectNode,
"JsonNode": JsonNode,
"OperatorFuncNode": OperatorFuncNode,
("DictNode", PROTOCOL): DictNode,
("ListNode", PROTOCOL): ListNode,
("SetNode", PROTOCOL): SetNode,
("TupleNode", PROTOCOL): TupleNode,
("BytesNode", PROTOCOL): BytesNode,
("BytearrayNode", PROTOCOL): BytearrayNode,
("SliceNode", PROTOCOL): SliceNode,
("FunctionNode", PROTOCOL): FunctionNode,
("MethodNode", PROTOCOL): MethodNode,
("PartialNode", PROTOCOL): PartialNode,
("TypeNode", PROTOCOL): TypeNode,
("ObjectNode", PROTOCOL): ObjectNode,
("JsonNode", PROTOCOL): JsonNode,
("OperatorFuncNode", PROTOCOL): OperatorFuncNode,
}
30 changes: 8 additions & 22 deletions skops/io/_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np

from ._audit import Node, get_tree
from ._general import function_get_state
from ._protocol import PROTOCOL
from ._utils import LoadContext, SaveContext, get_module, get_state, gettype
from .exceptions import UnsupportedTypeException

Expand Down Expand Up @@ -195,22 +197,6 @@ def _construct(self):
return gettype(self.module_name, self.class_name)(bit_generator=bit_generator)


# For numpy.ufunc we need to get the type from the type's module, but for other
# functions we get it from objet's module directly. Therefore sett a especial
# get_state method for them here. The load is the same as other functions.
def ufunc_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__, # ufunc
"__module__": get_module(type(obj)), # numpy
"__loader__": "FunctionNode",
"content": {
"module_path": get_module(obj),
"function": obj.__name__,
},
}
return res


def dtype_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
# we use numpy's internal save mechanism to store the dtype by
# saving/loading an empty array with that dtype.
Expand Down Expand Up @@ -247,16 +233,16 @@ def _construct(self):
(np.generic, ndarray_get_state),
(np.ndarray, ndarray_get_state),
(np.ma.MaskedArray, maskedarray_get_state),
(np.ufunc, ufunc_get_state),
(np.ufunc, function_get_state),
(np.dtype, dtype_get_state),
(np.random.RandomState, random_state_get_state),
(np.random.Generator, random_generator_get_state),
]
# tuples of type and function that creates the instance of that type
NODE_TYPE_MAPPING = {
"NdArrayNode": NdArrayNode,
"MaskedArrayNode": MaskedArrayNode,
"DTypeNode": DTypeNode,
"RandomStateNode": RandomStateNode,
"RandomGeneratorNode": RandomGeneratorNode,
("NdArrayNode", PROTOCOL): NdArrayNode,
("MaskedArrayNode", PROTOCOL): MaskedArrayNode,
("DTypeNode", PROTOCOL): DTypeNode,
("RandomStateNode", PROTOCOL): RandomStateNode,
("RandomGeneratorNode", PROTOCOL): RandomGeneratorNode,
}
16 changes: 10 additions & 6 deletions skops/io/_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from ._utils import LoadContext, SaveContext, _get_state, get_state

# We load the dispatch functions from the corresponding modules and register
# them.
# them. Old protocols are found in the 'old/' directory, with the protocol
# version appended to the corresponding module name.
modules = ["._general", "._numpy", "._scipy", "._sklearn"]
modules.extend([".old._general_v0"])
for module_name in modules:
# register exposed functions for get_state and get_tree
module = importlib.import_module(module_name, package="skops.io")
Expand Down Expand Up @@ -123,9 +125,9 @@ def load(file: str | Path, trusted: bool | Sequence[str] = False) -> Any:

"""
with ZipFile(file, "r") as input_zip:
schema = input_zip.read("schema.json")
load_context = LoadContext(src=input_zip)
tree = get_tree(json.loads(schema), load_context)
schema = json.loads(input_zip.read("schema.json"))
load_context = LoadContext(src=input_zip, protocol=schema["protocol"])
tree = get_tree(schema, load_context)
audit_tree(tree, trusted)
instance = tree.construct()

Expand Down Expand Up @@ -164,7 +166,7 @@ def loads(data: bytes, trusted: bool | Sequence[str] = False) -> Any:

with ZipFile(io.BytesIO(data), "r") as zip_file:
schema = json.loads(zip_file.read("schema.json"))
load_context = LoadContext(src=zip_file)
load_context = LoadContext(src=zip_file, protocol=schema["protocol"])
tree = get_tree(schema, load_context)
audit_tree(tree, trusted)
instance = tree.construct()
Expand Down Expand Up @@ -208,7 +210,9 @@ def get_untrusted_types(

with ZipFile(content, "r") as zip_file:
schema = json.loads(zip_file.read("schema.json"))
tree = get_tree(schema, load_context=LoadContext(src=zip_file))
tree = get_tree(
schema, load_context=LoadContext(src=zip_file, protocol=schema["protocol"])
)
untrusted_types = tree.get_unsafe_set()

return sorted(untrusted_types)
26 changes: 26 additions & 0 deletions skops/io/_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""The current protocol of the skops version

Notes on updating the protocol:

Every time that a backwards incompatible change to the skops format is made
for the first time within a release, the protocol should be bumped to the next
higher number. The old version of the Node, which knows how to deal with the
old state, should be preserved, registered, and tested. Let's give an example:

- There is a BC breaking change in FunctionNode.
- Since it's the first BC breaking change in the skops format in this release,
bump skops.io._protocol.PROTOCOL (this file) from version X to X+1.
- Move the old FunctionNode code into 'skops/io/old/_general_vX.py', where 'X'
is the old protocol.
- Register the _general_vX.FunctionNode in NODE_TYPE_MAPPING inside of
_persist.py.
- Write a test in test_persist_old.py that shows that the old state can
still be loaded. Look at test_persist_old.test_function_v0 for inspiration.

Now, if a user loads a FunctionNode state with version X using skops with
version Y>X, the old code will be used instead of the new one. For all other
node types, if there is no loader for version X, skops will automatically use
version Y instead.

"""
PROTOCOL = 1
3 changes: 2 additions & 1 deletion skops/io/_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from scipy.sparse import load_npz, save_npz, spmatrix

from ._audit import Node
from ._protocol import PROTOCOL
from ._utils import LoadContext, SaveContext, get_module


Expand Down Expand Up @@ -65,5 +66,5 @@ def _construct(self):
NODE_TYPE_MAPPING = {
# use 'spmatrix' to check if a matrix is a sparse matrix because that is
# what scipy.sparse.issparse checks
"SparseMatrixNode": SparseMatrixNode,
("SparseMatrixNode", PROTOCOL): SparseMatrixNode,
}
8 changes: 5 additions & 3 deletions skops/io/_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from sklearn.cluster import Birch

from ._protocol import PROTOCOL

try:
# TODO: remove once support for sklearn<1.2 is dropped. See #187
from sklearn.covariance._graph_lasso import _DictWithDeprecatedKeys
Expand Down Expand Up @@ -232,8 +234,8 @@ def _construct(self):

# tuples of type and function that creates the instance of that type
NODE_TYPE_MAPPING = {
"SGDNode": SGDNode,
"TreeNode": TreeNode,
("SGDNode", PROTOCOL): SGDNode,
("TreeNode", PROTOCOL): TreeNode,
}

# TODO: remove once support for sklearn<1.2 is dropped.
Expand All @@ -244,5 +246,5 @@ def _construct(self):
(_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_state)
)
NODE_TYPE_MAPPING[
"_DictWithDeprecatedKeysNode"
("_DictWithDeprecatedKeysNode", PROTOCOL)
] = _DictWithDeprecatedKeysNode # type: ignore
9 changes: 4 additions & 5 deletions skops/io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Any, Type
from zipfile import ZipFile

from ._protocol import PROTOCOL


# The following two functions are copied from cpython's pickle.py file.
# ---------------------------------------------------------------------
Expand Down Expand Up @@ -83,10 +85,6 @@ def get_module(obj: Any) -> str:
return whichmodule(obj, obj.__name__)


# For now, there is just one protocol version
DEFAULT_PROTOCOL = 0


@dataclass(frozen=True)
class SaveContext:
"""Context required for saving the objects
Expand All @@ -105,7 +103,7 @@ class SaveContext:
"""

zip_file: ZipFile
protocol: int = DEFAULT_PROTOCOL
protocol: int = PROTOCOL
memo: dict[int, Any] = field(default_factory=dict)

def memoize(self, obj: Any) -> int:
Expand Down Expand Up @@ -135,6 +133,7 @@ class LoadContext:
"""

src: ZipFile
protocol: int
memo: dict[int, Any] = field(default_factory=dict)

def memoize(self, obj: Any, id: int) -> None:
Expand Down
Loading