diff --git a/.gitignore b/.gitignore index 8992bea1..7dc6c890 100644 --- a/.gitignore +++ b/.gitignore @@ -117,6 +117,8 @@ node_modules # Vim *.swp +# MacOS +.DS_Store exports trash diff --git a/docs/changes.rst b/docs/changes.rst index ad0e7afd..aebd161b 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -13,6 +13,8 @@ v0.8 ---- - Adds the abillity to set the :attr:`.Section.folded` property when using :meth:`.Card.add`. :pr:`361` by :user:`Thomas Lazarus `. +- Add the CLI command to update Skops files to the latest Skops persistence format. + (:func:`.cli._update.main`). :pr:`333` by :user:`Edoardo Abati ` v0.7 ---- diff --git a/docs/persistence.rst b/docs/persistence.rst index b2351e0f..bf74e38a 100644 --- a/docs/persistence.rst +++ b/docs/persistence.rst @@ -130,8 +130,13 @@ for more details. Command Line Interface ###################### -Skops has a command line interface to convert scikit-learn models persisted with -``Pickle`` to ``Skops`` files. +Skops has a command line interface to: + +- convert scikit-learn models persisted with ``Pickle`` to ``Skops`` files. +- update ``Skops`` files to the latest version. + +``skops convert`` +~~~~~~~~~~~~~~~~~ To convert a file from the command line, use the ``skops convert`` entrypoint. @@ -151,6 +156,22 @@ For example, to convert all ``.pkl`` flies in the current directory: Further help for the different supported options can be found by calling ``skops convert --help`` in a terminal. +``skops update`` +~~~~~~~~~~~~~~~~ + +To update a ``Skops`` file from the command line, use the ``skops update`` command. +Skops will check the protocol version of the file to determine if it needs to be updated to the current version. + +The below command is an example on how to create an updated version of a file +``my_model.skops`` and save it as ``my_model-updated.skops``: + +.. code-block:: console + + skops update my_model.skops -o my_model-updated.skops + +Further help for the different supported options can be found by calling +``skops update --help`` in a terminal. + Visualization ############# diff --git a/skops/cli/_update.py b/skops/cli/_update.py new file mode 100644 index 00000000..c648e37b --- /dev/null +++ b/skops/cli/_update.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import argparse +import json +import logging +import shutil +import tempfile +import zipfile +from pathlib import Path + +from skops.cli._utils import get_log_level +from skops.io import dump, load +from skops.io._protocol import PROTOCOL + + +def _update_file( + input_file: str | Path, + output_file: str | Path | None = None, + inplace: bool = False, + logger: logging.Logger = logging.getLogger(), +) -> None: + """Function that is called by ``skops update`` entrypoint. + + Loads a skops model from the input path, updates it to the current skops format, and + saves to an output file. It will overwrite the input file if `inplace` is True. + + Parameters + ---------- + input_file : str, or Path + Path of input skops model to load. + + output_file : str, or Path, default=None + Path to save the updated skops model to. + + inplace : bool, default=False + Whether to update and overwrite the input file in place. + + logger : logging.Logger, default=logging.getLogger() + Logger to use for logging. + """ + if inplace: + if output_file is None: + output_file = input_file + else: + raise ValueError( + "Cannot specify both an output file path and the inplace flag. Please" + " choose whether you want to create a new file or overwrite the input" + " file." + ) + + input_model = load(input_file, trusted=True) + with zipfile.ZipFile(input_file, "r") as zip_file: + input_file_schema = json.loads(zip_file.read("schema.json")) + + if input_file_schema["protocol"] == PROTOCOL: + logger.warning( + "File was not updated because already up to date with the current protocol:" + f" {PROTOCOL}" + ) + return None + + if input_file_schema["protocol"] > PROTOCOL: + logger.warning( + "File cannot be updated because its protocol is more recent than the " + f"current protocol: {PROTOCOL}" + ) + return None + + if output_file is None: + logger.warning( + f"File can be updated to the current protocol: {PROTOCOL}. Please" + " specify an output file path or use the `inplace` flag to create the" + " updated Skops file." + ) + return None + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_output_file = Path(tmp_dir) / f"{output_file}.tmp" + dump(input_model, tmp_output_file) + shutil.move(str(tmp_output_file), str(output_file)) + logger.info(f"Updated skops file written to {output_file}") + + +def format_parser( + parser: argparse.ArgumentParser | None = None, +) -> argparse.ArgumentParser: + """Adds arguments and help to parent CLI parser for the `update` method.""" + + if not parser: # used in tests + parser = argparse.ArgumentParser() + + parser_subgroup = parser.add_argument_group("update") + parser_subgroup.add_argument("input", help="Path to an input file to update.") + + parser_subgroup.add_argument( + "-o", + "--output-file", + help="Specify the output file name for the updated skops file.", + default=None, + ) + parser_subgroup.add_argument( + "--inplace", + help="Update and overwrite the input file in place.", + action="store_true", + ) + parser_subgroup.add_argument( + "-v", + "--verbose", + help=( + "Increases verbosity of logging. Can be used multiple times to increase " + "verbosity further." + ), + action="count", + dest="loglevel", + default=0, + ) + return parser + + +def main( + parsed_args: argparse.Namespace, + logger: logging.Logger = logging.getLogger(), +) -> None: + output_file = Path(parsed_args.output_file) if parsed_args.output_file else None + input_file = Path(parsed_args.input) + inplace = parsed_args.inplace + + logging.basicConfig(format="%(levelname)-8s: %(message)s") + logger.setLevel(level=get_log_level(parsed_args.loglevel)) + + _update_file( + input_file=input_file, + output_file=output_file, + inplace=inplace, + logger=logger, + ) diff --git a/skops/cli/entrypoint.py b/skops/cli/entrypoint.py index c98ef8d3..e23e839c 100644 --- a/skops/cli/entrypoint.py +++ b/skops/cli/entrypoint.py @@ -1,6 +1,7 @@ import argparse import skops.cli._convert +import skops.cli._update def main_cli(command_line_args=None): @@ -32,6 +33,10 @@ def main_cli(command_line_args=None): "method": skops.cli._convert.main, "format_parser": skops.cli._convert.format_parser, }, + "update": { + "method": skops.cli._update.main, + "format_parser": skops.cli._update.format_parser, + }, } for func_name, values in function_map.items(): diff --git a/skops/cli/tests/test_entrypoint.py b/skops/cli/tests/test_entrypoint.py index 2fd3cb60..d463ffab 100644 --- a/skops/cli/tests/test_entrypoint.py +++ b/skops/cli/tests/test_entrypoint.py @@ -39,3 +39,23 @@ def test_convert_works_as_expected( ) assert caplog.at_level(logging.WARNING) + + @mock.patch("skops.cli._update._update_file") + def test_update_works_as_expected( + self, + update_file_mock: mock.MagicMock, + ): + """ + To make sure the parser is configured correctly, when 'update' + is the first argument. + """ + + args = ["update", "abc.skops", "-o", "abc-new.skops"] + + main_cli(args) + update_file_mock.assert_called_once_with( + input_file=pathlib.Path("abc.skops"), + output_file=pathlib.Path("abc-new.skops"), + inplace=False, + logger=mock.ANY, + ) diff --git a/skops/cli/tests/test_update.py b/skops/cli/tests/test_update.py new file mode 100644 index 00000000..0169294e --- /dev/null +++ b/skops/cli/tests/test_update.py @@ -0,0 +1,227 @@ +import logging +import pathlib +from functools import partial +from unittest import mock + +import numpy as np +import pytest + +from skops.cli import _update +from skops.io import _persist, _protocol, dump, load + + +class TestUpdate: + model_name = "some_model_name" + + @pytest.fixture + def safe_obj(self) -> np.ndarray: + return np.ndarray([1, 2, 3, 4]) + + @pytest.fixture + def skops_path(self, tmp_path: pathlib.Path) -> pathlib.Path: + return tmp_path / f"{self.model_name}.skops" + + @pytest.fixture + def new_skops_path(self, tmp_path: pathlib.Path) -> pathlib.Path: + return tmp_path / f"{self.model_name}-new.skops" + + @pytest.fixture + def dump_file(self, skops_path: pathlib.Path, safe_obj: np.ndarray): + """Dump an object using the current protocol version.""" + dump(safe_obj, skops_path) + + @pytest.fixture + @mock.patch( + "skops.io._persist.SaveContext", partial(_persist.SaveContext, protocol=0) + ) + def dump_old_file(self, skops_path: pathlib.Path, safe_obj: np.ndarray): + """Dump an object using an old protocol version so that the file needs + updating. + """ + dump(safe_obj, skops_path) + + @pytest.fixture + @mock.patch( + "skops.io._persist.SaveContext", + partial(_persist.SaveContext, protocol=_protocol.PROTOCOL + 1), + ) + def dump_new_file(self, skops_path: pathlib.Path, safe_obj: np.ndarray): + """Dump an object using an new protocol version so that the file cannot be + updated. + """ + dump(safe_obj, skops_path) + + @pytest.mark.parametrize("inplace", [True, False]) + def test_base_case_works_as_expected( + self, + skops_path: pathlib.Path, + new_skops_path: pathlib.Path, + inplace: bool, + dump_old_file, + safe_obj, + ): + mock_logger = mock.MagicMock() + expected_output_file = new_skops_path if not inplace else skops_path + + _update._update_file( + input_file=skops_path, + output_file=new_skops_path if not inplace else None, + inplace=inplace, + logger=mock_logger, + ) + updated_obj = load(expected_output_file) + + assert np.array_equal(updated_obj, safe_obj) + + # Check logging messages + mock_logger.info.assert_called_once_with( + f"Updated skops file written to {expected_output_file}" + ) + mock_logger.warning.assert_not_called() + mock_logger.error.assert_not_called() + mock_logger.debug.assert_not_called() + + @mock.patch("skops.cli._update.dump") + def test_only_diff( + self, mock_dump: mock.MagicMock, skops_path: pathlib.Path, dump_old_file + ): + mock_logger = mock.MagicMock() + _update._update_file( + input_file=skops_path, + output_file=None, + inplace=False, + logger=mock_logger, + ) + mock_dump.assert_not_called() + + # Check logging messages + mock_logger.warning.assert_called_once_with( + f"File can be updated to the current protocol: {_protocol.PROTOCOL}. Please" + " specify an output file path or use the `inplace` flag to create the" + " updated Skops file." + ) + mock_logger.info.assert_not_called() + mock_logger.error.assert_not_called() + mock_logger.debug.assert_not_called() + + def test_no_update_same_protocol( + self, + skops_path: pathlib.Path, + new_skops_path: pathlib.Path, + dump_file, + ): + mock_logger = mock.MagicMock() + _update._update_file( + input_file=skops_path, + output_file=new_skops_path, + inplace=False, + logger=mock_logger, + ) + mock_logger.warning.assert_called_once_with( + "File was not updated because already up to date with the current protocol:" + f" {_protocol.PROTOCOL}" + ) + assert not new_skops_path.exists() + + def test_no_update_newer_protocol( + self, + skops_path: pathlib.Path, + new_skops_path: pathlib.Path, + dump_new_file, + ): + mock_logger = mock.MagicMock() + _update._update_file( + input_file=skops_path, + output_file=new_skops_path, + inplace=False, + logger=mock_logger, + ) + mock_logger.warning.assert_called_once_with( + "File cannot be updated because its protocol is more recent than the " + f"current protocol: {_protocol.PROTOCOL}" + ) + assert not new_skops_path.exists() + + def test_error_with_output_file_and_inplace( + self, skops_path: pathlib.Path, new_skops_path: pathlib.Path, dump_file + ): + with pytest.raises(ValueError): + _update._update_file( + input_file=skops_path, output_file=new_skops_path, inplace=True + ) + + +class TestMain: + @pytest.fixture + def tmp_logger(self) -> logging.Logger: + return logging.getLogger() + + @pytest.mark.parametrize("output_flag", ["-o", "--output"]) + @mock.patch("skops.cli._update._update_file") + def test_output_argument( + self, mock_update: mock.MagicMock, output_flag: str, tmp_logger: logging.Logger + ): + input_path = "abc.skops" + output_path = "abc-new.skops" + namespace, _ = _update.format_parser().parse_known_args( + [input_path, output_flag, output_path] + ) + + _update.main(namespace, tmp_logger) + mock_update.assert_called_once_with( + input_file=pathlib.Path(input_path), + output_file=pathlib.Path(output_path), + inplace=False, + logger=tmp_logger, + ) + + @mock.patch("skops.cli._update._update_file") + def test_inplace_argument( + self, mock_update: mock.MagicMock, tmp_logger: logging.Logger + ): + input_path = "abc.skops" + namespace, _ = _update.format_parser().parse_known_args( + [input_path, "--inplace"] + ) + + _update.main(namespace, tmp_logger) + mock_update.assert_called_once_with( + input_file=pathlib.Path(input_path), + output_file=None, + inplace=True, + logger=tmp_logger, + ) + + @mock.patch("skops.cli._update._update_file") + @pytest.mark.parametrize( + "verbosity, expected_level", + [ + ("", logging.WARNING), + ("-v", logging.INFO), + ("--verbose", logging.INFO), + ("-vv", logging.DEBUG), + ("-v -v", logging.DEBUG), + ("-vvvvv", logging.DEBUG), + ("--verbose --verbose", logging.DEBUG), + ], + ) + def test_given_log_levels_works_as_expected( + self, + mock_update: mock.MagicMock, + verbosity: str, + expected_level: int, + tmp_logger: logging.Logger, + ): + input_path = "abc.skops" + output_path = "def.skops" + args = [input_path, "--output", output_path, *verbosity.split()] + + namespace, _ = _update.format_parser().parse_known_args(args) + _update.main(namespace, tmp_logger) + mock_update.assert_called_once_with( + input_file=pathlib.Path(input_path), + output_file=pathlib.Path(output_path), + inplace=False, + logger=tmp_logger, + ) + assert tmp_logger.getEffectiveLevel() == expected_level