Skip to content

Commit 8c28f52

Browse files
committed
Add fit_cate_intercept to DML, rework feature generation
1 parent 57327b4 commit 8c28f52

File tree

2 files changed

+66
-47
lines changed

2 files changed

+66
-47
lines changed

econml/dml.py

+54-35
Original file line numberDiff line numberDiff line change
@@ -124,18 +124,19 @@ class takes as input the parameter `model_t`, which is an arbitrary scikit-learn
124124
The estimator for fitting the response residuals to the treatment residuals. Must implement
125125
`fit` and `predict` methods, and must be a linear model for correctness.
126126
127-
featurizer: transformer
128-
The transformer used to featurize the raw features when fitting the final model. Must implement
129-
a `fit_transform` method.
127+
featurizer: :term:`transformer`, optional, default None
128+
Must support fit_transform and transform. Used to create composite features in the final CATE regression.
129+
It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
130+
If featurizer=None, then CATE is trained on X.
130131
131132
linear_first_stages: bool
132133
Whether the first stage models are linear (in which case we will expand the features passed to
133134
`model_y` accordingly)
134135
135-
discrete_treatment: bool, optional (default is ``False``)
136+
discrete_treatment: bool, optional, default False
136137
Whether the treatment values should be treated as categorical, rather than continuous, quantities
137138
138-
n_splits: int, cross-validation generator or an iterable, optional (Default=2)
139+
n_splits: int, cross-validation generator or an iterable, optional, default 2
139140
Determines the cross-validation splitting strategy.
140141
Possible inputs for cv are:
141142
@@ -161,7 +162,7 @@ class takes as input the parameter `model_t`, which is an arbitrary scikit-learn
161162

