Skip to content

Commit 818c832

Browse files
heimengqiMiruna Oprescu
authored and
Miruna Oprescu
committed
automate the first stage model T and update DML notebook (#172)
* automate the first stage model T and update DML notebook * Changed model defaults in ORF and fixed a bug in WeightedKFold
1 parent 607e0ea commit 818c832

8 files changed

+196
-125
lines changed

econml/dml.py

+51-20
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@
3939
from .utilities import (shape, reshape, ndim, hstack, cross_product, transpose, inverse_onehot,
4040
broadcast_unit_treatments, reshape_treatmentwise_effects,
4141
StatsModelsLinearRegression, LassoCVWrapper, check_high_dimensional)
42-
from econml.sklearn_extensions.linear_model import MultiOutputDebiasedLasso
42+
from econml.sklearn_extensions.linear_model import MultiOutputDebiasedLasso, WeightedLassoCVWrapper
4343
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
44-
from sklearn.linear_model import LinearRegression, LassoCV, ElasticNetCV
44+
from sklearn.linear_model import LinearRegression, LassoCV, LogisticRegressionCV, ElasticNetCV
4545
from sklearn.preprocessing import (PolynomialFeatures, LabelEncoder, OneHotEncoder,
4646
FunctionTransformer)
4747
from sklearn.base import clone, TransformerMixin
@@ -52,6 +52,7 @@
5252
DebiasedLassoCateEstimatorMixin)
5353
from .inference import StatsModelsInference
5454
from ._rlearner import _RLearner
55+
from .sklearn_extensions.model_selection import WeightedStratifiedKFold
5556

5657

