-
Notifications
You must be signed in to change notification settings - Fork 395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make skorch work with sklearn 1.6.0, attempt 2 #1078
Make skorch work with sklearn 1.6.0, attempt 2 #1078
Conversation
Altenative to #1076 As described in that PR, skorch is currently not compatible with sklearn 1.6.0 or above. As per suggestion, instead of implementing __sklearn_tags__, this PR solves the issue by inheriting from BaseEstimator. Related changes: - It is important to set the correct order when inheriting from BaseEstimator and, say, ClassifierMixin (BaseEstimator should come last). - As explained in #1076, using GridSearchCV with y being a torch tensor fails and two tests had to be adjusted. Unrelated changes - Removed unnecessary imports from callbacks/base.py.
@adrinjalali As per your suggestion, I started inheriting from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nice!
Two notes:
- scikit-learn now supports array API in many places, which means tensors remain tensors in cross validation etc. It's still experimental and in progress, but you might want to experiment with the config flag.
- You might want to have a
set_fit_request
like in this PR to properly enable metadata routing:
Okay, I'll close the other PR in favor of this one then.
I'm not sure if the error is directly related to the array API, I think it's more that sklearn is inconsistent in checking the input arrays? Here is an example that shows that mixing numpy arrays and torch tensors works with import torch
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
yt = torch.tensor(y)
lr = LogisticRegression()
lr.fit(X, yt) # works
gs = GridSearchCV(lr, param_grid={"C": [1, 10, 100]})
gs.fit(X, yt) # fails The error message is a not very helpful:
(both devices's reprs are "cpu", but one is The reason is the check from lines like this: which causes an error in When removing this check, the score can be calculated correctly, but of course there is no guarantee for that.
Probably, right now it's not officially supported (aka there are no tests for routing). But let's leave that for another PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -450,12 +451,13 @@ def test_grid_search_with_slds_works( | |||
gs = GridSearchCV( | |||
net, params, refit=False, cv=3, scoring='accuracy', error_score='raise' | |||
) | |||
gs.fit(slds, y) # does not raise | |||
gs.fit(slds, to_numpy(y)) # does not raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not supporting a mixture of NumPy arrays and CPU pytorch arrays, feels like a regression from scikit-learn's side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me, the question is: is it an intended change or was it introduced accidentally?
If it's unintended, one way to resolve this would be here:
There would need to be another check
- if one of the devices is an instance of
torch.device
- and for this device,
device.type == "cpu"
- and the other device is
"cpu"
it should be allowed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like there is already a bug report and a PR to fix it in scikit-learn:
scikit-learn/scikit-learn#29107 (comment)
scikit-learn/scikit-learn#30454
..when sklearn > 1.6 is released.
Please welcome skorch 1.1.0 - a smaller release with a few fixes, a new notebook showcasing learning rate schedulers and mainly support for scikit-learn 1.6.0. Full list of changes: ### Added - Added a [notebook](https://github.com/skorch-dev/skorch/blob/master/notebooks/Learning_Rate_Scheduler.ipynb) that shows how to use Learning Rate Scheduler in skorch.(#1074) ### Changed - All neural net classes now inherit from sklearn's [`BaseEstimator`](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html). This is to support compatibility with sklearn 1.6.0 and above. Classification models additionally inherit from [`ClassifierMixin`](https://scikit-learn.org/stable/modules/generated/sklearn.base.ClassifierMixin.html) and regressors from [`RegressorMixin`](https://scikit-learn.org/stable/modules/generated/sklearn.base.RegressorMixin.html). (#1078) - When using the `ReduceLROnPlateau` learning rate scheduler, we now record the learning rate in the net history (`net.history[:, 'event_lr']` by default). It is now also possible to to step per batch, not only by epoch (#1075) - The learning rate scheduler `.simulate()` method now supports adding step args which is useful when simulation policies such as `ReduceLROnPlateau` which expect metrics to base their schedule on. (#1077) - Removed deprecated `skorch.callbacks.scoring.cache_net_infer` (#1088) ### Fixed - Fix an issue with using `NeuralNetBinaryClassifier` with `torch.compile` (#1058)
Alternative to #1076
As described in that PR, skorch is currently not compatible with sklearn 1.6.0 or above. As per suggestion, instead of implementing
__sklearn_tags__
, this PR solves the issue by inheriting fromBaseEstimator
.Related changes:
BaseEstimator
and, say,ClassifierMixin
(BaseEstimator
should come last).GridSearchCV
withy
being a torch tensor fails and two tests had to be adjusted.Unrelated changes
callbacks/base.py
.