@@ -124,18 +124,19 @@ class takes as input the parameter `model_t`, which is an arbitrary scikit-learn
124
124
The estimator for fitting the response residuals to the treatment residuals. Must implement
125
125
`fit` and `predict` methods, and must be a linear model for correctness.
126
126
127
- featurizer: transformer
128
- The transformer used to featurize the raw features when fitting the final model. Must implement
129
- a `fit_transform` method.
127
+ featurizer: :term:`transformer`, optional, default None
128
+ Must support fit_transform and transform. Used to create composite features in the final CATE regression.
129
+ It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
130
+ If featurizer=None, then CATE is trained on X.
130
131
131
132
linear_first_stages: bool
132
133
Whether the first stage models are linear (in which case we will expand the features passed to
133
134
`model_y` accordingly)
134
135
135
- discrete_treatment: bool, optional ( default is `` False``)
136
+ discrete_treatment: bool, optional, default False
136
137
Whether the treatment values should be treated as categorical, rather than continuous, quantities
137
138
138
- n_splits: int, cross-validation generator or an iterable, optional (Default=2)
139
+ n_splits: int, cross-validation generator or an iterable, optional, default 2
139
140
Determines the cross-validation splitting strategy.
140
141
Possible inputs for cv are:
141
142
@@ -161,7 +162,7 @@ class takes as input the parameter `model_t`, which is an arbitrary scikit-learn
161
162
162
163
def __init__ (self ,
163
164
model_y , model_t , model_final ,
164
- featurizer ,
165
+ featurizer = None ,
165
166
linear_first_stages = False ,
166
167
discrete_treatment = False ,
167
168
n_splits = 2 ,
@@ -177,22 +178,23 @@ def __init__(self, model, is_Y):
177
178
self ._is_Y = is_Y
178
179
179
180
def _combine (self , X , W , n_samples , fitting = True ):
181
+ no_x = X is None
182
+ if no_x :
183
+ X = np .ones ((n_samples , 1 ))
184
+ if W is None :
185
+ W = np .empty ((n_samples , 0 ))
186
+ XW = hstack ([X , W ])
180
187
if self ._is_Y and linear_first_stages :
181
- if X is not None :
182
- F = self ._featurizer .fit_transform (X ) if fitting else self ._featurizer .transform (X )
188
+ if no_x :
189
+ return XW
190
+
191
+ if self ._featurizer is None :
192
+ F = X
183
193
else :
184
- X = np .ones ((n_samples , 1 ))
185
- F = np .ones ((n_samples , 1 ))
186
- if W is None :
187
- W = np .empty ((n_samples , 0 ))
188
- XW = hstack ([X , W ])
189
- return cross_product (XW , hstack ([np .ones ((shape (XW )[0 ], 1 )), F , W ]))
194
+ F = self ._featurizer .fit_transform (X ) if fitting else self ._featurizer .transform (X )
195
+ return cross_product (XW , hstack ([np .ones ((shape (XW )[0 ], 1 )), F ]))
190
196
else :
191
- if X is None :
192
- X = np .ones ((n_samples , 1 ))
193
- if W is None :
194
- W = np .empty ((n_samples , 0 ))
195
- return hstack ([X , W ])
197
+ return XW
196
198
197
199
def fit (self , X , W , Target , sample_weight = None ):
198
200
if (not self ._is_Y ) and discrete_treatment :
@@ -221,12 +223,21 @@ def __init__(self):
221
223
self ._model = clone (model_final , safe = False )
222
224
self ._featurizer = clone (featurizer , safe = False )
223
225
226
+ def _combine (self , X , T , fitting = True ):
227
+ if X is not None :
228
+ if self ._featurizer is not None :
229
+ F = self ._featurizer .fit_transform (X ) if fitting else self ._featurizer .transform (X )
230
+ else :
231
+ F = X
232
+ else :
233
+ F = np .ones ((T .shape [0 ], 1 ))
234
+ return cross_product (F , T )
235
+
224
236
def fit (self , X , T_res , Y_res , sample_weight = None , sample_var = None ):
225
237
# Track training dimensions to see if Y or T is a vector instead of a 2-dimensional array
226
238
self ._d_t = shape (T_res )[1 :]
227
239
self ._d_y = shape (Y_res )[1 :]
228
- F = self ._featurizer .fit_transform (X ) if X is not None else np .ones ((T_res .shape [0 ], 1 ))
229
- fts = cross_product (F , T_res )
240
+ fts = self ._combine (X , T_res )
230
241
if sample_weight is not None :
231
242
if sample_var is not None :
232
243
self ._model .fit (fts ,
@@ -246,9 +257,9 @@ def fit(self, X, T_res, Y_res, sample_weight=None, sample_var=None):
246
257
self ._intercept = intercept
247
258
248
259
def predict (self , X ):
249
- F = self . _featurizer . transform ( X ) if X is not None else np .ones ((1 , 1 ))
250
- F , T = broadcast_unit_treatments ( F , self ._d_t [0 ] if self ._d_t else 1 )
251
- prediction = self ._model .predict (cross_product ( F , T ))
260
+ X2 , T = broadcast_unit_treatments ( X if X is not None else np .empty ((1 , 0 )),
261
+ self ._d_t [0 ] if self ._d_t else 1 )
262
+ prediction = self ._model .predict (self . _combine ( None if X is None else X2 , T , fitting = False ))
252
263
if self ._intercept is not None :
253
264
prediction -= self ._intercept
254
265
return reshape_treatmentwise_effects (prediction ,
@@ -292,10 +303,13 @@ class LinearDMLCateEstimator(StatsModelsCateEstimatorMixin, DMLCateEstimator):
292
303
The estimator for fitting the treatment to the features. Must implement
293
304
`fit` and `predict` methods.
294
305
295
- featurizer: transformer, optional (default is \
296
- :class:`PolynomialFeatures(degree=1, include_bias=True) <sklearn.preprocessing.PolynomialFeatures>`)
297
- The transformer used to featurize the raw features when fitting the final model. Must implement
298
- a `fit_transform` method.
306
+ featurizer : :term:`transformer`, optional, default None
307
+ Must support fit_transform and transform. Used to create composite features in the final CATE regression.
308
+ It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
309
+ If featurizer=None, then CATE is trained on X.
310
+
311
+ fit_cate_intercept : bool, optional, default True
312
+ Whether the linear CATE model should have a constant term.
299
313
300
314
linear_first_stages: bool
301
315
Whether the first stage models are linear (in which case we will expand the features passed to
@@ -330,14 +344,15 @@ class LinearDMLCateEstimator(StatsModelsCateEstimatorMixin, DMLCateEstimator):
330
344
331
345
def __init__ (self ,
332
346
model_y = LassoCV (), model_t = LassoCV (),
333
- featurizer = PolynomialFeatures (degree = 1 , include_bias = True ),
347
+ featurizer = None ,
348
+ fit_cate_intercept = True ,
334
349
linear_first_stages = True ,
335
350
discrete_treatment = False ,
336
351
n_splits = 2 ,
337
352
random_state = None ):
338
353
super ().__init__ (model_y = model_y ,
339
354
model_t = model_t ,
340
- model_final = StatsModelsLinearRegression (fit_intercept = False ),
355
+ model_final = StatsModelsLinearRegression (fit_intercept = fit_cate_intercept ),
341
356
featurizer = featurizer ,
342
357
linear_first_stages = linear_first_stages ,
343
358
discrete_treatment = discrete_treatment ,
@@ -410,10 +425,13 @@ class SparseLinearDMLCateEstimator(DebiasedLassoCateEstimatorMixin, DMLCateEstim
410
425
dual gap for optimality and continues until it is smaller
411
426
than ``tol``.
412
427
413
- featurizer: transformer, optional
414
- (default is :class:`PolynomialFeatures(degree=1, include_bias=True) <sklearn.preprocessing.PolynomialFeatures>`)
415
- The transformer used to featurize the raw features when fitting the final model. Must implement
416
- a `fit_transform` method.
428
+ featurizer : :term:`transformer`, optional, default None
429
+ Must support fit_transform and transform. Used to create composite features in the final CATE regression.
430
+ It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X).
431
+ If featurizer=None, then CATE is trained on X.
432
+
433
+ fit_cate_intercept : bool, optional, default True
434
+ Whether the linear CATE model should have a constant term.
417
435
418
436
linear_first_stages: bool
419
437
Whether the first stage models are linear (in which case we will expand the features passed to
@@ -450,7 +468,8 @@ def __init__(self,
450
468
alpha = 'auto' ,
451
469
max_iter = 1000 ,
452
470
tol = 1e-4 ,
453
- featurizer = PolynomialFeatures (degree = 1 , include_bias = True ),
471
+ featurizer = None ,
472
+ fit_cate_intercept = True ,
454
473
linear_first_stages = True ,
455
474
discrete_treatment = False ,
456
475
n_splits = 2 ,
0 commit comments