From c0238decf2dacab363d13543acf21f363c30dff8 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Mon, 10 Apr 2023 12:33:59 +0200 Subject: [PATCH 01/26] added update entrypoint --- skops/cli/_update.py | 118 ++++++++++++++++++++++++++++++++++++++++ skops/cli/entrypoint.py | 5 ++ 2 files changed, 123 insertions(+) create mode 100644 skops/cli/_update.py diff --git a/skops/cli/_update.py b/skops/cli/_update.py new file mode 100644 index 00000000..cbd8a7d3 --- /dev/null +++ b/skops/cli/_update.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import argparse +import logging +import os +import pathlib + +# import pickle +from typing import Optional + +from skops.cli._utils import get_log_level + +# from skops.io import dumps, get_untrusted_types + + +def _update_file( + input_file: os.PathLike, + output_file: os.PathLike, + logger: logging.Logger = logging.getLogger(), +) -> None: + """Function that is called by ``skops update`` entrypoint. + + Loads a skops model from the input path, converts to the latest skops format, and saves to + output file. + + Parameters + ---------- + input_file : os.PathLike + Path of input .pkl model to load. + + output_file : os.PathLike + Path to save .skops model to. + + """ + # model_name = pathlib.Path(input_file).stem + + # logger.debug(f"Converting {model_name}") + + # with open(input_file, "rb") as f: + # obj = pickle.load(f) + # skops_dump = dumps(obj) + + # untrusted_types = get_untrusted_types(data=skops_dump) + + # if not untrusted_types: + # logger.info(f"No unknown types found in {model_name}.") + # else: + # untrusted_str = ", ".join(untrusted_types) + + # logger.warning( + # f"While converting {input_file}, " + # "the following unknown types were found: " + # f"{untrusted_str}. " + # f"When loading {output_file} with skops.load, these types must be " + # "specified as 'trusted'" + # ) + + # with open(output_file, "wb") as out_file: + # logger.debug(f"Writing to {output_file}") + # out_file.write(skops_dump) + raise NotImplementedError + + +def format_parser( + parser: Optional[argparse.ArgumentParser] = None, +) -> argparse.ArgumentParser: + """Adds arguments and help to parent CLI parser for the `update` method.""" + + if not parser: # used in tests + parser = argparse.ArgumentParser() + + parser_subgroup = parser.add_argument_group("update") + parser_subgroup.add_argument("input", help="Path to an input file to update. ") + + parser_subgroup.add_argument( + "-o", + "--output-file", + help=( + # TODO: decide what to do with this. Default name? or compulsory? + "Specify the output file name for the updated skops file. " + "If not provided, will default to using the same name as the input file, " + "and saving to the current working directory with the suffix '.skops'." + ), + default=None, + ) + parser_subgroup.add_argument( + "-v", + "--verbose", + help=( + "Increases verbosity of logging. Can be used multiple times to increase " + "verbosity further." + ), + action="count", + dest="loglevel", + default=0, + ) + return parser + + +def main( + parsed_args: argparse.Namespace, +) -> None: + output_file = parsed_args.output_file + input_file = parsed_args.input + + logging.basicConfig( + format="%(levelname)-8s: %(message)s", level=get_log_level(parsed_args.loglevel) + ) + + if not output_file: + # No filename provided, defaulting to base file path + file_name = pathlib.Path(input_file).stem + output_file = pathlib.Path.cwd() / f"{file_name}.skops" + + _update_file( + input_file=input_file, + output_file=output_file, + ) diff --git a/skops/cli/entrypoint.py b/skops/cli/entrypoint.py index c98ef8d3..e23e839c 100644 --- a/skops/cli/entrypoint.py +++ b/skops/cli/entrypoint.py @@ -1,6 +1,7 @@ import argparse import skops.cli._convert +import skops.cli._update def main_cli(command_line_args=None): @@ -32,6 +33,10 @@ def main_cli(command_line_args=None): "method": skops.cli._convert.main, "format_parser": skops.cli._convert.format_parser, }, + "update": { + "method": skops.cli._update.main, + "format_parser": skops.cli._update.format_parser, + }, } for func_name, values in function_map.items(): From c23b4d0d823e8aee3b46a478e36cae5bcec13f48 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Mon, 10 Apr 2023 17:16:24 +0200 Subject: [PATCH 02/26] minimal working cli without optionals --- skops/cli/_update.py | 74 ++++++++++---------------------------------- 1 file changed, 17 insertions(+), 57 deletions(-) diff --git a/skops/cli/_update.py b/skops/cli/_update.py index cbd8a7d3..b943c91e 100644 --- a/skops/cli/_update.py +++ b/skops/cli/_update.py @@ -2,63 +2,35 @@ import argparse import logging -import os -import pathlib - -# import pickle -from typing import Optional +from pathlib import Path +from typing import Optional, Union from skops.cli._utils import get_log_level - -# from skops.io import dumps, get_untrusted_types +from skops.io import dump, load def _update_file( - input_file: os.PathLike, - output_file: os.PathLike, + input_file: Union[str, Path], + output_file: Union[str, Path], logger: logging.Logger = logging.getLogger(), ) -> None: """Function that is called by ``skops update`` entrypoint. - Loads a skops model from the input path, converts to the latest skops format, and saves to - output file. + Loads a skops model from the input path, converts to the current skops format, and + saves to output file. Parameters ---------- - input_file : os.PathLike - Path of input .pkl model to load. + input_file : Union[str, Path] + Path of input skops model to load. - output_file : os.PathLike - Path to save .skops model to. + output_file : Union[str, Path] + Path to save the updated skops model to. """ - # model_name = pathlib.Path(input_file).stem - - # logger.debug(f"Converting {model_name}") - - # with open(input_file, "rb") as f: - # obj = pickle.load(f) - # skops_dump = dumps(obj) - - # untrusted_types = get_untrusted_types(data=skops_dump) - - # if not untrusted_types: - # logger.info(f"No unknown types found in {model_name}.") - # else: - # untrusted_str = ", ".join(untrusted_types) - - # logger.warning( - # f"While converting {input_file}, " - # "the following unknown types were found: " - # f"{untrusted_str}. " - # f"When loading {output_file} with skops.load, these types must be " - # "specified as 'trusted'" - # ) - - # with open(output_file, "wb") as out_file: - # logger.debug(f"Writing to {output_file}") - # out_file.write(skops_dump) - raise NotImplementedError + input_model = load(input_file, trusted=True) + dump(input_model, output_file) + logger.debug(f"Updated skops file written in {output_file}") def format_parser( @@ -75,13 +47,7 @@ def format_parser( parser_subgroup.add_argument( "-o", "--output-file", - help=( - # TODO: decide what to do with this. Default name? or compulsory? - "Specify the output file name for the updated skops file. " - "If not provided, will default to using the same name as the input file, " - "and saving to the current working directory with the suffix '.skops'." - ), - default=None, + help="Specify the output file name for the updated skops file.", ) parser_subgroup.add_argument( "-v", @@ -100,18 +66,12 @@ def format_parser( def main( parsed_args: argparse.Namespace, ) -> None: - output_file = parsed_args.output_file - input_file = parsed_args.input + output_file = Path(parsed_args.output_file) + input_file = Path(parsed_args.input) logging.basicConfig( format="%(levelname)-8s: %(message)s", level=get_log_level(parsed_args.loglevel) ) - - if not output_file: - # No filename provided, defaulting to base file path - file_name = pathlib.Path(input_file).stem - output_file = pathlib.Path.cwd() / f"{file_name}.skops" - _update_file( input_file=input_file, output_file=output_file, From f0923c136f9173c3be9d1848c95107db8982d356 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Mon, 10 Apr 2023 17:20:49 +0200 Subject: [PATCH 03/26] added macos file to gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 8992bea1..7dc6c890 100644 --- a/.gitignore +++ b/.gitignore @@ -117,6 +117,8 @@ node_modules # Vim *.swp +# MacOS +.DS_Store exports trash From 04a1738c7321945d25d0477480a0178b01206cb7 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Sat, 15 Apr 2023 18:16:13 +0200 Subject: [PATCH 04/26] added `update` cli documentation --- docs/persistence.rst | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/docs/persistence.rst b/docs/persistence.rst index 2ed2819e..c8c592ac 100644 --- a/docs/persistence.rst +++ b/docs/persistence.rst @@ -113,8 +113,13 @@ objects having references to functions such as ``numpy.sqrt``. Command Line Interface ###################### -Skops has a command line interface to convert scikit-learn models persisted with -``Pickle`` to ``Skops`` files. +Skops has a command line interface to: + +- convert scikit-learn models persisted with ``Pickle`` to ``Skops`` files. +- update ``Skops`` files to the latest version. + +``skops convert`` +----------------- To convert a file from the command line, use the ``skops convert`` entrypoint. @@ -134,6 +139,21 @@ For example, to convert all ``.pkl`` flies in the current directory: Further help for the different supported options can be found by calling ``skops convert --help`` in a terminal. +``skops update`` +---------------- + +To update a ``Skops`` file from the command line, use the ``skops update`` command. + +The below command is an example on how to create an updated version of a file +``my_model.skops`` and save it as ``my_model-updated.skops``: + +.. code-block:: console + + skops update my_model.skops -o my_model-updated.skops + +Further help for the different supported options can be found by calling +``skops update --help`` in a terminal. + Visualization ############# From 157170d7ff69bf75b77aa8f1f71c6d100a64ed8a Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 18 Apr 2023 20:07:23 +0200 Subject: [PATCH 05/26] added unittests for update entrypoint --- skops/cli/_update.py | 8 ++- skops/cli/tests/test_entrypoint.py | 20 ++++++ skops/cli/tests/test_update.py | 108 +++++++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 3 deletions(-) create mode 100644 skops/cli/tests/test_update.py diff --git a/skops/cli/_update.py b/skops/cli/_update.py index b943c91e..5c1b24a1 100644 --- a/skops/cli/_update.py +++ b/skops/cli/_update.py @@ -65,14 +65,16 @@ def format_parser( def main( parsed_args: argparse.Namespace, + logger: logging.Logger = logging.getLogger(), ) -> None: output_file = Path(parsed_args.output_file) input_file = Path(parsed_args.input) - logging.basicConfig( - format="%(levelname)-8s: %(message)s", level=get_log_level(parsed_args.loglevel) - ) + logging.basicConfig(format="%(levelname)-8s: %(message)s") + logger.setLevel(level=get_log_level(parsed_args.loglevel)) + _update_file( input_file=input_file, output_file=output_file, + logger=logger, ) diff --git a/skops/cli/tests/test_entrypoint.py b/skops/cli/tests/test_entrypoint.py index 2fd3cb60..074f52a7 100644 --- a/skops/cli/tests/test_entrypoint.py +++ b/skops/cli/tests/test_entrypoint.py @@ -39,3 +39,23 @@ def test_convert_works_as_expected( ) assert caplog.at_level(logging.WARNING) + + @mock.patch("skops.cli._update._update_file") + def test_update_works_as_expected( + self, + update_file_mock: mock.MagicMock, + ): + """ + Intended as a unit test to make sure, + given 'update' as the first argument, + the parser is configured correctly + """ + + args = ["update", "abc.skops", "-o", "abc-new.skops"] + + main_cli(args) + update_file_mock.assert_called_once_with( + input_file=pathlib.Path("abc.skops"), + output_file=pathlib.Path("abc-new.skops"), + logger=mock.ANY, + ) diff --git a/skops/cli/tests/test_update.py b/skops/cli/tests/test_update.py new file mode 100644 index 00000000..b9bb0985 --- /dev/null +++ b/skops/cli/tests/test_update.py @@ -0,0 +1,108 @@ +import logging +import pathlib +from unittest import mock + +import numpy as np +import pytest + +from skops.cli import _update +from skops.io import dump, load + + +class MockUnsafeType: + def __init__(self): + pass + + +class TestUpdate: + model_name = "some_model_name" + + @pytest.fixture + def safe_obj(self) -> np.ndarray: + return np.ndarray([1, 2, 3, 4]) + + @pytest.fixture + def skops_path(self, tmp_path: pathlib.Path) -> pathlib.Path: + return tmp_path / f"{self.model_name}.skops" + + @pytest.fixture + def new_skops_path(self, tmp_path: pathlib.Path) -> pathlib.Path: + return tmp_path / f"{self.model_name}-new.skops" + + @pytest.fixture + def dump_file(self, skops_path: pathlib.Path, safe_obj: np.ndarray): + dump(safe_obj, skops_path) + + def test_base_case_works_as_expected( + self, + skops_path: pathlib.Path, + new_skops_path: pathlib.Path, + dump_file, + safe_obj, + ): + mock_logger = mock.MagicMock() + _update._update_file( + input_file=skops_path, output_file=new_skops_path, logger=mock_logger + ) + updated_obj = load(new_skops_path) + assert np.array_equal(updated_obj, safe_obj) + + # Check no warnings or errors raised + mock_logger.warning.assert_not_called() + mock_logger.error.assert_not_called() + + +class TestMain: + @pytest.fixture + def tmp_logger(self) -> logging.Logger: + return logging.getLogger() + + @mock.patch("skops.cli._update._update_file") + def test_base_works_as_expected( + self, mock_update: mock.MagicMock, tmp_logger: logging.Logger + ): + input_path = "abc.skops" + output_path = "abc-new.skops" + namespace, _ = _update.format_parser().parse_known_args( + [input_path, "-o", output_path] + ) + + _update.main(namespace, tmp_logger) + mock_update.assert_called_once_with( + input_file=pathlib.Path(input_path), + output_file=pathlib.Path(output_path), + logger=tmp_logger, + ) + + @mock.patch("skops.cli._update._update_file") + @pytest.mark.parametrize( + "verbosity, expected_level", + [ + ("", logging.WARNING), + ("-v", logging.INFO), + ("--verbose", logging.INFO), + ("-vv", logging.DEBUG), + ("-v -v", logging.DEBUG), + ("-vvvvv", logging.DEBUG), + ("--verbose --verbose", logging.DEBUG), + ], + ) + def test_given_log_levels_works_as_expected( + self, + mock_update: mock.MagicMock, + verbosity: str, + expected_level: int, + tmp_logger: logging.Logger, + ): + input_path = "abc.skops" + output_path = "def.skops" + args = [input_path, "--output", output_path, *verbosity.split()] + + namespace, _ = _update.format_parser().parse_known_args(args) + _update.main(namespace, tmp_logger) + mock_update.assert_called_once_with( + input_file=pathlib.Path(input_path), + output_file=pathlib.Path(output_path), + logger=tmp_logger, + ) + assert tmp_logger.getEffectiveLevel() == expected_level From 121c7ca37cc953078b11fbd67e691e978fa99b63 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Sun, 23 Apr 2023 14:36:55 +0200 Subject: [PATCH 06/26] adjusted logging message Co-authored-by: Benjamin Bossan --- skops/cli/_update.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skops/cli/_update.py b/skops/cli/_update.py index 5c1b24a1..0018e547 100644 --- a/skops/cli/_update.py +++ b/skops/cli/_update.py @@ -30,7 +30,7 @@ def _update_file( """ input_model = load(input_file, trusted=True) dump(input_model, output_file) - logger.debug(f"Updated skops file written in {output_file}") + logger.debug(f"Updated skops file written to {output_file}") def format_parser( From 683281ddfff3204acf91af2fe5b2747e690d86c3 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Sun, 23 Apr 2023 14:37:22 +0200 Subject: [PATCH 07/26] removed redundant space Co-authored-by: Benjamin Bossan --- skops/cli/_update.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skops/cli/_update.py b/skops/cli/_update.py index 0018e547..b2f64a77 100644 --- a/skops/cli/_update.py +++ b/skops/cli/_update.py @@ -42,7 +42,7 @@ def format_parser( parser = argparse.ArgumentParser() parser_subgroup = parser.add_argument_group("update") - parser_subgroup.add_argument("input", help="Path to an input file to update. ") + parser_subgroup.add_argument("input", help="Path to an input file to update.") parser_subgroup.add_argument( "-o", From 5c382daeab761ec36c60895f0583e59224a2d202 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Sun, 23 Apr 2023 14:52:31 +0200 Subject: [PATCH 08/26] adjusted narrow docstring in test --- skops/cli/tests/test_entrypoint.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/skops/cli/tests/test_entrypoint.py b/skops/cli/tests/test_entrypoint.py index 074f52a7..0b0f9c3d 100644 --- a/skops/cli/tests/test_entrypoint.py +++ b/skops/cli/tests/test_entrypoint.py @@ -46,9 +46,8 @@ def test_update_works_as_expected( update_file_mock: mock.MagicMock, ): """ - Intended as a unit test to make sure, - given 'update' as the first argument, - the parser is configured correctly + To make sure the parser is configured correctly, when 'update' + is the first argument. """ args = ["update", "abc.skops", "-o", "abc-new.skops"] From 7e8ef1f43de6dc9faaca0f608440b0efb9c938de Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Sun, 23 Apr 2023 17:33:45 +0200 Subject: [PATCH 09/26] added check for protocol number before update --- skops/cli/_update.py | 12 +++++++++--- skops/cli/_utils.py | 10 ++++++++++ skops/cli/tests/test_update.py | 36 +++++++++++++++++++++++++++------- 3 files changed, 48 insertions(+), 10 deletions(-) diff --git a/skops/cli/_update.py b/skops/cli/_update.py index b2f64a77..e9705e98 100644 --- a/skops/cli/_update.py +++ b/skops/cli/_update.py @@ -5,8 +5,9 @@ from pathlib import Path from typing import Optional, Union -from skops.cli._utils import get_log_level +from skops.cli._utils import get_log_level, load_schema from skops.io import dump, load +from skops.io._protocol import PROTOCOL def _update_file( @@ -29,8 +30,13 @@ def _update_file( """ input_model = load(input_file, trusted=True) - dump(input_model, output_file) - logger.debug(f"Updated skops file written to {output_file}") + input_file_schema = load_schema(input_file) + + if input_file_schema["protocol"] != PROTOCOL: + dump(input_model, output_file) + logger.info(f"Updated skops file written to {output_file}") + + logger.info(f"Input file is already up to date to the current protocol: {PROTOCOL}") def format_parser( diff --git a/skops/cli/_utils.py b/skops/cli/_utils.py index 55173532..fc8dd10a 100644 --- a/skops/cli/_utils.py +++ b/skops/cli/_utils.py @@ -1,4 +1,8 @@ +import json import logging +import pathlib +import zipfile +from typing import Any, Union def get_log_level(level: int = 0) -> int: @@ -13,3 +17,9 @@ def get_log_level(level: int = 0) -> int: level = 0 return all_levels[level] + + +def load_schema(skops_file_path: Union[str, pathlib.Path]) -> dict[str, Any]: + with zipfile.ZipFile(skops_file_path, "r") as zip_file: + schema = json.loads(zip_file.read("schema.json")) + return schema diff --git a/skops/cli/tests/test_update.py b/skops/cli/tests/test_update.py index b9bb0985..09e3d650 100644 --- a/skops/cli/tests/test_update.py +++ b/skops/cli/tests/test_update.py @@ -1,17 +1,13 @@ import logging import pathlib +from functools import partial from unittest import mock import numpy as np import pytest from skops.cli import _update -from skops.io import dump, load - - -class MockUnsafeType: - def __init__(self): - pass +from skops.io import _persist, dump, load class TestUpdate: @@ -31,13 +27,24 @@ def new_skops_path(self, tmp_path: pathlib.Path) -> pathlib.Path: @pytest.fixture def dump_file(self, skops_path: pathlib.Path, safe_obj: np.ndarray): + """Dump an object using the current protocol version.""" + dump(safe_obj, skops_path) + + @pytest.fixture + @mock.patch( + "skops.io._persist.SaveContext", partial(_persist.SaveContext, protocol=0) + ) + def dump_old_file(self, skops_path: pathlib.Path, safe_obj: np.ndarray): + """Dump an object using an old protocol version so that the file needs + updating. + """ dump(safe_obj, skops_path) def test_base_case_works_as_expected( self, skops_path: pathlib.Path, new_skops_path: pathlib.Path, - dump_file, + dump_old_file, safe_obj, ): mock_logger = mock.MagicMock() @@ -45,12 +52,27 @@ def test_base_case_works_as_expected( input_file=skops_path, output_file=new_skops_path, logger=mock_logger ) updated_obj = load(new_skops_path) + assert np.array_equal(updated_obj, safe_obj) # Check no warnings or errors raised mock_logger.warning.assert_not_called() mock_logger.error.assert_not_called() + def test_no_update( + self, + skops_path: pathlib.Path, + new_skops_path: pathlib.Path, + dump_file, + ): + mock_logger = mock.MagicMock() + _update._update_file( + input_file=skops_path, output_file=new_skops_path, logger=mock_logger + ) + + with pytest.raises(FileNotFoundError): + load(new_skops_path) + class TestMain: @pytest.fixture From 3153180b0a218685513418a1cd0c09bad00b50fc Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Sun, 23 Apr 2023 17:47:00 +0200 Subject: [PATCH 10/26] removed load_schema --- skops/cli/_update.py | 7 +++++-- skops/cli/_utils.py | 10 ---------- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/skops/cli/_update.py b/skops/cli/_update.py index e9705e98..456eb2e4 100644 --- a/skops/cli/_update.py +++ b/skops/cli/_update.py @@ -1,11 +1,13 @@ from __future__ import annotations import argparse +import json import logging +import zipfile from pathlib import Path from typing import Optional, Union -from skops.cli._utils import get_log_level, load_schema +from skops.cli._utils import get_log_level from skops.io import dump, load from skops.io._protocol import PROTOCOL @@ -30,7 +32,8 @@ def _update_file( """ input_model = load(input_file, trusted=True) - input_file_schema = load_schema(input_file) + with zipfile.ZipFile(input_file, "r") as zip_file: + input_file_schema = json.loads(zip_file.read("schema.json")) if input_file_schema["protocol"] != PROTOCOL: dump(input_model, output_file) diff --git a/skops/cli/_utils.py b/skops/cli/_utils.py index fc8dd10a..55173532 100644 --- a/skops/cli/_utils.py +++ b/skops/cli/_utils.py @@ -1,8 +1,4 @@ -import json import logging -import pathlib -import zipfile -from typing import Any, Union def get_log_level(level: int = 0) -> int: @@ -17,9 +13,3 @@ def get_log_level(level: int = 0) -> int: level = 0 return all_levels[level] - - -def load_schema(skops_file_path: Union[str, pathlib.Path]) -> dict[str, Any]: - with zipfile.ZipFile(skops_file_path, "r") as zip_file: - schema = json.loads(zip_file.read("schema.json")) - return schema From 0ca412807b841cf8e0d17978b5414836dddc6ab6 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Sun, 23 Apr 2023 17:51:58 +0200 Subject: [PATCH 11/26] added change description to changelog --- docs/changes.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changes.rst b/docs/changes.rst index 295b876b..a7ad85aa 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -15,6 +15,8 @@ v0.7 - `compression` and `compresslevel` from :class:`~zipfile.ZipFile` are now exposed to the user via :func:`.io.dumps` and :func:`.io.dump`. :pr:`345` by `Adrin Jalali`_. +- Add the CLI command to update Skops files to the latest Skops persistence format. + (:func:`.cli._update.main`). :pr:`333` by :user:`Edoardo Abati ` v0.6 ---- From e5d190d0e37d193e7f5303ac311eb0ef53b12dce Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Sat, 29 Apr 2023 11:04:25 +0200 Subject: [PATCH 12/26] updated logging message when file not updated --- skops/cli/_update.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/skops/cli/_update.py b/skops/cli/_update.py index 456eb2e4..9fa8addb 100644 --- a/skops/cli/_update.py +++ b/skops/cli/_update.py @@ -38,8 +38,11 @@ def _update_file( if input_file_schema["protocol"] != PROTOCOL: dump(input_model, output_file) logger.info(f"Updated skops file written to {output_file}") - - logger.info(f"Input file is already up to date to the current protocol: {PROTOCOL}") + else: + logger.info( + "File was not updated because already up to date with the current protocol:" + f" {PROTOCOL}" + ) def format_parser( From 987eeec0ee9f047186b7bd9fa270081b3d6169eb Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Sat, 29 Apr 2023 11:27:58 +0200 Subject: [PATCH 13/26] testing all logging messages --- skops/cli/tests/test_update.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/skops/cli/tests/test_update.py b/skops/cli/tests/test_update.py index 09e3d650..415b3c32 100644 --- a/skops/cli/tests/test_update.py +++ b/skops/cli/tests/test_update.py @@ -55,9 +55,13 @@ def test_base_case_works_as_expected( assert np.array_equal(updated_obj, safe_obj) - # Check no warnings or errors raised + # Check logging messages + mock_logger.info.assert_called_once_with( + f"Updated skops file written to {new_skops_path}" + ) mock_logger.warning.assert_not_called() mock_logger.error.assert_not_called() + mock_logger.debug.assert_not_called() def test_no_update( self, From aa1f145d2047b239ba98513275bde9818f910927 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Sat, 29 Apr 2023 11:31:46 +0200 Subject: [PATCH 14/26] updated test_no_update based on feedback --- skops/cli/tests/test_update.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/skops/cli/tests/test_update.py b/skops/cli/tests/test_update.py index 415b3c32..d6ab317e 100644 --- a/skops/cli/tests/test_update.py +++ b/skops/cli/tests/test_update.py @@ -7,7 +7,7 @@ import pytest from skops.cli import _update -from skops.io import _persist, dump, load +from skops.io import _persist, _protocol, dump, load class TestUpdate: @@ -73,9 +73,11 @@ def test_no_update( _update._update_file( input_file=skops_path, output_file=new_skops_path, logger=mock_logger ) - - with pytest.raises(FileNotFoundError): - load(new_skops_path) + mock_logger.info.assert_called_once_with( + "File was not updated because already up to date with the current protocol:" + f" {_protocol.PROTOCOL}" + ) + assert not new_skops_path.exists() class TestMain: From c681d58dcc1f83bd678988d4663bd37611858ddb Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Sat, 29 Apr 2023 13:59:20 +0200 Subject: [PATCH 15/26] added inplace arg and only diff as default --- skops/cli/_update.py | 48 ++++++++++++++++++---- skops/cli/tests/test_update.py | 75 ++++++++++++++++++++++++++++++---- 2 files changed, 107 insertions(+), 16 deletions(-) diff --git a/skops/cli/_update.py b/skops/cli/_update.py index 9fa8addb..dda26ee5 100644 --- a/skops/cli/_update.py +++ b/skops/cli/_update.py @@ -5,7 +5,6 @@ import logging import zipfile from pathlib import Path -from typing import Optional, Union from skops.cli._utils import get_log_level from skops.io import dump, load @@ -13,8 +12,9 @@ def _update_file( - input_file: Union[str, Path], - output_file: Union[str, Path], + input_file: str | Path, + output_file: str | Path | None = None, + inplace: bool = False, logger: logging.Logger = logging.getLogger(), ) -> None: """Function that is called by ``skops update`` entrypoint. @@ -24,20 +24,42 @@ def _update_file( Parameters ---------- - input_file : Union[str, Path] + input_file : str, or Path Path of input skops model to load. - output_file : Union[str, Path] + output_file : str, or Path, default=None Path to save the updated skops model to. + inplace : bool, default=False + Whether to update and overwrite the input file in place. + + logger : logging.Logger, default=logging.getLogger() + Logger to use for logging. """ + if inplace: + if output_file is None: + output_file = input_file + else: + raise ValueError( + "Cannot specify both an output file path and the inplace flag. Please" + " choose whether you want to create a new file or overwrite the input" + " file." + ) + input_model = load(input_file, trusted=True) with zipfile.ZipFile(input_file, "r") as zip_file: input_file_schema = json.loads(zip_file.read("schema.json")) if input_file_schema["protocol"] != PROTOCOL: - dump(input_model, output_file) - logger.info(f"Updated skops file written to {output_file}") + if output_file is not None: + dump(input_model, output_file) + logger.info(f"Updated skops file written to {output_file}") + else: + logger.info( + f"File can be updated to the current protocol: {PROTOCOL}. Please" + " specify an output file path or use the inplace flag to create the" + " updated Skops file." + ) else: logger.info( "File was not updated because already up to date with the current protocol:" @@ -46,7 +68,7 @@ def _update_file( def format_parser( - parser: Optional[argparse.ArgumentParser] = None, + parser: argparse.ArgumentParser | None = None, ) -> argparse.ArgumentParser: """Adds arguments and help to parent CLI parser for the `update` method.""" @@ -60,6 +82,12 @@ def format_parser( "-o", "--output-file", help="Specify the output file name for the updated skops file.", + default=None, + ) + parser_subgroup.add_argument( + "--inplace", + help="Update and overwrite the input file in place.", + action="store_true", ) parser_subgroup.add_argument( "-v", @@ -79,8 +107,9 @@ def main( parsed_args: argparse.Namespace, logger: logging.Logger = logging.getLogger(), ) -> None: - output_file = Path(parsed_args.output_file) + output_file = Path(parsed_args.output_file) if parsed_args.output_file else None input_file = Path(parsed_args.input) + inplace = parsed_args.inplace logging.basicConfig(format="%(levelname)-8s: %(message)s") logger.setLevel(level=get_log_level(parsed_args.loglevel)) @@ -88,5 +117,6 @@ def main( _update_file( input_file=input_file, output_file=output_file, + inplace=inplace, logger=logger, ) diff --git a/skops/cli/tests/test_update.py b/skops/cli/tests/test_update.py index d6ab317e..51ad1e51 100644 --- a/skops/cli/tests/test_update.py +++ b/skops/cli/tests/test_update.py @@ -40,24 +40,54 @@ def dump_old_file(self, skops_path: pathlib.Path, safe_obj: np.ndarray): """ dump(safe_obj, skops_path) + @pytest.mark.parametrize("inplace", [True, False]) def test_base_case_works_as_expected( self, skops_path: pathlib.Path, new_skops_path: pathlib.Path, + inplace: bool, dump_old_file, safe_obj, ): mock_logger = mock.MagicMock() + expected_output_file = new_skops_path if not inplace else skops_path + _update._update_file( - input_file=skops_path, output_file=new_skops_path, logger=mock_logger + input_file=skops_path, + output_file=new_skops_path if not inplace else None, + inplace=inplace, + logger=mock_logger, ) - updated_obj = load(new_skops_path) + updated_obj = load(expected_output_file) assert np.array_equal(updated_obj, safe_obj) # Check logging messages mock_logger.info.assert_called_once_with( - f"Updated skops file written to {new_skops_path}" + f"Updated skops file written to {expected_output_file}" + ) + mock_logger.warning.assert_not_called() + mock_logger.error.assert_not_called() + mock_logger.debug.assert_not_called() + + @mock.patch("skops.cli._update.dump") + def test_only_diff( + self, mock_dump: mock.MagicMock, skops_path: pathlib.Path, dump_old_file + ): + mock_logger = mock.MagicMock() + _update._update_file( + input_file=skops_path, + output_file=None, + inplace=False, + logger=mock_logger, + ) + mock_dump.assert_not_called() + + # Check logging messages + mock_logger.info.assert_called_once_with( + f"File can be updated to the current protocol: {_protocol.PROTOCOL}. Please" + " specify an output file path or use the inplace flag to create the" + " updated Skops file." ) mock_logger.warning.assert_not_called() mock_logger.error.assert_not_called() @@ -71,7 +101,10 @@ def test_no_update( ): mock_logger = mock.MagicMock() _update._update_file( - input_file=skops_path, output_file=new_skops_path, logger=mock_logger + input_file=skops_path, + output_file=new_skops_path, + inplace=False, + logger=mock_logger, ) mock_logger.info.assert_called_once_with( "File was not updated because already up to date with the current protocol:" @@ -79,26 +112,53 @@ def test_no_update( ) assert not new_skops_path.exists() + def test_raises_valueerror( + self, skops_path: pathlib.Path, new_skops_path: pathlib.Path, dump_file + ): + with pytest.raises(ValueError): + _update._update_file( + input_file=skops_path, output_file=new_skops_path, inplace=True + ) + class TestMain: @pytest.fixture def tmp_logger(self) -> logging.Logger: return logging.getLogger() + @pytest.mark.parametrize("output_flag", ["-o", "--output"]) @mock.patch("skops.cli._update._update_file") - def test_base_works_as_expected( - self, mock_update: mock.MagicMock, tmp_logger: logging.Logger + def test_output_argument( + self, mock_update: mock.MagicMock, output_flag: str, tmp_logger: logging.Logger ): input_path = "abc.skops" output_path = "abc-new.skops" namespace, _ = _update.format_parser().parse_known_args( - [input_path, "-o", output_path] + [input_path, output_flag, output_path] ) _update.main(namespace, tmp_logger) mock_update.assert_called_once_with( input_file=pathlib.Path(input_path), output_file=pathlib.Path(output_path), + inplace=False, + logger=tmp_logger, + ) + + @mock.patch("skops.cli._update._update_file") + def test_inplace_argument( + self, mock_update: mock.MagicMock, tmp_logger: logging.Logger + ): + input_path = "abc.skops" + namespace, _ = _update.format_parser().parse_known_args( + [input_path, "--inplace"] + ) + + _update.main(namespace, tmp_logger) + mock_update.assert_called_once_with( + input_file=pathlib.Path(input_path), + output_file=None, + inplace=True, logger=tmp_logger, ) @@ -131,6 +191,7 @@ def test_given_log_levels_works_as_expected( mock_update.assert_called_once_with( input_file=pathlib.Path(input_path), output_file=pathlib.Path(output_path), + inplace=False, logger=tmp_logger, ) assert tmp_logger.getEffectiveLevel() == expected_level From 8f6248e2fbcb78be1cc191ba5c79730e04abf422 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Sat, 29 Apr 2023 14:05:02 +0200 Subject: [PATCH 16/26] updated entrypoint test --- skops/cli/tests/test_entrypoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skops/cli/tests/test_entrypoint.py b/skops/cli/tests/test_entrypoint.py index 0b0f9c3d..d463ffab 100644 --- a/skops/cli/tests/test_entrypoint.py +++ b/skops/cli/tests/test_entrypoint.py @@ -56,5 +56,6 @@ def test_update_works_as_expected( update_file_mock.assert_called_once_with( input_file=pathlib.Path("abc.skops"), output_file=pathlib.Path("abc-new.skops"), + inplace=False, logger=mock.ANY, ) From dcb55c1a4ed97ea37001d9d15797d6c8464529c0 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Sun, 21 May 2023 21:35:03 +0200 Subject: [PATCH 17/26] changed heading subsections --- docs/persistence.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/persistence.rst b/docs/persistence.rst index f6757299..06ee9af5 100644 --- a/docs/persistence.rst +++ b/docs/persistence.rst @@ -135,7 +135,7 @@ Skops has a command line interface to: - update ``Skops`` files to the latest version. ``skops convert`` ------------------ +~~~~~~~~~~~~~~~~~ To convert a file from the command line, use the ``skops convert`` entrypoint. @@ -156,7 +156,7 @@ Further help for the different supported options can be found by calling ``skops convert --help`` in a terminal. ``skops update`` ----------------- +~~~~~~~~~~~~~~~~ To update a ``Skops`` file from the command line, use the ``skops update`` command. From 7846077e3cf9448d2eb2a1147f37bbe3776b72bb Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Sun, 21 May 2023 21:39:11 +0200 Subject: [PATCH 18/26] changed name of test --- skops/cli/tests/test_update.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skops/cli/tests/test_update.py b/skops/cli/tests/test_update.py index 51ad1e51..d3f7b7f9 100644 --- a/skops/cli/tests/test_update.py +++ b/skops/cli/tests/test_update.py @@ -112,7 +112,7 @@ def test_no_update( ) assert not new_skops_path.exists() - def test_raises_valueerror( + def test_error_with_output_file_and_inplace( self, skops_path: pathlib.Path, new_skops_path: pathlib.Path, dump_file ): with pytest.raises(ValueError): From 854e0baaf12a315aa381399ee4694fd79350c967 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 23 May 2023 21:46:04 +0200 Subject: [PATCH 19/26] updated docstring with inplace --- skops/cli/_update.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skops/cli/_update.py b/skops/cli/_update.py index dda26ee5..54d7cb8f 100644 --- a/skops/cli/_update.py +++ b/skops/cli/_update.py @@ -19,8 +19,8 @@ def _update_file( ) -> None: """Function that is called by ``skops update`` entrypoint. - Loads a skops model from the input path, converts to the current skops format, and - saves to output file. + Loads a skops model from the input path, updates it to the current skops format, and + saves to an output file. It will overwrite the input file if `inplace` is True. Parameters ---------- From a4be07bd4e45ff890329c060fa97deeed2235a08 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 23 May 2023 22:10:52 +0200 Subject: [PATCH 20/26] refactored _update_file --- skops/cli/_update.py | 24 +++++++++++++----------- skops/cli/tests/test_update.py | 2 +- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/skops/cli/_update.py b/skops/cli/_update.py index 54d7cb8f..6183d063 100644 --- a/skops/cli/_update.py +++ b/skops/cli/_update.py @@ -50,21 +50,23 @@ def _update_file( with zipfile.ZipFile(input_file, "r") as zip_file: input_file_schema = json.loads(zip_file.read("schema.json")) - if input_file_schema["protocol"] != PROTOCOL: - if output_file is not None: - dump(input_model, output_file) - logger.info(f"Updated skops file written to {output_file}") - else: - logger.info( - f"File can be updated to the current protocol: {PROTOCOL}. Please" - " specify an output file path or use the inplace flag to create the" - " updated Skops file." - ) - else: + if input_file_schema["protocol"] == PROTOCOL: logger.info( "File was not updated because already up to date with the current protocol:" f" {PROTOCOL}" ) + return None + + if output_file is None: + logger.info( + f"File can be updated to the current protocol: {PROTOCOL}. Please" + " specify an output file path or use the `inplace` flag to create the" + " updated Skops file." + ) + return None + + dump(input_model, output_file) + logger.info(f"Updated skops file written to {output_file}") def format_parser( diff --git a/skops/cli/tests/test_update.py b/skops/cli/tests/test_update.py index d3f7b7f9..573ddc35 100644 --- a/skops/cli/tests/test_update.py +++ b/skops/cli/tests/test_update.py @@ -86,7 +86,7 @@ def test_only_diff( # Check logging messages mock_logger.info.assert_called_once_with( f"File can be updated to the current protocol: {_protocol.PROTOCOL}. Please" - " specify an output file path or use the inplace flag to create the" + " specify an output file path or use the `inplace` flag to create the" " updated Skops file." ) mock_logger.warning.assert_not_called() From 734ed43788fa85059e4d70daa5b9e9b69bac3dab Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 23 May 2023 22:16:46 +0200 Subject: [PATCH 21/26] added edge case for newer protocol --- skops/cli/_update.py | 7 +++++++ skops/cli/tests/test_update.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/skops/cli/_update.py b/skops/cli/_update.py index 6183d063..98732c8f 100644 --- a/skops/cli/_update.py +++ b/skops/cli/_update.py @@ -57,6 +57,13 @@ def _update_file( ) return None + if input_file_schema["protocol"] > PROTOCOL: + logger.info( + "File cannot be updated because its protocol is more recent than the " + f"current protocol: {PROTOCOL}" + ) + return None + if output_file is None: logger.info( f"File can be updated to the current protocol: {PROTOCOL}. Please" diff --git a/skops/cli/tests/test_update.py b/skops/cli/tests/test_update.py index 573ddc35..84984d11 100644 --- a/skops/cli/tests/test_update.py +++ b/skops/cli/tests/test_update.py @@ -40,6 +40,17 @@ def dump_old_file(self, skops_path: pathlib.Path, safe_obj: np.ndarray): """ dump(safe_obj, skops_path) + @pytest.fixture + @mock.patch( + "skops.io._persist.SaveContext", + partial(_persist.SaveContext, protocol=_protocol.PROTOCOL + 1), + ) + def dump_new_file(self, skops_path: pathlib.Path, safe_obj: np.ndarray): + """Dump an object using an new protocol version so that the file cannot be + updated. + """ + dump(safe_obj, skops_path) + @pytest.mark.parametrize("inplace", [True, False]) def test_base_case_works_as_expected( self, @@ -93,7 +104,7 @@ def test_only_diff( mock_logger.error.assert_not_called() mock_logger.debug.assert_not_called() - def test_no_update( + def test_no_update_same_protocol( self, skops_path: pathlib.Path, new_skops_path: pathlib.Path, @@ -112,6 +123,25 @@ def test_no_update( ) assert not new_skops_path.exists() + def test_no_update_newer_protocol( + self, + skops_path: pathlib.Path, + new_skops_path: pathlib.Path, + dump_new_file, + ): + mock_logger = mock.MagicMock() + _update._update_file( + input_file=skops_path, + output_file=new_skops_path, + inplace=False, + logger=mock_logger, + ) + mock_logger.info.assert_called_once_with( + "File cannot be updated because its protocol is more recent than the " + f"current protocol: {_protocol.PROTOCOL}" + ) + assert not new_skops_path.exists() + def test_error_with_output_file_and_inplace( self, skops_path: pathlib.Path, new_skops_path: pathlib.Path, dump_file ): From 50ab3a46215970966e19f9eeac689320386edfb8 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 23 May 2023 22:29:37 +0200 Subject: [PATCH 22/26] updated the docs to explain when skops updates a file --- docs/persistence.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/persistence.rst b/docs/persistence.rst index 06ee9af5..3aedec44 100644 --- a/docs/persistence.rst +++ b/docs/persistence.rst @@ -159,6 +159,7 @@ Further help for the different supported options can be found by calling ~~~~~~~~~~~~~~~~ To update a ``Skops`` file from the command line, use the ``skops update`` command. +Skops will check the protocol version of the file to determine if it needs to be updated to the current version. The below command is an example on how to create an updated version of a file ``my_model.skops`` and save it as ``my_model-updated.skops``: From b44dca493e64debdff6767fa6bec74617bebb0dd Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Fri, 26 May 2023 15:15:55 +0200 Subject: [PATCH 23/26] changed level of certain log messages --- skops/cli/_update.py | 6 +++--- skops/cli/tests/test_update.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/skops/cli/_update.py b/skops/cli/_update.py index 98732c8f..d1754b5a 100644 --- a/skops/cli/_update.py +++ b/skops/cli/_update.py @@ -51,21 +51,21 @@ def _update_file( input_file_schema = json.loads(zip_file.read("schema.json")) if input_file_schema["protocol"] == PROTOCOL: - logger.info( + logger.warning( "File was not updated because already up to date with the current protocol:" f" {PROTOCOL}" ) return None if input_file_schema["protocol"] > PROTOCOL: - logger.info( + logger.warning( "File cannot be updated because its protocol is more recent than the " f"current protocol: {PROTOCOL}" ) return None if output_file is None: - logger.info( + logger.warning( f"File can be updated to the current protocol: {PROTOCOL}. Please" " specify an output file path or use the `inplace` flag to create the" " updated Skops file." diff --git a/skops/cli/tests/test_update.py b/skops/cli/tests/test_update.py index 84984d11..0169294e 100644 --- a/skops/cli/tests/test_update.py +++ b/skops/cli/tests/test_update.py @@ -95,12 +95,12 @@ def test_only_diff( mock_dump.assert_not_called() # Check logging messages - mock_logger.info.assert_called_once_with( + mock_logger.warning.assert_called_once_with( f"File can be updated to the current protocol: {_protocol.PROTOCOL}. Please" " specify an output file path or use the `inplace` flag to create the" " updated Skops file." ) - mock_logger.warning.assert_not_called() + mock_logger.info.assert_not_called() mock_logger.error.assert_not_called() mock_logger.debug.assert_not_called() @@ -117,7 +117,7 @@ def test_no_update_same_protocol( inplace=False, logger=mock_logger, ) - mock_logger.info.assert_called_once_with( + mock_logger.warning.assert_called_once_with( "File was not updated because already up to date with the current protocol:" f" {_protocol.PROTOCOL}" ) @@ -136,7 +136,7 @@ def test_no_update_newer_protocol( inplace=False, logger=mock_logger, ) - mock_logger.info.assert_called_once_with( + mock_logger.warning.assert_called_once_with( "File cannot be updated because its protocol is more recent than the " f"current protocol: {_protocol.PROTOCOL}" ) From 3a30721bbdbcc72fa17abc36624c580c51a49943 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Sun, 18 Jun 2023 15:32:38 +0200 Subject: [PATCH 24/26] saving update skops to tmp file before saving to output --- skops/cli/_update.py | 7 ++++++- skops/cli/tests/test_update.py | 4 +++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/skops/cli/_update.py b/skops/cli/_update.py index d1754b5a..729dfcf9 100644 --- a/skops/cli/_update.py +++ b/skops/cli/_update.py @@ -3,6 +3,7 @@ import argparse import json import logging +import os import zipfile from pathlib import Path @@ -72,7 +73,11 @@ def _update_file( ) return None - dump(input_model, output_file) + tmp_output_file = f"{output_file}.tmp" + logger.debug(f"Writing updated skops file to temporary path: {tmp_output_file}") + dump(input_model, tmp_output_file) + logger.debug(f"Moving updated skops file to output path: {output_file}") + os.replace(tmp_output_file, output_file) logger.info(f"Updated skops file written to {output_file}") diff --git a/skops/cli/tests/test_update.py b/skops/cli/tests/test_update.py index 0169294e..c3a05262 100644 --- a/skops/cli/tests/test_update.py +++ b/skops/cli/tests/test_update.py @@ -79,7 +79,9 @@ def test_base_case_works_as_expected( ) mock_logger.warning.assert_not_called() mock_logger.error.assert_not_called() - mock_logger.debug.assert_not_called() + mock_logger.debug.assert_called_with( + f"Moving updated skops file to output path: {expected_output_file}" + ) @mock.patch("skops.cli._update.dump") def test_only_diff( From 68e5e675d51dd232caa18d059e1d8ff4cac219f8 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 20 Jun 2023 22:58:32 +0200 Subject: [PATCH 25/26] using tempfile to create the temporary updated file --- skops/cli/_update.py | 12 ++++++------ skops/cli/tests/test_update.py | 4 +--- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/skops/cli/_update.py b/skops/cli/_update.py index 729dfcf9..e3335642 100644 --- a/skops/cli/_update.py +++ b/skops/cli/_update.py @@ -3,7 +3,8 @@ import argparse import json import logging -import os +import shutil +import tempfile import zipfile from pathlib import Path @@ -73,11 +74,10 @@ def _update_file( ) return None - tmp_output_file = f"{output_file}.tmp" - logger.debug(f"Writing updated skops file to temporary path: {tmp_output_file}") - dump(input_model, tmp_output_file) - logger.debug(f"Moving updated skops file to output path: {output_file}") - os.replace(tmp_output_file, output_file) + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_output_file = Path(tmp_dir) / f"{output_file}.tmp" + dump(input_model, tmp_output_file) + shutil.move(tmp_output_file, output_file) logger.info(f"Updated skops file written to {output_file}") diff --git a/skops/cli/tests/test_update.py b/skops/cli/tests/test_update.py index c3a05262..0169294e 100644 --- a/skops/cli/tests/test_update.py +++ b/skops/cli/tests/test_update.py @@ -79,9 +79,7 @@ def test_base_case_works_as_expected( ) mock_logger.warning.assert_not_called() mock_logger.error.assert_not_called() - mock_logger.debug.assert_called_with( - f"Moving updated skops file to output path: {expected_output_file}" - ) + mock_logger.debug.assert_not_called() @mock.patch("skops.cli._update.dump") def test_only_diff( From 1136741e7d451e10ad38a321539c30a04c1bf252 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Wed, 21 Jun 2023 09:46:34 +0200 Subject: [PATCH 26/26] fixed types --- skops/cli/_update.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skops/cli/_update.py b/skops/cli/_update.py index e3335642..c648e37b 100644 --- a/skops/cli/_update.py +++ b/skops/cli/_update.py @@ -77,7 +77,7 @@ def _update_file( with tempfile.TemporaryDirectory() as tmp_dir: tmp_output_file = Path(tmp_dir) / f"{output_file}.tmp" dump(input_model, tmp_output_file) - shutil.move(tmp_output_file, output_file) + shutil.move(str(tmp_output_file), str(output_file)) logger.info(f"Updated skops file written to {output_file}")