Skip to content

Commit

Permalink
Added MlflowLogger terminate_after_train paramter, Added MlflowLogger…
Browse files Browse the repository at this point in the history
… test with real client and run
  • Loading branch information
cacharle committed Jun 9, 2021
1 parent e153d99 commit 4d82e1f
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 2 deletions.
2 changes: 1 addition & 1 deletion skorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@
'Unfreezer',
'WandbLogger',
'WarmRestartLR',
'MLflow',
'MlflowLogger',
]
11 changes: 10 additions & 1 deletion skorch/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,7 @@ def __init__(
run=None,
client=None,
create_artifact=True,
terminate_after_train=True,
log_on_batch_end=False,
log_on_epoch_end=True,
keys_ignored=None,
Expand All @@ -901,9 +902,10 @@ def __init__(
):
self.run = run
self.client = client
self.create_artifact = create_artifact
self.terminate_after_train = terminate_after_train
self.log_on_batch_end = log_on_batch_end
self.log_on_epoch_end = log_on_epoch_end
self.create_artifact = create_artifact
self.keys_ignored = keys_ignored
self.batch_suffix = batch_suffix
self.epoch_suffix = epoch_suffix
Expand Down Expand Up @@ -947,6 +949,13 @@ def _iteration_log(self, logs, suffix):
self.client.log_metric(self.run_id, key + suffix, logs[key])

def on_train_end(self, net, **kwargs):
try:
self._log_artifacts(net)
finally:
if self.terminate_after_train:
self.client.set_terminated(self.run_id)

def _log_artifacts(self, net):
if not self.create_artifact:
return
with tempfile.TemporaryDirectory(prefix='skorch_mlflow_logger_') as dirpath:
Expand Down
91 changes: 91 additions & 0 deletions skorch/tests/callbacks/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,55 @@ def test_epoch_batch_suffixes_defaults(
assert logger.batch_suffix_ == batch_suffix
assert logger.epoch_suffix_ == epoch_suffix

@pytest.mark.parametrize('batch_suffix', ['', '_foo'])
@pytest.mark.parametrize('epoch_suffix', ['', '_bar'])
def test_epoch_batch_custom_suffix(
self,
logger_cls,
mock_client,
mock_run,
batch_suffix,
epoch_suffix
):
logger = logger_cls(
mock_run,
mock_client,
log_on_batch_end=True,
log_on_epoch_end=True,
batch_suffix=batch_suffix,
epoch_suffix=epoch_suffix,
).initialize()
assert logger.batch_suffix_ == batch_suffix
assert logger.epoch_suffix_ == epoch_suffix

def test_dont_log_epoch_metrics(
self,
logger_cls,
mock_client,
mock_run,
net_cls,
classifier_module,
data
):
logger = logger_cls(
mock_run,
mock_client,
log_on_batch_end=True,
log_on_epoch_end=False,
batch_suffix='_batch',
epoch_suffix='_epoch',
)
net_cls(
classifier_module,
batch_size=10,
callbacks=[logger],
max_epochs=3,
).fit(*data)
assert all(
call[0][1].endswith('_batch')
for call in mock_client.log_metric.call_args_list
)

def test_artifact_filenames(self, net_fitted, mock_client):
keys = {call_args[0][1].name
for call_args in mock_client.log_artifact.call_args_list}
Expand All @@ -1099,3 +1148,45 @@ def test_dont_create_artifact(
max_epochs=3,
).fit(*data)
assert not mock_client.log_artifact.called

def test_run_terminated_automatically(self, net_fitted, mock_client):
assert mock_client.set_terminated.call_count == 1

def test_run_not_closed(
self,
net_cls,
classifier_module,
data,
logger_cls,
mock_run,
mock_client,
):
net_cls(
classifier_module,
callbacks=[
logger_cls(mock_run, mock_client, terminate_after_train=False)
],
max_epochs=2,
).fit(*data)
assert mock_client.set_terminated.call_count == 0

def test_fit_with_real_run_and_client(
self,
net_cls,
classifier_module,
data,
logger_cls,
tmp_path,
):
from mlflow.tracking import MlflowClient
client = MlflowClient(tracking_uri=tmp_path.as_uri())
experiment_name = 'foo'
experiment_id = client.create_experiment(experiment_name)
run = client.create_run(experiment_id)
logger = logger_cls(run, client, create_artifact=False)
net_cls(
classifier_module,
callbacks=[logger],
max_epochs=3,
).fit(*data)
assert os.listdir(tmp_path)

0 comments on commit 4d82e1f

Please sign in to comment.