Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for criterion being set to train/valid #643

Merged
merged 1 commit into from
Jun 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Raise `FutureWarning` in `CVSplit` when `random_state` is not used. Will raise an exception in a future (#620)
- The behavior of method `net.get_params` changed to make it more consistent with sklearn: it will no longer return "learned" attributes like `module_`; therefore, functions like `sklearn.base.clone`, when called with a fitted net, will no longer return a fitted net but instead an uninitialized net; if you want a copy of a fitted net, use `copy.deepcopy` instead;`net.get_params` is used under the hood by many sklearn functions and classes, such as `GridSearchCV`, whose behavior may thus be affected by the change. (#521, #527)
- Raise `FutureWarning` when using `CyclicLR` scheduler, because the default behavior has changed from taking a step every batch to taking a step every epoch. (#626)
- Set train/validation on criterion if it's a PyTorch module (#621)

### Fixed

Expand Down
32 changes: 32 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -2271,6 +2271,38 @@ def test_set_lr_at_runtime_sets_lr_pgroups(self, net_cls, module_cls, data):
assert net.optimizer_.param_groups[0]['lr'] == lr_pgroup_0_new
assert net.optimizer_.param_groups[1]['lr'] == lr_pgroup_1_new

def test_criterion_training_set_correctly(self, net_cls, module_cls, data):
# check that criterion's training attribute is set correctly

X, y = data[0][:50], data[1][:50] # don't need all the data
side_effect = []

class MyCriterion(nn.NLLLoss):
"""Criterion that records its training attribute"""
def forward(self, *args, **kwargs):
side_effect.append(self.training)
return super().forward(*args, **kwargs)

net = net_cls(module_cls, criterion=MyCriterion, max_epochs=1)
net.fit(X, y)

# called once with training=True for train step, once with
# training=False for validation step
assert side_effect == [True, False]

net.partial_fit(X, y)
# same logic as before
assert side_effect == [True, False, True, False]

def test_criterion_is_not_a_torch_module(self, net_cls, module_cls, data):
X, y = data[0][:50], data[1][:50] # don't need all the data

def my_criterion():
return torch.nn.functional.nll_loss

net = net_cls(module_cls, criterion=my_criterion, max_epochs=1)
net.fit(X, y) # does not raise

@pytest.mark.parametrize('acc_steps', [1, 2, 3, 5, 10])
def test_gradient_accumulation(self, net_cls, module_cls, data, acc_steps):
# Test if gradient accumulation technique is possible,
Expand Down