-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com> Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
- Loading branch information
1 parent
311f524
commit e9f6e19
Showing
10 changed files
with
398 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from __future__ import annotations | ||
|
||
import argparse | ||
import logging | ||
import os | ||
import pathlib | ||
import pickle | ||
from typing import Optional | ||
|
||
from skops.cli._utils import get_log_level | ||
from skops.io import dumps, get_untrusted_types | ||
|
||
|
||
def _convert_file( | ||
input_file: os.PathLike, | ||
output_file: os.PathLike, | ||
logger: logging.Logger = logging.getLogger(), | ||
) -> None: | ||
"""Function that is called by ``skops convert`` entrypoint. | ||
Loads a pickle model from the input path, converts to skops format, and saves to | ||
output file. | ||
Parameters | ||
---------- | ||
input_file : os.PathLike | ||
Path of input .pkl model to load. | ||
output_file : os.PathLike | ||
Path to save .skops model to. | ||
""" | ||
model_name = pathlib.Path(input_file).stem | ||
|
||
logger.debug(f"Converting {model_name}") | ||
|
||
with open(input_file, "rb") as f: | ||
obj = pickle.load(f) | ||
skops_dump = dumps(obj) | ||
|
||
untrusted_types = get_untrusted_types(data=skops_dump) | ||
|
||
if not untrusted_types: | ||
logger.info(f"No unknown types found in {model_name}.") | ||
else: | ||
untrusted_str = ", ".join(untrusted_types) | ||
|
||
logger.warning( | ||
f"While converting {input_file}, " | ||
"the following unknown types were found: " | ||
f"{untrusted_str}. " | ||
f"When loading {output_file} with skops.load, these types must be " | ||
"specified as 'trusted'" | ||
) | ||
|
||
with open(output_file, "wb") as out_file: | ||
logger.debug(f"Writing to {output_file}") | ||
out_file.write(skops_dump) | ||
|
||
|
||
def format_parser( | ||
parser: Optional[argparse.ArgumentParser] = None, | ||
) -> argparse.ArgumentParser: | ||
"""Adds arguments and help to parent CLI parser for the convert method.""" | ||
|
||
if not parser: # used in tests | ||
parser = argparse.ArgumentParser() | ||
|
||
parser_subgroup = parser.add_argument_group("convert") | ||
parser_subgroup.add_argument("input", help="Path to an input file to convert. ") | ||
|
||
parser_subgroup.add_argument( | ||
"-o", | ||
"--output-file", | ||
help=( | ||
"Specify the output file name for the converted skops file. " | ||
"If not provided, will default to using the same name as the input file, " | ||
"and saving to the current working directory with the suffix '.skops'." | ||
), | ||
default=None, | ||
) | ||
parser_subgroup.add_argument( | ||
"-v", | ||
"--verbose", | ||
help=( | ||
"Increases verbosity of logging. Can be used multiple times to increase " | ||
"verbosity further." | ||
), | ||
action="count", | ||
dest="loglevel", | ||
default=0, | ||
) | ||
return parser | ||
|
||
|
||
def main( | ||
parsed_args: argparse.Namespace, | ||
) -> None: | ||
output_file = parsed_args.output_file | ||
input_file = parsed_args.input | ||
|
||
logging.basicConfig( | ||
format="%(levelname)-8s: %(message)s", level=get_log_level(parsed_args.loglevel) | ||
) | ||
|
||
if not output_file: | ||
# No filename provided, defaulting to base file path | ||
file_name = pathlib.Path(input_file).stem | ||
output_file = pathlib.Path.cwd() / f"{file_name}.skops" | ||
|
||
_convert_file( | ||
input_file=input_file, | ||
output_file=output_file, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import logging | ||
|
||
|
||
def get_log_level(level: int = 0) -> int: | ||
"""Takes in verbosity from a CLI entrypoint (number of times -v specified), | ||
and sets the logger to the required log level""" | ||
|
||
all_levels = [logging.WARNING, logging.INFO, logging.DEBUG] | ||
|
||
if level > len(all_levels): | ||
level = len(all_levels) - 1 | ||
elif level < 0: | ||
level = 0 | ||
|
||
return all_levels[level] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import argparse | ||
|
||
import skops.cli._convert | ||
|
||
|
||
def main_cli(command_line_args=None): | ||
"""Main command line interface entrypoint for all command line Skops methods. | ||
To add a new entrypoint: | ||
1. Create a new method to call that accepts a namespace | ||
2. Create a new subparser formatter to define the expected CL arguments | ||
3. Add those to the function map. | ||
""" | ||
entry_parser = argparse.ArgumentParser( | ||
prog="Skops", | ||
description="Main entrypoint for all command line Skops methods.", | ||
add_help=True, | ||
) | ||
|
||
subparsers = entry_parser.add_subparsers( | ||
title="Commands", | ||
description="Skops command to call", | ||
dest="cmd", | ||
help="Sub-commands help", | ||
) | ||
|
||
# function_map should map a command to | ||
# method: the command to call (gets set to default 'func') | ||
# format_parser: the function used to create a subparser for that command | ||
function_map = { | ||
"convert": { | ||
"method": skops.cli._convert.main, | ||
"format_parser": skops.cli._convert.format_parser, | ||
}, | ||
} | ||
|
||
for func_name, values in function_map.items(): | ||
# Add subparser for each function in func map, | ||
# and assigns default func to be "method" from function_map | ||
subparser = subparsers.add_parser(func_name) | ||
subparser.set_defaults(func=values["method"]) | ||
values["format_parser"](subparser) | ||
|
||
# Parse arguments with arg parser for given function in function map, | ||
# Then call the matching method in the function_map with the argument namespace | ||
args = entry_parser.parse_args(command_line_args) | ||
args.func(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import logging | ||
import pathlib | ||
import pickle | ||
from unittest import mock | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from skops.cli import _convert | ||
from skops.io import load | ||
|
||
|
||
class MockUnsafeType: | ||
def __init__(self): | ||
pass | ||
|
||
|
||
class TestConvert: | ||
model_name = "some_model_name" | ||
|
||
@pytest.fixture | ||
def safe_obj(self): | ||
return np.ndarray([1, 2, 3, 4]) | ||
|
||
@pytest.fixture | ||
def unsafe_obj(self): | ||
return MockUnsafeType() | ||
|
||
@pytest.fixture | ||
def pkl_path(self, tmp_path): | ||
return tmp_path / f"{self.model_name}.pkl" | ||
|
||
@pytest.fixture | ||
def skops_path(self, tmp_path): | ||
return tmp_path / f"{self.model_name}.skops" | ||
|
||
@pytest.fixture | ||
def write_safe_file(self, pkl_path, safe_obj): | ||
with open(pkl_path, "wb") as f: | ||
pickle.dump(safe_obj, f) | ||
|
||
@pytest.fixture | ||
def write_unsafe_file(self, pkl_path, unsafe_obj): | ||
with open(pkl_path, "wb") as f: | ||
pickle.dump(unsafe_obj, f) | ||
|
||
def test_base_case_works_as_expected( | ||
self, pkl_path, tmp_path, skops_path, write_safe_file, safe_obj, caplog | ||
): | ||
mock_logger = mock.MagicMock() | ||
_convert._convert_file(pkl_path, skops_path, logger=mock_logger) | ||
persisted_obj = load(skops_path) | ||
assert np.array_equal(persisted_obj, safe_obj) | ||
|
||
# Check no warnings or errors raised | ||
mock_logger.warning.assert_not_called() | ||
mock_logger.error.assert_not_called() | ||
|
||
def test_unsafe_case_works_as_expected( | ||
self, pkl_path, tmp_path, skops_path, write_unsafe_file, caplog | ||
): | ||
caplog.set_level(logging.WARNING) | ||
_convert._convert_file(pkl_path, skops_path) | ||
persisted_obj = load(skops_path, trusted=True) | ||
|
||
assert isinstance(persisted_obj, MockUnsafeType) | ||
|
||
# check logging has warned that an unsafe type was found | ||
assert MockUnsafeType.__name__ in caplog.text | ||
|
||
|
||
class TestMain: | ||
@staticmethod | ||
def assert_called_correctly( | ||
mock_convert: mock.MagicMock, | ||
path, | ||
output_file=None, | ||
): | ||
if not output_file: | ||
output_file = pathlib.Path.cwd() / f"{pathlib.Path(path).stem}.skops" | ||
mock_convert.assert_called_once_with(input_file=path, output_file=output_file) | ||
|
||
@mock.patch("skops.cli._convert._convert_file") | ||
def test_base_works_as_expected(self, mock_convert: mock.MagicMock): | ||
path = "123.pkl" | ||
namespace, _ = _convert.format_parser().parse_known_args([path]) | ||
|
||
_convert.main(namespace) | ||
self.assert_called_correctly(mock_convert, path) | ||
|
||
@mock.patch("skops.cli._convert._convert_file") | ||
@pytest.mark.parametrize( | ||
"input_path, output_file, expected_path", | ||
[ | ||
("abc.123", "some/file/path.out", "some/file/path.out"), | ||
("abc.123", None, pathlib.Path.cwd() / "abc.skops"), | ||
], | ||
ids=["Given an output path", "No output path"], | ||
) | ||
def test_with_output_dir_works_as_expected( | ||
self, mock_convert: mock.MagicMock, input_path, output_file, expected_path | ||
): | ||
if output_file is not None: | ||
args = [input_path, "--output", output_file] | ||
else: | ||
args = [input_path] | ||
|
||
namespace, _ = _convert.format_parser().parse_known_args(args) | ||
|
||
_convert.main(namespace) | ||
self.assert_called_correctly( | ||
mock_convert, path=input_path, output_file=expected_path | ||
) | ||
|
||
@mock.patch("skops.cli._convert._convert_file") | ||
@pytest.mark.parametrize( | ||
"verbosity, expected_level", | ||
[ | ||
("", logging.WARNING), | ||
("-v", logging.INFO), | ||
("--verbose", logging.INFO), | ||
("-vv", logging.DEBUG), | ||
("-v -v", logging.DEBUG), | ||
("-vvvvv", logging.DEBUG), | ||
("--verbose --verbose", logging.DEBUG), | ||
], | ||
) | ||
def test_given_log_levels_works_as_expected( | ||
self, mock_convert: mock.MagicMock, verbosity, expected_level, caplog | ||
): | ||
input_path = "abc.def" | ||
output_path = "bde.skops" | ||
args = [input_path, "--output", output_path, verbosity.split()] | ||
|
||
namespace, _ = _convert.format_parser().parse_known_args(args) | ||
|
||
_convert.main(namespace) | ||
self.assert_called_correctly( | ||
mock_convert, path=input_path, output_file=output_path | ||
) | ||
|
||
assert caplog.at_level(expected_level) |
Oops, something went wrong.