Skip to content

Commit 9b6c939

Browse files
authored
Add fit_cate_intercept to DML, rework feature generation (#174)
Add fit_cate_intercept to DML, rework feature generation
1 parent 818c832 commit 9b6c939

9 files changed

+380
-114
lines changed

econml/cate_estimator.py

+40-10
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from warnings import warn
1111
from .bootstrap import BootstrapEstimator
1212
from .inference import BootstrapInference
13-
from .utilities import tensordot, ndim, reshape, shape
13+
from .utilities import tensordot, ndim, reshape, shape, parse_final_model_params
1414
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete, LinearModelFinalInference,\
1515
LinearModelFinalInferenceDiscrete
1616

@@ -383,7 +383,20 @@ def effect(self, X=None, *, T0=0, T1=1):
383383

384384

385385
class LinearModelFinalCateEstimatorMixin(BaseCateEstimator):
386-
"""Base class for models where the final stage is a linear model."""
386+
"""
387+
Base class for models where the final stage is a linear model.
388+
389+
Subclasses must expose a ``model_final`` attribute containing the model's
390+
final stage model.
391+
392+
Attributes
393+
----------
394+
bias_part_of_coef: bool
395+
Whether the CATE model's intercept is contained in the final model's ``coef_`` rather
396+
than as a separate ``intercept_``
397+
"""
398+
399+
bias_part_of_coef = False
387400

388401
@property
389402
def coef_(self):
@@ -392,15 +405,17 @@ def coef_(self):
392405
393406
Returns
394407
-------
395-
coef: (n_x * n_t,) or (n_y, n_x * n_t) array like
408+
coef: (n_x,) or (n_t, n_x) or (n_y, n_t, n_x) array like
396409
Where n_x is the number of features that enter the final model (either the
397410
dimension of X or the dimension of featurizer.fit_transform(X) if the CATE
398411
estimator has a featurizer.), n_t is the number of treatments, n_y is
399-
the number of outcomes. The coefficient is flattened in a manner that
400-
the first block of n_x columns are the coefficients associated with treatment 0,
401-
the next n_x columns are the coefficients associated with treatment 1 etc.
412+
the number of outcomes. Dimensions are omitted if the original input was
413+
a vector and not a 2D array. For binary treatment the n_t dimension is
414+
also omitted.
402415
"""
403-
return self.model_final.coef_
416+
return parse_final_model_params(self.model_final.coef_, self.model_final.intercept_,
417+
self._d_y, self._d_t, self._d_t_in, self.bias_part_of_coef,
418+
self.fit_cate_intercept)[0]
404419

405420
@property
406421
def intercept_(self):
@@ -409,9 +424,17 @@ def intercept_(self):
409424
410425
Returns
411426
-------
412-
intercept: float or (n_y,) array like
427+
intercept: float or (n_y,) or (n_y, n_t) array like
428+
Where n_t is the number of treatments, n_y is
429+
the number of outcomes. Dimensions are omitted if the original input was
430+
a vector and not a 2D array. For binary treatment the n_t dimension is
431+
also omitted.
413432
"""
414-
return self.model_final.intercept_
433+
if not self.fit_cate_intercept:
434+
raise AttributeError("No intercept was fitted!")
435+
return parse_final_model_params(self.model_final.coef_, self.model_final.intercept_,
436+
self._d_y, self._d_t, self._d_t_in, self.bias_part_of_coef,
437+
self.fit_cate_intercept)[1]
415438

416439
@BaseCateEstimator._defer_to_inference
417440
def coef__interval(self, *, alpha=0.1):
@@ -479,6 +502,12 @@ def _get_inference_options(self):
479502

480503
class LinearModelFinalCateEstimatorDiscreteMixin(BaseCateEstimator):
481504
# TODO Share some logic with non-discrete version
505+
"""
506+
Base class for models where the final stage is a linear model.
507+
508+
Subclasses must expose a ``fitted_models_final`` attribute
509+
returning an array of the fitted models for each non-control treatment
510+
"""
482511

483512
def coef_(self, T):
484513
""" The coefficients in the linear model of the constant marginal treatment
@@ -498,7 +527,8 @@ def coef_(self, T):
498527
"""
499528
_, T = self._expand_treatments(None, T)
500529
ind = (T @ np.arange(T.shape[1])).astype(int)[0]
501-
return self.fitted_models_final[ind].coef_
530+
all_coefs = self.fitted_models_final[ind].coef_
531+
return all_coefs
502532

503533
def intercept_(self, T):
504534
""" The intercept in the linear model of the constant marginal treatment

0 commit comments

Comments
 (0)