5758
class DMLCateEstimator(_RLearner):
@@ -116,9 +117,15 @@ class takes as input the parameter `model_t`, which is an arbitrary scikit-learn
116117
The estimator for fitting the response to the features. Must implement
117118
`fit` and `predict` methods. Must be a linear model for correctness when linear_first_stages is ``True``.
118119
119-
model_t: estimator
120-
The estimator for fitting the treatment to the features. Must implement
121-
`fit` and `predict` methods. Must be a linear model for correctness when linear_first_stages is ``True``.
120+
model_t: estimator or 'auto' (default is 'auto')
121+
The estimator for fitting the treatment to the features.
122+
If estimator, it must implement `fit` and `predict` methods. Must be a linear model for correctness
123+
when linear_first_stages is ``True``;
124+
If 'auto', :class:`LogisticRegressionCV() <sklearn.linear_model.LogisticRegressionCV>`
125+
will be applied for discrete treatment,
126+
and :class:`WeightedLassoCV() <econml.sklearn_extensions.linear_model.WeightedLassoCV>`/
127+
:class:`WeightedMultitaskLassoCV() <econml.sklearn_extensions.linear_model.WeightedMultitaskLassoCV>`
128+
will be applied for continuous treatment.
122129
123130
model_final: estimator
124131
The estimator for fitting the response residuals to the treatment residuals. Must implement
@@ -170,6 +177,12 @@ def __init__(self,
170177
# TODO: consider whether we need more care around stateful featurizers,
171178
# since we clone it and fit separate copies
172179

180+
if model_t == 'auto':
181+
if discrete_treatment:
182+
model_t = LogisticRegressionCV(cv=WeightedStratifiedKFold())
183+
else:
184+
model_t = WeightedLassoCVWrapper()
185+
173186
class FirstStageWrapper:
174187
def __init__(self, model, is_Y):
175188
self._model = clone(model, safe=False)
@@ -284,13 +297,19 @@ class LinearDMLCateEstimator(StatsModelsCateEstimatorMixin, DMLCateEstimator):
284297
285298
Parameters
286299
----------
287-
model_y: estimator
300+
model_y: estimator, optional (default is :class:`WeightedLassoCVWrapper()
301+
<econml.sklearn_extensions.linear_model.WeightedLassoCVWrapper>`)
288302
The estimator for fitting the response to the features. Must implement
289303
`fit` and `predict` methods.
290304
291-
model_t: estimator
292-
The estimator for fitting the treatment to the features. Must implement
293-
`fit` and `predict` methods.
305+
model_t: estimator or 'auto', optional (default is 'auto')
306+
The estimator for fitting the treatment to the features.
307+
If estimator, it must implement `fit` and `predict` methods;
308+
If 'auto', :class:`LogisticRegressionCV() <sklearn.linear_model.LogisticRegressionCV>`
309+
will be applied for discrete treatment,
310+
and :class:`WeightedLassoCV() <econml.sklearn_extensions.linear_model.WeightedLassoCV>`/
311+
:class:`WeightedMultitaskLassoCV() <econml.sklearn_extensions.linear_model.WeightedMultitaskLassoCV>`
312+
will be applied for continuous treatment.
294313
295314
featurizer: transformer, optional (default is \
296315
:class:`PolynomialFeatures(degree=1, include_bias=True) <sklearn.preprocessing.PolynomialFeatures>`)
@@ -329,7 +348,7 @@ class LinearDMLCateEstimator(StatsModelsCateEstimatorMixin, DMLCateEstimator):
329348
"""
330349

331350
def __init__(self,
332-
model_y=LassoCV(), model_t=LassoCV(),
351+
model_y=WeightedLassoCVWrapper(), model_t='auto',
333352
featurizer=PolynomialFeatures(degree=1, include_bias=True),
334353
linear_first_stages=True,
335354
discrete_treatment=False,
@@ -389,13 +408,20 @@ class SparseLinearDMLCateEstimator(DebiasedLassoCateEstimatorMixin, DMLCateEstim
389408
390409
Parameters
391410
----------
392-
model_y: estimator
411+
model_y: estimator, optional (default is :class:`WeightedLassoCVWrapper()
412+
<econml.sklearn_extensions.linear_model.WeightedLassoCVWrapper>`)
393413
The estimator for fitting the response to the features. Must implement
394414
`fit` and `predict` methods.
395415
396-
model_t: estimator
397-
The estimator for fitting the treatment to the features. Must implement
398-
`fit` and `predict` methods, and must be a linear model for correctness.
416+
model_t: estimator or 'auto', optional (default is 'auto')
417+
The estimator for fitting the treatment to the features.
418+
If estimator, it must implement `fit` and `predict` methods, and must be a
419+
linear model for correctness;
420+
If 'auto', :class:`LogisticRegressionCV() <sklearn.linear_model.LogisticRegressionCV>`
421+
will be applied for discrete treatment,
422+
and :class:`WeightedLassoCV() <econml.sklearn_extensions.linear_model.WeightedLassoCV>`/
423+
:class:`WeightedMultitaskLassoCV() <econml.sklearn_extensions.linear_model.WeightedMultitaskLassoCV>`
424+
will be applied for continuous treatment.
399425
400426
alpha: string | float, optional. Default='auto'.
401427
CATE L1 regularization applied through the debiased lasso in the final model.
@@ -446,7 +472,7 @@ class SparseLinearDMLCateEstimator(DebiasedLassoCateEstimatorMixin, DMLCateEstim
446472
"""
447473

448474
def __init__(self,
449-
model_y=LassoCV(), model_t=LassoCV(),
475+
model_y=WeightedLassoCVWrapper(), model_t='auto',
450476
alpha='auto',
451477
max_iter=1000,
452478
tol=1e-4,
@@ -511,13 +537,18 @@ class KernelDMLCateEstimator(DMLCateEstimator):
511537
512538
Parameters
513539
----------
514-
model_y: estimator, optional (default is :class:`LassoCV() <sklearn.linear_model.LassoCV>`)
540+
model_y: estimator, optional (default is :class:`<econml.sklearn_extensions.linear_model.WeightedLassoCVWrapper>`)
515541
The estimator for fitting the response to the features. Must implement
516542
`fit` and `predict` methods.
517543
518-
model_t: estimator, optional (default is :class:`LassoCV() <sklearn.linear_model.LassoCV>`)
519-
The estimator for fitting the treatment to the features. Must implement
520-
`fit` and `predict` methods.
544+
model_t: estimator or 'auto', optional (default is 'auto')
545+
The estimator for fitting the treatment to the features.
546+
If estimator, it must implement `fit` and `predict` methods;
547+
If 'auto', :class:`LogisticRegressionCV() <sklearn.linear_model.LogisticRegressionCV>`
548+
will be applied for discrete treatment,
549+
and :class:`WeightedLassoCV() <econml.sklearn_extensions.linear_model.WeightedLassoCV>`/
550+
:class:`WeightedMultitaskLassoCV() <econml.sklearn_extensions.linear_model.WeightedMultitaskLassoCV>`
551+
will be applied for continuous treatment.
521552
522553
dim: int, optional (default is 20)
523554
The number of random Fourier features to generate
@@ -551,7 +582,7 @@ class KernelDMLCateEstimator(DMLCateEstimator):
551582
by :mod:`np.random<numpy.random>`.
552583
"""
553584

554-
def __init__(self, model_y=LassoCV(), model_t=LassoCV(),
585+
def __init__(self, model_y=WeightedLassoCVWrapper(), model_t='auto',
555586
dim=20, bw=1.0, discrete_treatment=False, n_splits=2, random_state=None):
556587
class RandomFeatures(TransformerMixin):
557588
def __init__(self, random_state):

econml/ortho_forest.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@
3333
from sklearn.pipeline import Pipeline
3434
from sklearn.preprocessing import OneHotEncoder, LabelEncoder, PolynomialFeatures, FunctionTransformer
3535
from sklearn.utils import check_random_state, check_array, column_or_1d
36+
from .sklearn_extensions.linear_model import WeightedLassoCVWrapper
3637
from .cate_estimator import BaseCateEstimator, LinearCateEstimator, TreatmentExpansionMixin
3738
from .causal_tree import CausalTree
38-
from .utilities import reshape, reshape_Y_T, MAX_RAND_SEED, check_inputs, WeightedModelWrapper, cross_product
39+
from .utilities import reshape, reshape_Y_T, MAX_RAND_SEED, check_inputs, cross_product
3940

4041

4142
def _build_tree_in_parallel(Y, T, X, W,
@@ -399,8 +400,8 @@ def __init__(self,
399400
subsample_ratio=0.7,
400401
bootstrap=False,
401402
lambda_reg=0.01,
402-
model_T=WeightedModelWrapper(LassoCV(cv=3)),
403-
model_Y=WeightedModelWrapper(LassoCV(cv=3)),
403+
model_T=WeightedLassoCVWrapper(cv=3),
404+
model_Y=WeightedLassoCVWrapper(cv=3),
404405
model_T_final=None,
405406
model_Y_final=None,
406407
n_jobs=-1,
@@ -627,7 +628,7 @@ def __init__(self,
627628
lambda_reg=0.01,
628629
propensity_model=LogisticRegression(penalty='l1', solver='saga',
629630
multi_class='auto'), # saga solver supports l1
630-
model_Y=WeightedModelWrapper(LassoCV(cv=3)),
631+
model_Y=WeightedLassoCVWrapper(cv=3),
631632
propensity_model_final=None,
632633
model_Y_final=None,
633634
n_jobs=-1,

econml/sklearn_extensions/linear_model.py

+38
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from collections.abc import Iterable
1010
from scipy.stats import norm
1111
from econml.sklearn_extensions.model_selection import WeightedKFold, WeightedStratifiedKFold
12+
from econml.utilities import ndim, shape, reshape
1213
from sklearn.linear_model import LassoCV, MultiTaskLassoCV, Lasso, MultiTaskLasso
1314
from sklearn.model_selection import KFold, StratifiedKFold
1415
from sklearn.model_selection._split import _CVIterableWrapper, CV_WARNING
@@ -1048,3 +1049,40 @@ def _set_attribute(self, attribute_name, condition=True, default=None):
10481049
else:
10491050
attribute_value = default
10501051
setattr(self, attribute_name, attribute_value)
1052+
1053+
1054+
class WeightedLassoCVWrapper:
1055+
"""Helper class to wrap either WeightedLassoCV or WeightedMultiTaskLassoCV depending on the shape of the target."""
1056+
1057+
def __init__(self, *args, **kwargs):
1058+
self.args = args
1059+
self.kwargs = kwargs
1060+
1061+
def fit(self, X, y, sample_weight=None):
1062+
self.needs_unravel = False
1063+
if ndim(y) == 2 and shape(y)[1] > 1:
1064+
self.model = WeightedMultiTaskLassoCV(*self.args, **self.kwargs)
1065+
else:
1066+
if ndim(y) == 2 and shape(y)[1] == 1:
1067+
y = np.ravel(y)
1068+
self.needs_unravel = True
1069+
self.model = WeightedLassoCV(*self.args, **self.kwargs)
1070+
self.model.fit(X, y, sample_weight)
1071+
# set intercept_ attribute
1072+
self.intercept_ = self.model.intercept_
1073+
# set coef_ attribute
1074+
self.coef_ = self.model.coef_
1075+
# set alpha_ attribute
1076+
self.alpha_ = self.model.alpha_
1077+
# set alphas_ attribute
1078+
self.alphas_ = self.model.alphas_
1079+
# set n_iter_ attribute
1080+
self.n_iter_ = self.model.n_iter_
1081+
return self
1082+
1083+
def predict(self, X):
1084+
predictions = self.model.predict(X)
1085+
return reshape(predictions, (-1, 1)) if self.needs_unravel else predictions
1086+
1087+
def score(self, X, y, sample_weight=None):
1088+
return self.model.score(X, y, sample_weight)

econml/sklearn_extensions/model_selection.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _split_weighted_sample(self, X, y, sample_weight, is_stratified=False):
3030
return self._get_folds_from_splits(splits, X.shape[0])
3131
# Record all splits in case the stratification by weight yeilds a worse partition
3232
all_splits.append(splits)
33-
max_deviation = np.abs(weight_fracs - 1 / self.n_splits)
33+
max_deviation = np.max(np.abs(weight_fracs - 1 / self.n_splits))
3434
max_deviations.append(max_deviation)
3535
# Reseed random generator and try again
3636
kfold_model.shuffle = True
@@ -57,7 +57,7 @@ def _split_weighted_sample(self, X, y, sample_weight, is_stratified=False):
5757
# Did not find a good split
5858
# Record the devaiation for the weight-stratified split to compare with KFold splits
5959
all_splits.append(stratified_weight_splits)
60-
max_deviation = np.abs(weight_fracs - 1 / self.n_splits)
60+
max_deviation = np.max(np.abs(weight_fracs - 1 / self.n_splits))
6161
max_deviations.append(max_deviation)
6262
# Return most weight-balanced partition
6363
min_deviation_index = np.argmin(max_deviations)

econml/tests/test_dml.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def make_random(is_discrete, d):
7676
all_infs.append(BootstrapInference(1))
7777

7878
for est, multi, infs in [(LinearDMLCateEstimator(model_y=Lasso(),
79-
model_t=model_t,
79+
model_t='auto',
8080
discrete_treatment=is_discrete),
8181
False,
8282
all_infs),
@@ -149,8 +149,8 @@ def test_can_use_vectors(self):
149149
def test_can_use_sample_weights(self):
150150
"""Test that we can pass sample weights to an estimator."""
151151
dmls = [
152-
LinearDMLCateEstimator(LinearRegression(), LinearRegression(), featurizer=FunctionTransformer()),
153-
SparseLinearDMLCateEstimator(LinearRegression(), LinearRegression(), featurizer=FunctionTransformer())
152+
LinearDMLCateEstimator(LinearRegression(), 'auto', featurizer=FunctionTransformer()),
153+
SparseLinearDMLCateEstimator(LinearRegression(), 'auto', featurizer=FunctionTransformer())
154154
]
155155
for dml in dmls:
156156
dml.fit(np.array([1, 2, 3, 1, 2, 3]), np.array([1, 2, 3, 1, 2, 3]),

econml/tests/test_orf.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from sklearn.linear_model import LinearRegression, Lasso, LassoCV, LogisticRegression, LogisticRegressionCV
1111
from sklearn.multioutput import MultiOutputRegressor
1212
from sklearn.pipeline import Pipeline
13-
from econml.ortho_forest import ContinuousTreatmentOrthoForest, DiscreteTreatmentOrthoForest, \
14-
WeightedModelWrapper
13+
from econml.ortho_forest import ContinuousTreatmentOrthoForest, DiscreteTreatmentOrthoForest
14+
from econml.sklearn_extensions.linear_model import WeightedLassoCVWrapper
1515

1616

1717
class TestOrthoForest(unittest.TestCase):
@@ -53,8 +53,8 @@ def test_continuous_treatments(self):
5353
est = ContinuousTreatmentOrthoForest(n_jobs=4, n_trees=10,
5454
model_T=Lasso(),
5555
model_Y=Lasso(),
56-
model_T_final=WeightedModelWrapper(LassoCV(), sample_type="weighted"),
57-
model_Y_final=WeightedModelWrapper(LassoCV(), sample_type="weighted"))
56+
model_T_final=WeightedLassoCVWrapper(),
57+
model_Y_final=WeightedLassoCVWrapper())
5858
# Test inputs for continuous treatments
5959
# --> Check that one can pass in regular lists
6060
est.fit(list(Y), list(T), list(TestOrthoForest.X), list(TestOrthoForest.W))
@@ -69,8 +69,8 @@ def test_continuous_treatments(self):
6969
max_depth=50, subsample_ratio=0.30, bootstrap=False, n_jobs=4,
7070
model_T=Lasso(alpha=0.024),
7171
model_Y=Lasso(alpha=0.024),
72-
model_T_final=WeightedModelWrapper(LassoCV(), sample_type="weighted"),
73-
model_Y_final=WeightedModelWrapper(LassoCV(), sample_type="weighted"))
72+
model_T_final=WeightedLassoCVWrapper(),
73+
model_Y_final=WeightedLassoCVWrapper())
7474
est.fit(Y, T, TestOrthoForest.X, TestOrthoForest.W)
7575
self._test_te(est, TestOrthoForest.expected_exp_te, tol=0.5)
7676
# Test continuous treatments without controls
@@ -94,7 +94,7 @@ def test_binary_treatments(self):
9494
est = DiscreteTreatmentOrthoForest(n_trees=10, n_jobs=4,
9595
propensity_model=LogisticRegression(), model_Y=Lasso(),
9696
propensity_model_final=LogisticRegressionCV(penalty='l1', solver='saga'),
97-
model_Y_final=WeightedModelWrapper(LassoCV(), sample_type="weighted"))
97+
model_Y_final=WeightedLassoCVWrapper())
9898
# Test inputs for binary treatments
9999
# --> Check that one can pass in regular lists
100100
est.fit(list(Y), list(T), list(TestOrthoForest.X), list(TestOrthoForest.W))
@@ -118,7 +118,7 @@ def test_binary_treatments(self):
118118
propensity_model=LogisticRegression(C=1 / 0.024, penalty='l1'),
119119
model_Y=Lasso(alpha=0.024),
120120
propensity_model_final=LogisticRegressionCV(penalty='l1', solver='saga'),
121-
model_Y_final=WeightedModelWrapper(LassoCV(), sample_type="weighted"))
121+
model_Y_final=WeightedLassoCVWrapper())
122122
est.fit(Y, T, TestOrthoForest.X, TestOrthoForest.W)
123123
self._test_te(est, TestOrthoForest.expected_exp_te, tol=0.7, treatment_type='discrete')
124124
# Test binary treatments without controls
@@ -146,9 +146,8 @@ def test_multiple_treatments(self):
146146
max_depth=50, subsample_ratio=0.30, bootstrap=False, n_jobs=4,
147147
model_T=MultiOutputRegressor(Lasso(alpha=0.024)),
148148
model_Y=Lasso(alpha=0.024),
149-
model_T_final=WeightedModelWrapper(
150-
MultiOutputRegressor(LassoCV()), sample_type="weighted"),
151-
model_Y_final=WeightedModelWrapper(LassoCV(), sample_type="weighted"))
149+
model_T_final=WeightedLassoCVWrapper(),
150+
model_Y_final=WeightedLassoCVWrapper())
152151
est.fit(Y, T, TestOrthoForest.X, TestOrthoForest.W)
153152
expected_te = np.array([TestOrthoForest.expected_exp_te, TestOrthoForest.expected_const_te]).T
154153
self._test_te(est, expected_te, tol=0.5, treatment_type='multi')

0 commit comments

Comments
 (0)