162163
def __init__(self,
163164
model_y, model_t, model_final,
164-
featurizer,
165+
featurizer=None,
165166
linear_first_stages=False,
166167
discrete_treatment=False,
167168
n_splits=2,
@@ -177,22 +178,23 @@ def __init__(self, model, is_Y):
177178
self._is_Y = is_Y
178179

179180
def _combine(self, X, W, n_samples, fitting=True):
181+
no_x = X is None
182+
if no_x:
183+
X = np.ones((n_samples, 1))
184+
if W is None:
185+
W = np.empty((n_samples, 0))
186+
XW = hstack([X, W])
180187
if self._is_Y and linear_first_stages:
181-
if X is not None:
182-
F = self._featurizer.fit_transform(X) if fitting else self._featurizer.transform(X)
188+
if no_x:
189+
return XW
190+
191+
if self._featurizer is None:
192+
F = X
183193
else:
184-
X = np.ones((n_samples, 1))
185-
F = np.ones((n_samples, 1))
186-
if W is None:
187-
W = np.empty((n_samples, 0))
188-
XW = hstack([X, W])
189-
return cross_product(XW, hstack([np.ones((shape(XW)[0], 1)), F, W]))
194+
F = self._featurizer.fit_transform(X) if fitting else self._featurizer.transform(X)
195+
return cross_product(XW, hstack([np.ones((shape(XW)[0], 1)), F]))
190196
else:
191-
if X is None:
192-
X = np.ones((n_samples, 1))
193-
if W is None:
194-
W = np.empty((n_samples, 0))
195-
return hstack([X, W])
197+
return XW
196198

197199
def fit(self, X, W, Target, sample_weight=None):
198200
if (not self._is_Y) and discrete_treatment:
@@ -221,12 +223,21 @@ def __init__(self):
221223
self._model = clone(model_final, safe=False)
222224
self._featurizer = clone(featurizer, safe=False)
223225

226+
def _combine(self, X, T, fitting=True):
227+
if X is not None:
228+
if self._featurizer is not None:
229+
F = self._featurizer.fit_transform(X) if fitting else self._featurizer.transform(X)
230+
else:
231+
F = X
232+
else:
233+
F = np.ones((T.shape[0], 1))
234+
return cross_product(F, T)
235+
224236
def fit(self, X, T_res, Y_res, sample_weight=None, sample_var=None):
225237
# Track training dimensions to see if Y or T is a vector instead of a 2-dimensional array
226238
self._d_t = shape(T_res)[1:]
227239
self._d_y = shape(Y_res)[1:]
228-
F = self._featurizer.fit_transform(X) if X is not None else np.ones((T_res.shape[0], 1))
229-
fts = cross_product(F, T_res)
240+
fts = self._combine(X, T_res)
230241
if sample_weight is not None:
231242
if sample_var is not None:
232243
self._model.fit(fts,
@@ -246,9 +257,9 @@ def fit(self, X, T_res, Y_res, sample_weight=None, sample_var=None):
246257
self._intercept = intercept
247258

248259
def predict(self, X):
249-
F = self._featurizer.transform(X) if X is not None else np.ones((1, 1))
250-
F, T = broadcast_unit_treatments(F, self._d_t[0] if self._d_t else 1)
251-
prediction = self._model.predict(cross_product(F, T))
260+
X2, T = broadcast_unit_treatments(X if X is not None else np.empty((1, 0)),
261+
self._d_t[0] if self._d_t else 1)
262+
prediction = self._model.predict(self._combine(None if X is None else X2, T, fitting=False))
252263
if self._intercept is not None:
253264
prediction -= self._intercept
254265
return reshape_treatmentwise_effects(prediction,
@@ -292,10 +303,13 @@ class LinearDMLCateEstimator(StatsModelsCateEstimatorMixin, DMLCateEstimator):
292303
The estimator for fitting the treatment to the features. Must implement
293304
`fit` and `predict` methods.
294305
295-
featurizer: transformer, optional (default is \
296-
:class:`PolynomialFeatures(degree=1, include_bias=True) <sklearn.preprocessing.PolynomialFeatures>`)
297-
The transformer used to featurize the raw features when fitting the final model. Must implement
298-
a `fit_transform` method.
306+
featurizer : :term:`transformer`, optional, default None
307+
Must support fit_transform and transform. Used to create composite features in the final CATE regression.
308+
It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
309+
If featurizer=None, then CATE is trained on X.
310+
311+
fit_cate_intercept : bool, optional, default True
312+
Whether the linear CATE model should have a constant term.
299313
300314
linear_first_stages: bool
301315
Whether the first stage models are linear (in which case we will expand the features passed to
@@ -330,14 +344,15 @@ class LinearDMLCateEstimator(StatsModelsCateEstimatorMixin, DMLCateEstimator):
330344

331345
def __init__(self,
332346
model_y=LassoCV(), model_t=LassoCV(),
333-
featurizer=PolynomialFeatures(degree=1, include_bias=True),
347+
featurizer=None,
348+
fit_cate_intercept=True,
334349
linear_first_stages=True,
335350
discrete_treatment=False,
336351
n_splits=2,
337352
random_state=None):
338353
super().__init__(model_y=model_y,
339354
model_t=model_t,
340-
model_final=StatsModelsLinearRegression(fit_intercept=False),
355+
model_final=StatsModelsLinearRegression(fit_intercept=fit_cate_intercept),
341356
featurizer=featurizer,
342357
linear_first_stages=linear_first_stages,
343358
discrete_treatment=discrete_treatment,
@@ -410,10 +425,13 @@ class SparseLinearDMLCateEstimator(DebiasedLassoCateEstimatorMixin, DMLCateEstim
410425
dual gap for optimality and continues until it is smaller
411426
than ``tol``.
412427
413-
featurizer: transformer, optional
414-
(default is :class:`PolynomialFeatures(degree=1, include_bias=True) <sklearn.preprocessing.PolynomialFeatures>`)
415-
The transformer used to featurize the raw features when fitting the final model. Must implement
416-
a `fit_transform` method.
428+
featurizer : :term:`transformer`, optional, default None
429+
Must support fit_transform and transform. Used to create composite features in the final CATE regression.
430+
It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
431+
If featurizer=None, then CATE is trained on X.
432+
433+
fit_cate_intercept : bool, optional, default True
434+
Whether the linear CATE model should have a constant term.
417435
418436
linear_first_stages: bool
419437
Whether the first stage models are linear (in which case we will expand the features passed to
@@ -450,7 +468,8 @@ def __init__(self,
450468
alpha='auto',
451469
max_iter=1000,
452470
tol=1e-4,
453-
featurizer=PolynomialFeatures(degree=1, include_bias=True),
471+
featurizer=None,
472+
fit_cate_intercept=True,
454473
linear_first_stages=True,
455474
discrete_treatment=False,
456475
n_splits=2,

econml/drlearner.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,15 @@ class takes as input the parameter `model_regressor``, which is an arbitrary sci
114114
mono-task model and a separate clone of the model is trained for each outcome. Then predict(X) of the t-th
115115
clone will be the CATE of the t-th lexicographically ordered treatment compared to the baseline.
116116
117-
multitask_model_final : optional bool (default=False)
117+
multitask_model_final : bool, optional, default False
118118
Whether the model_final should be treated as a multi-task model. See description of model_final.
119119
120-
featurizer : sklearn featurizer or None
120+
featurizer : :term:`transformer`, optional, default None
121121
Must support fit_transform and transform. Used to create composite features in the final CATE regression.
122122
It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
123123
If featurizer=None, then CATE is trained on X.
124124
125-
n_splits: int, cross-validation generator or an iterable, optional (Default=2)
125+
n_splits: int, cross-validation generator or an iterable, optional (default is 2)
126126
Determines the cross-validation splitting strategy.
127127
Possible inputs for cv are:
128128
@@ -535,15 +535,15 @@ class LinearDRLearner(StatsModelsCateEstimatorDiscreteMixin, DRLearner):
535535
`predict` methods. If different models per treatment arm are desired, see the
536536
:class:`~econml.utilities.MultiModelWrapper` helper class.
537537
538-
featurizer : sklearn featurizer or None
538+
featurizer : :term:`transformer`, optional, default None
539539
Must support fit_transform and transform. Used to create composite features in the final CATE regression.
540540
It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
541541
If featurizer=None, then CATE is trained on X.
542542
543-
fit_cate_intercept : bool, optional (Default=True)
543+
fit_cate_intercept : bool, optional, default True
544544
Whether the linear CATE model should have a constant term.
545545
546-
n_splits: int, cross-validation generator or an iterable, optional (Default=2)
546+
n_splits: int, cross-validation generator or an iterable, optional (default is 2)
547547
Determines the cross-validation splitting strategy.
548548
Possible inputs for cv are:
549549
@@ -711,28 +711,28 @@ class SparseLinearDRLearner(DebiasedLassoCateEstimatorDiscreteMixin, DRLearner):
711711
`predict` methods. If different models per treatment arm are desired, see the
712712
:class:`~econml.utilities.MultiModelWrapper` helper class.
713713
714-
featurizer : sklearn featurizer or None
714+
featurizer : :term:`transformer`, optional, default None
715715
Must support fit_transform and transform. Used to create composite features in the final CATE regression.
716716
It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
717717
If featurizer=None, then CATE is trained on X.
718718
719-
fit_cate_intercept : bool, optional (Default=True)
719+
fit_cate_intercept : bool, optional, default True
720720
Whether the linear CATE model should have a constant term.
721721
722-
alpha: string | float, optional. Default='auto'.
722+
alpha: string | float, optional., default 'auto'.
723723
CATE L1 regularization applied through the debiased lasso in the final model.
724724
'auto' corresponds to a CV form of the :class:`DebiasedLasso`.
725725
726-
max_iter : int, optional, default=1000
726+
max_iter : int, optional, default 1000
727727
The maximum number of iterations in the Debiased Lasso
728728
729-
tol : float, optional, default=1e-4
729+
tol : float, optional, default 1e-4
730730
The tolerance for the optimization: if the updates are
731731
smaller than ``tol``, the optimization code checks the
732732
dual gap for optimality and continues until it is smaller
733733
than ``tol``.
734734
735-
n_splits: int, cross-validation generator or an iterable, optional (Default=2)
735+
n_splits: int, cross-validation generator or an iterable, optional, default 2
736736
Determines the cross-validation splitting strategy.
737737
Possible inputs for cv are:
738738

0 commit comments

Comments
 (0)