Skip to content

Commit

Permalink
Merge pull request #643 from skorch-dev/feature/add-tests-for-criteri…
Browse files Browse the repository at this point in the history
…on-train-valid

Add tests for criterion being set to train/valid
  • Loading branch information
BenjaminBossan authored Jun 2, 2020
2 parents a621d0f + 0ddd8be commit cbab769
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Raise `FutureWarning` in `CVSplit` when `random_state` is not used. Will raise an exception in a future (#620)
- The behavior of method `net.get_params` changed to make it more consistent with sklearn: it will no longer return "learned" attributes like `module_`; therefore, functions like `sklearn.base.clone`, when called with a fitted net, will no longer return a fitted net but instead an uninitialized net; if you want a copy of a fitted net, use `copy.deepcopy` instead;`net.get_params` is used under the hood by many sklearn functions and classes, such as `GridSearchCV`, whose behavior may thus be affected by the change. (#521, #527)
- Raise `FutureWarning` when using `CyclicLR` scheduler, because the default behavior has changed from taking a step every batch to taking a step every epoch. (#626)
- Set train/validation on criterion if it's a PyTorch module (#621)

### Fixed

Expand Down
32 changes: 32 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -2271,6 +2271,38 @@ def test_set_lr_at_runtime_sets_lr_pgroups(self, net_cls, module_cls, data):
assert net.optimizer_.param_groups[0]['lr'] == lr_pgroup_0_new
assert net.optimizer_.param_groups[1]['lr'] == lr_pgroup_1_new

def test_criterion_training_set_correctly(self, net_cls, module_cls, data):
# check that criterion's training attribute is set correctly

X, y = data[0][:50], data[1][:50] # don't need all the data
side_effect = []

class MyCriterion(nn.NLLLoss):
"""Criterion that records its training attribute"""
def forward(self, *args, **kwargs):
side_effect.append(self.training)
return super().forward(*args, **kwargs)

net = net_cls(module_cls, criterion=MyCriterion, max_epochs=1)
net.fit(X, y)

# called once with training=True for train step, once with
# training=False for validation step
assert side_effect == [True, False]

net.partial_fit(X, y)
# same logic as before
assert side_effect == [True, False, True, False]

def test_criterion_is_not_a_torch_module(self, net_cls, module_cls, data):
X, y = data[0][:50], data[1][:50] # don't need all the data

def my_criterion():
return torch.nn.functional.nll_loss

net = net_cls(module_cls, criterion=my_criterion, max_epochs=1)
net.fit(X, y) # does not raise

@pytest.mark.parametrize('acc_steps', [1, 2, 3, 5, 10])
def test_gradient_accumulation(self, net_cls, module_cls, data, acc_steps):
# Test if gradient accumulation technique is possible,
Expand Down

0 comments on commit cbab769

Please sign in to comment.