|
39 | 39 | from .utilities import (shape, reshape, ndim, hstack, cross_product, transpose, inverse_onehot,
|
40 | 40 | broadcast_unit_treatments, reshape_treatmentwise_effects,
|
41 | 41 | StatsModelsLinearRegression, LassoCVWrapper, check_high_dimensional)
|
42 |
| -from econml.sklearn_extensions.linear_model import MultiOutputDebiasedLasso |
| 42 | +from econml.sklearn_extensions.linear_model import MultiOutputDebiasedLasso, WeightedLassoCVWrapper |
43 | 43 | from sklearn.model_selection import KFold, StratifiedKFold, check_cv
|
44 |
| -from sklearn.linear_model import LinearRegression, LassoCV, ElasticNetCV |
| 44 | +from sklearn.linear_model import LinearRegression, LassoCV, LogisticRegressionCV, ElasticNetCV |
45 | 45 | from sklearn.preprocessing import (PolynomialFeatures, LabelEncoder, OneHotEncoder,
|
46 | 46 | FunctionTransformer)
|
47 | 47 | from sklearn.base import clone, TransformerMixin
|
|
52 | 52 | DebiasedLassoCateEstimatorMixin)
|
53 | 53 | from .inference import StatsModelsInference
|
54 | 54 | from ._rlearner import _RLearner
|
| 55 | +from .sklearn_extensions.model_selection import WeightedStratifiedKFold |
55 | 56 |
|
56 | 57 |
|
57 | 58 | class DMLCateEstimator(_RLearner):
|
@@ -116,9 +117,15 @@ class takes as input the parameter `model_t`, which is an arbitrary scikit-learn
|
116 | 117 | The estimator for fitting the response to the features. Must implement
|
117 | 118 | `fit` and `predict` methods. Must be a linear model for correctness when linear_first_stages is ``True``.
|
118 | 119 |
|
119 |
| - model_t: estimator |
120 |
| - The estimator for fitting the treatment to the features. Must implement |
121 |
| - `fit` and `predict` methods. Must be a linear model for correctness when linear_first_stages is ``True``. |
| 120 | + model_t: estimator or 'auto' (default is 'auto') |
| 121 | + The estimator for fitting the treatment to the features. |
| 122 | + If estimator, it must implement `fit` and `predict` methods. Must be a linear model for correctness |
| 123 | + when linear_first_stages is ``True``; |
| 124 | + If 'auto', :class:`LogisticRegressionCV() <sklearn.linear_model.LogisticRegressionCV>` |
| 125 | + will be applied for discrete treatment, |
| 126 | + and :class:`WeightedLassoCV() <econml.sklearn_extensions.linear_model.WeightedLassoCV>`/ |
| 127 | + :class:`WeightedMultitaskLassoCV() <econml.sklearn_extensions.linear_model.WeightedMultitaskLassoCV>` |
| 128 | + will be applied for continuous treatment. |
122 | 129 |
|
123 | 130 | model_final: estimator
|
124 | 131 | The estimator for fitting the response residuals to the treatment residuals. Must implement
|
@@ -170,6 +177,12 @@ def __init__(self,
|
170 | 177 | # TODO: consider whether we need more care around stateful featurizers,
|
171 | 178 | # since we clone it and fit separate copies
|
172 | 179 |
|
| 180 | + if model_t == 'auto': |
| 181 | + if discrete_treatment: |
| 182 | + model_t = LogisticRegressionCV(cv=WeightedStratifiedKFold()) |
| 183 | + else: |
| 184 | + model_t = WeightedLassoCVWrapper() |
| 185 | + |
173 | 186 | class FirstStageWrapper:
|
174 | 187 | def __init__(self, model, is_Y):
|
175 | 188 | self._model = clone(model, safe=False)
|
@@ -284,13 +297,19 @@ class LinearDMLCateEstimator(StatsModelsCateEstimatorMixin, DMLCateEstimator):
|
284 | 297 |
|
285 | 298 | Parameters
|
286 | 299 | ----------
|
287 |
| - model_y: estimator |
| 300 | + model_y: estimator, optional (default is :class:`WeightedLassoCVWrapper() |
| 301 | + <econml.sklearn_extensions.linear_model.WeightedLassoCVWrapper>`) |
288 | 302 | The estimator for fitting the response to the features. Must implement
|
289 | 303 | `fit` and `predict` methods.
|
290 | 304 |
|
291 |
| - model_t: estimator |
292 |
| - The estimator for fitting the treatment to the features. Must implement |
293 |
| - `fit` and `predict` methods. |
| 305 | + model_t: estimator or 'auto', optional (default is 'auto') |
| 306 | + The estimator for fitting the treatment to the features. |
| 307 | + If estimator, it must implement `fit` and `predict` methods; |
| 308 | + If 'auto', :class:`LogisticRegressionCV() <sklearn.linear_model.LogisticRegressionCV>` |
| 309 | + will be applied for discrete treatment, |
| 310 | + and :class:`WeightedLassoCV() <econml.sklearn_extensions.linear_model.WeightedLassoCV>`/ |
| 311 | + :class:`WeightedMultitaskLassoCV() <econml.sklearn_extensions.linear_model.WeightedMultitaskLassoCV>` |
| 312 | + will be applied for continuous treatment. |
294 | 313 |
|
295 | 314 | featurizer: transformer, optional (default is \
|
296 | 315 | :class:`PolynomialFeatures(degree=1, include_bias=True) <sklearn.preprocessing.PolynomialFeatures>`)
|
@@ -329,7 +348,7 @@ class LinearDMLCateEstimator(StatsModelsCateEstimatorMixin, DMLCateEstimator):
|
329 | 348 | """
|
330 | 349 |
|
331 | 350 | def __init__(self,
|
332 |
| - model_y=LassoCV(), model_t=LassoCV(), |
| 351 | + model_y=WeightedLassoCVWrapper(), model_t='auto', |
333 | 352 | featurizer=PolynomialFeatures(degree=1, include_bias=True),
|
334 | 353 | linear_first_stages=True,
|
335 | 354 | discrete_treatment=False,
|
@@ -389,13 +408,20 @@ class SparseLinearDMLCateEstimator(DebiasedLassoCateEstimatorMixin, DMLCateEstim
|
389 | 408 |
|
390 | 409 | Parameters
|
391 | 410 | ----------
|
392 |
| - model_y: estimator |
| 411 | + model_y: estimator, optional (default is :class:`WeightedLassoCVWrapper() |
| 412 | + <econml.sklearn_extensions.linear_model.WeightedLassoCVWrapper>`) |
393 | 413 | The estimator for fitting the response to the features. Must implement
|
394 | 414 | `fit` and `predict` methods.
|
395 | 415 |
|
396 |
| - model_t: estimator |
397 |
| - The estimator for fitting the treatment to the features. Must implement |
398 |
| - `fit` and `predict` methods, and must be a linear model for correctness. |
| 416 | + model_t: estimator or 'auto', optional (default is 'auto') |
| 417 | + The estimator for fitting the treatment to the features. |
| 418 | + If estimator, it must implement `fit` and `predict` methods, and must be a |
| 419 | + linear model for correctness; |
| 420 | + If 'auto', :class:`LogisticRegressionCV() <sklearn.linear_model.LogisticRegressionCV>` |
| 421 | + will be applied for discrete treatment, |
| 422 | + and :class:`WeightedLassoCV() <econml.sklearn_extensions.linear_model.WeightedLassoCV>`/ |
| 423 | + :class:`WeightedMultitaskLassoCV() <econml.sklearn_extensions.linear_model.WeightedMultitaskLassoCV>` |
| 424 | + will be applied for continuous treatment. |
399 | 425 |
|
400 | 426 | alpha: string | float, optional. Default='auto'.
|
401 | 427 | CATE L1 regularization applied through the debiased lasso in the final model.
|
@@ -446,7 +472,7 @@ class SparseLinearDMLCateEstimator(DebiasedLassoCateEstimatorMixin, DMLCateEstim
|
446 | 472 | """
|
447 | 473 |
|
448 | 474 | def __init__(self,
|
449 |
| - model_y=LassoCV(), model_t=LassoCV(), |
| 475 | + model_y=WeightedLassoCVWrapper(), model_t='auto', |
450 | 476 | alpha='auto',
|
451 | 477 | max_iter=1000,
|
452 | 478 | tol=1e-4,
|
@@ -511,13 +537,18 @@ class KernelDMLCateEstimator(DMLCateEstimator):
|
511 | 537 |
|
512 | 538 | Parameters
|
513 | 539 | ----------
|
514 |
| - model_y: estimator, optional (default is :class:`LassoCV() <sklearn.linear_model.LassoCV>`) |
| 540 | + model_y: estimator, optional (default is :class:`<econml.sklearn_extensions.linear_model.WeightedLassoCVWrapper>`) |
515 | 541 | The estimator for fitting the response to the features. Must implement
|
516 | 542 | `fit` and `predict` methods.
|
517 | 543 |
|
518 |
| - model_t: estimator, optional (default is :class:`LassoCV() <sklearn.linear_model.LassoCV>`) |
519 |
| - The estimator for fitting the treatment to the features. Must implement |
520 |
| - `fit` and `predict` methods. |
| 544 | + model_t: estimator or 'auto', optional (default is 'auto') |
| 545 | + The estimator for fitting the treatment to the features. |
| 546 | + If estimator, it must implement `fit` and `predict` methods; |
| 547 | + If 'auto', :class:`LogisticRegressionCV() <sklearn.linear_model.LogisticRegressionCV>` |
| 548 | + will be applied for discrete treatment, |
| 549 | + and :class:`WeightedLassoCV() <econml.sklearn_extensions.linear_model.WeightedLassoCV>`/ |
| 550 | + :class:`WeightedMultitaskLassoCV() <econml.sklearn_extensions.linear_model.WeightedMultitaskLassoCV>` |
| 551 | + will be applied for continuous treatment. |
521 | 552 |
|
522 | 553 | dim: int, optional (default is 20)
|
523 | 554 | The number of random Fourier features to generate
|
@@ -551,7 +582,7 @@ class KernelDMLCateEstimator(DMLCateEstimator):
|
551 | 582 | by :mod:`np.random<numpy.random>`.
|
552 | 583 | """
|
553 | 584 |
|
554 |
| - def __init__(self, model_y=LassoCV(), model_t=LassoCV(), |
| 585 | + def __init__(self, model_y=WeightedLassoCVWrapper(), model_t='auto', |
555 | 586 | dim=20, bw=1.0, discrete_treatment=False, n_splits=2, random_state=None):
|
556 | 587 | class RandomFeatures(TransformerMixin):
|
557 | 588 | def __init__(self, random_state):
|
|
0 commit comments