Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preparation for more correct get params change #527

Merged
merged 13 commits into from
Oct 10, 2019

Conversation

BenjaminBossan
Copy link
Collaborator

Introduce changes and tests in preparation for new get_params

The new behavior of get_params will be to not returned any "learned"
attributes such as "module_".

This PR implements the new behavior but doesn't switch to it yet to
give users time to adjust their code. This is a breaking change but it
is necessary since it is the "correct" behavior; the old one could
introduce subtle bugs in rare situations (e.g. GridSearchCV with a
net that has warm_start=True).

The PR also includes tests that are currently failing but that are
passing under the new behavior. When switching to the new behavior,
all tests, including these new ones, should pass (they currently
xfail).

Relates to #521

The new behavior of get_params will be to not returned any "learned"
attributes such as "module_".

This PR implements the new behavior but doesn't switch to it yet to
give users time to adjust their code. This is a breaking change but it
is necessary since it is the "correct" behavior; the old one could
introduce subtle bugs in rare situations (e.g. `GridSearchCV` with a
net that has `warm_start=True`).

The PR also includes tests that are currently failing but that are
passing under the new behavior. When switching to the new behavior,
all tests, including these new ones, should pass (they currently
xfail).

Relates to #512
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the net_cuda.pkl udpated?

Should we raise a FutureWarning when someone tries to access a attribute that ends in _ with get_params?

skorch/net.py Outdated
# '_'). Once the transition period has passed, remove the old
# code and use the new one instead.
return (k for k in self.__dict__
if not k.endswith('_') and k != 'history')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

history is... weird parameter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should make history a property and store the values in self.history_?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this change.

@BenjaminBossan
Copy link
Collaborator Author

Why is the net_cuda.pkl udpated?

Accidental check in, I reset the file.

Should we raise a FutureWarning when someone tries to access a attribute that ends in _ with get_params?

How would you achieve this? Return a custom dictionary that checks the key on __getitem__?

I would be afraid that such a warning would trigger in unexpected places, e.g. when someone uses GridSearchCV and doesn't do anything wrong, he or she could be confused by the warning. I'm all in favor of warning users more prominently about the upcoming change but I wouldn't want to create unnecessary confusion.

BenjaminBossan and others added 7 commits September 25, 2019 21:46
Add setter and getter methods for net.history. That way, history now
ends on '_' like all other parameters that are not provided directly
by the user.
Cloning now raises an error because history_ is passed to the new
instance but is not actually set.

This will no longer be relevant once we move to the new get_params
behavior.
BenjaminBossan and others added 2 commits October 10, 2019 22:19
Co-Authored-By: ottonemo <marian.tietz@ottogroup.com>
Co-Authored-By: ottonemo <marian.tietz@ottogroup.com>
@BenjaminBossan BenjaminBossan merged commit 161f28d into master Oct 10, 2019
@BenjaminBossan BenjaminBossan deleted the preparation-for-more-correct-get_params-change branch October 13, 2019 11:39
BenjaminBossan added a commit that referenced this pull request Aug 30, 2020
This release of skorch contains a few minor improvements and some nice additions. As always, we fixed a few bugs and improved the documentation. Our [learning rate scheduler](https://skorch.readthedocs.io/en/latest/callbacks.html#skorch.callbacks.LRScheduler) now optionally logs learning rate changes to the history; moreover, it now allows the user to choose whether an update step should be made after each batch or each epoch.

If you always longed for a metric that would just use whatever is defined by your criterion, look no further than [`loss_scoring`](https://skorch.readthedocs.io/en/latest/scoring.html#skorch.scoring.loss_scoring). Also, skorch now allows you to easily change the kind of nonlinearity to apply to the module's output when `predict` and `predict_proba` are called, by passing the `predict_nonlinearity` argument.

Besides these changes, we improved the customization potential of skorch. First of all, the `criterion` is now set to `train` or `valid`, depending on the phase -- this is useful if the criterion should act differently during training and validation. Next we made it easier to add custom modules, optimizers, and criteria to your neural net; this should facilitate implementing architectures like GANs. Consult the [docs](https://skorch.readthedocs.io/en/latest/user/neuralnet.html#subclassing-neuralnet) for more on this. Conveniently, [`net.save_params`](https://skorch.readthedocs.io/en/latest/net.html#skorch.net.NeuralNet.save_params) can now persist arbitrary attributes, including those custom modules.
As always, these improvements wouldn't have been possible without the community. Please keep asking questions, raising issues, and proposing new features. We are especially grateful to those community members, old and new, who contributed via PRs:

```
Aaron Berk
guybuk
kqf
Michał Słapek
Scott Sievert
Yann Dubois
Zhao Meng
```

Here is the full list of all changes:

### Added

- Added the `event_name` argument for `LRScheduler` for optional recording of LR changes inside `net.history`. NOTE: Supported only in Pytorch>=1.4
- Make it easier to add custom modules or optimizers to a neural net class by automatically registering them where necessary and by making them available to set_params
- Added the `step_every` argument for `LRScheduler` to set whether the scheduler step should be taken on every epoch or on every batch.
- Added the `scoring` module with `loss_scoring` function, which computes the net's loss (using `get_loss`) on provided input data.
- Added a parameter `predict_nonlinearity` to `NeuralNet` which allows users to control the nonlinearity to be applied to the module output when calling `predict` and `predict_proba` (#637, #661)
- Added the possibility to save the criterion with `save_params` and with checkpoint callbacks
- Added the possibility to save custom modules with `save_params` and with checkpoint callbacks

### Changed

- Removed support for schedulers with a `batch_step()` method in `LRScheduler`.
- Raise `FutureWarning` in `CVSplit` when `random_state` is not used. Will raise an exception in a future (#620)
- The behavior of method `net.get_params` changed to make it more consistent with sklearn: it will no longer return "learned" attributes like `module_`; therefore, functions like `sklearn.base.clone`, when called with a fitted net, will no longer return a fitted net but instead an uninitialized net; if you want a copy of a fitted net, use `copy.deepcopy` instead;`net.get_params` is used under the hood by many sklearn functions and classes, such as `GridSearchCV`, whose behavior may thus be affected by the change. (#521, #527)
- Raise `FutureWarning` when using `CyclicLR` scheduler, because the default behavior has changed from taking a step every batch to taking a step every epoch. (#626)
- Set train/validation on criterion if it's a PyTorch module (#621)
- Don't pass `y=None` to `NeuralNet.train_split` to enable the direct use of split functions without positional `y` in their signatures. This is useful when working with unsupervised data (#605).
- `to_numpy` is now able to unpack dicts and lists/tuples (#657, #658)
- When using `CrossEntropyLoss`, softmax is now automatically applied to the output when calling `predict` or `predict_proba`

### Fixed

- Fixed a bug where `CyclicLR` scheduler would update during both training and validation rather than just during training.
- Fixed a bug introduced by moving the `optimizer.zero_grad()` call outside of the train step function, making it incompatible with LBFGS and other optimizers that call the train step several times per batch (#636)
- Fixed pickling of the `ProgressBar` callback (#656)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants