Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added MLflow callback #770

Merged
merged 9 commits into from
Jun 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `load_best` attribute to `Checkpoint` callback to automatically load state of the best result at the end of training
- Added a `get_all_learnable_params` method to retrieve the named parameters of all PyTorch modules defined on the net, including of criteria if applicable
- Added `MlflowLogger` callback for logging to Mlflow (#769)

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
'sklearn': ('http://scikit-learn.org/stable/', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'python': ('https://docs.python.org/3', None),
'mlflow': ('https://mlflow.org/docs/latest/', None),
}

# Add any paths that contain templates here, relative to this directory.
Expand Down
1 change: 1 addition & 0 deletions skorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
'Initializer',
'LRScheduler',
'LoadInitState',
'MlflowLogger',
'NeptuneLogger',
'ParamMapper',
'PassthroughScoring',
Expand Down
184 changes: 183 additions & 1 deletion skorch/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sys
import time
import tempfile
from contextlib import suppress
from numbers import Number
from itertools import cycle
Expand All @@ -16,7 +17,7 @@
from skorch.callbacks import Callback

__all__ = ['EpochTimer', 'NeptuneLogger', 'WandbLogger', 'PrintLog', 'ProgressBar',
'TensorBoard', 'SacredLogger']
'TensorBoard', 'SacredLogger', 'MlflowLogger']


def filter_log_keys(keys, keys_ignored=None):
Expand Down Expand Up @@ -875,3 +876,184 @@ def on_epoch_end(self, net, **kwargs):

for key in filter_log_keys(epoch_logs.keys(), self.keys_ignored_):
self.experiment.log_scalar(key + self.epoch_suffix_, epoch_logs[key], epoch)


class MlflowLogger(Callback):
"""Logs results from history and artifact to Mlflow

"MLflow is an open source platform for managing
the end-to-end machine learning lifecycle" (:doc:`mlflow:index`)

Use this callback to automatically log your metrics
and create/log artifacts to mlflow.

The best way to log additional information is to log directly to the
experiment object or subclass the ``on_*`` methods.

To use this logger, you first have to install Mlflow:

.. code-block::

$ pip install mlflow

Examples
--------

Mlflow :doc:`fluent API <mlflow:python_api/mlflow>`:

>>> import mlflow
>>> net = NeuralNetClassifier(net, callbacks=[MLflowLogger()])
>>> with mlflow.start_run():
... net.fit(X, y)

Custom :py:class:`run <mlflow.entities.Run>` and
:py:class:`client <mlflow.tracking.MlflowClient>`:

>>> from mlflow.tracking import MlflowClient
>>> client = MlflowClient()
>>> experiment = client.get_experiment_by_name('Default')
>>> run = client.create_run(experiment.experiment_id)
>>> net = NeuralNetClassifier(..., callbacks=[MlflowLogger(run, client)])
>>> net.fit(X, y)

Parameters
----------

run : mlflow.entities.Run (default=None)
Instantiated :py:class:`mlflow.entities.Run` class.
By default (if set to ``None``),
:py:func:`mlflow.active_run` is used to get the current run.

client : mlflow.tracking.MlflowClient (default=None)
Instantiated :py:class:`mlflow.tracking.MlflowClient` class.
By default (if set to ``None``),
``MlflowClient()`` is used, which by default has:

- the tracking URI set by :py:func:`mlflow.set_tracking_uri`
- the registry URI set by :py:func:`mlflow.set_registry_uri`

create_artifact : bool (default=True)
Whether to create artifacts for the network's
params, optimizer, criterion and history.
See :ref:`save_load`

terminate_after_train : bool (default=True)
Whether to terminate the ``Run`` object once training finishes.

log_on_batch_end : bool (default=False)
Whether to log loss and other metrics on batch level.

log_on_epoch_end : bool (default=True)
Whether to log loss and other metrics on epoch level.

batch_suffix : str (default=None)
A string that will be appended to all logged keys. By default (if set to
``None``) ``'_batch'`` is used if batch and epoch logging are both enabled
and no suffix is used otherwise.

epoch_suffix : str (default=None)
A string that will be appended to all logged keys. By default (if set to
``None``) ``'_epoch'`` is used if batch and epoch logging are both enabled
and no suffix is used otherwise.

keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to Mlflow. Note that in
addition to the keys provided by the user, keys such as those starting
with ``'event_'`` or ending on ``'_best'`` are ignored by default.
"""
def __init__(
self,
run=None,
client=None,
create_artifact=True,
terminate_after_train=True,
log_on_batch_end=False,
log_on_epoch_end=True,
batch_suffix=None,
epoch_suffix=None,
keys_ignored=None,
):
self.run = run
self.client = client
self.create_artifact = create_artifact
self.terminate_after_train = terminate_after_train
self.log_on_batch_end = log_on_batch_end
self.log_on_epoch_end = log_on_epoch_end
self.batch_suffix = batch_suffix
self.epoch_suffix = epoch_suffix
self.keys_ignored = keys_ignored

def initialize(self):
self.run_ = self.run
if self.run_ is None:
import mlflow
self.run_ = mlflow.active_run()
self.client_ = self.client
if self.client_ is None:
from mlflow.tracking import MlflowClient
self.client_ = MlflowClient()
keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')
self.batch_suffix_ = self._init_suffix(self.batch_suffix, '_batch')
self.epoch_suffix_ = self._init_suffix(self.epoch_suffix, '_epoch')
return self

def _init_suffix(self, suffix, default):
if suffix is not None:
return suffix
return default if self.log_on_batch_end and self.log_on_epoch_end else ''

def on_train_begin(self, net, **kwargs):
self._batch_count = 0

def on_batch_end(self, net, training, **kwargs):
if not self.log_on_batch_end:
return
self._batch_count += 1
batch_logs = net.history[-1]['batches'][-1]
self._iteration_log(batch_logs, self.batch_suffix_, self._batch_count)

def on_epoch_end(self, net, **kwargs):
if not self.log_on_epoch_end:
return
epoch_logs = net.history[-1]
self._iteration_log(epoch_logs, self.epoch_suffix_, len(net.history))

def _iteration_log(self, logs, suffix, step):
for key in filter_log_keys(logs.keys(), self.keys_ignored_):
self.client_.log_metric(
self.run_.info.run_id,
key + suffix,
logs[key],
step=step,
)

def on_train_end(self, net, **kwargs):
try:
self._log_artifacts(net)
finally:
if self.terminate_after_train:
self.client_.set_terminated(self.run_.info.run_id)

def _log_artifacts(self, net):
if not self.create_artifact:
return
with tempfile.TemporaryDirectory(prefix='skorch_mlflow_logger_') as dirpath:
dirpath = Path(dirpath)
params_filepath = dirpath / 'params.pth'
optimizer_filepath = dirpath / 'optimizer.pth'
criterion_filepath = dirpath / 'criterion.pth'
history_filepath = dirpath / 'history.json'
net.save_params(
f_params=params_filepath,
f_optimizer=optimizer_filepath,
f_criterion=criterion_filepath,
f_history=history_filepath,
)
self.client_.log_artifact(self.run_.info.run_id, params_filepath)
self.client_.log_artifact(self.run_.info.run_id, optimizer_filepath)
self.client_.log_artifact(self.run_.info.run_id, criterion_filepath)
self.client_.log_artifact(self.run_.info.run_id, history_filepath)
Loading