10
10
from warnings import warn
11
11
from .bootstrap import BootstrapEstimator
12
12
from .inference import BootstrapInference
13
- from .utilities import tensordot , ndim , reshape , shape
13
+ from .utilities import tensordot , ndim , reshape , shape , parse_final_model_params
14
14
from .inference import StatsModelsInference , StatsModelsInferenceDiscrete , LinearModelFinalInference ,\
15
15
LinearModelFinalInferenceDiscrete
16
16
@@ -383,7 +383,20 @@ def effect(self, X=None, *, T0=0, T1=1):
383
383
384
384
385
385
class LinearModelFinalCateEstimatorMixin (BaseCateEstimator ):
386
- """Base class for models where the final stage is a linear model."""
386
+ """
387
+ Base class for models where the final stage is a linear model.
388
+
389
+ Subclasses must expose a ``model_final`` attribute containing the model's
390
+ final stage model.
391
+
392
+ Attributes
393
+ ----------
394
+ bias_part_of_coef: bool
395
+ Whether the CATE model's intercept is contained in the final model's ``coef_`` rather
396
+ than as a separate ``intercept_``
397
+ """
398
+
399
+ bias_part_of_coef = False
387
400
388
401
@property
389
402
def coef_ (self ):
@@ -392,15 +405,17 @@ def coef_(self):
392
405
393
406
Returns
394
407
-------
395
- coef: (n_x * n_t,) or (n_y, n_x * n_t) array like
408
+ coef: (n_x,) or ( n_t, n_x ) or (n_y, n_t, n_x ) array like
396
409
Where n_x is the number of features that enter the final model (either the
397
410
dimension of X or the dimension of featurizer.fit_transform(X) if the CATE
398
411
estimator has a featurizer.), n_t is the number of treatments, n_y is
399
- the number of outcomes. The coefficient is flattened in a manner that
400
- the first block of n_x columns are the coefficients associated with treatment 0,
401
- the next n_x columns are the coefficients associated with treatment 1 etc .
412
+ the number of outcomes. Dimensions are omitted if the original input was
413
+ a vector and not a 2D array. For binary treatment the n_t dimension is
414
+ also omitted .
402
415
"""
403
- return self .model_final .coef_
416
+ return parse_final_model_params (self .model_final .coef_ , self .model_final .intercept_ ,
417
+ self ._d_y , self ._d_t , self ._d_t_in , self .bias_part_of_coef ,
418
+ self .fit_cate_intercept )[0 ]
404
419
405
420
@property
406
421
def intercept_ (self ):
@@ -409,9 +424,17 @@ def intercept_(self):
409
424
410
425
Returns
411
426
-------
412
- intercept: float or (n_y,) array like
427
+ intercept: float or (n_y,) or (n_y, n_t) array like
428
+ Where n_t is the number of treatments, n_y is
429
+ the number of outcomes. Dimensions are omitted if the original input was
430
+ a vector and not a 2D array. For binary treatment the n_t dimension is
431
+ also omitted.
413
432
"""
414
- return self .model_final .intercept_
433
+ if not self .fit_cate_intercept :
434
+ raise AttributeError ("No intercept was fitted!" )
435
+ return parse_final_model_params (self .model_final .coef_ , self .model_final .intercept_ ,
436
+ self ._d_y , self ._d_t , self ._d_t_in , self .bias_part_of_coef ,
437
+ self .fit_cate_intercept )[1 ]
415
438
416
439
@BaseCateEstimator ._defer_to_inference
417
440
def coef__interval (self , * , alpha = 0.1 ):
@@ -479,6 +502,12 @@ def _get_inference_options(self):
479
502
480
503
class LinearModelFinalCateEstimatorDiscreteMixin (BaseCateEstimator ):
481
504
# TODO Share some logic with non-discrete version
505
+ """
506
+ Base class for models where the final stage is a linear model.
507
+
508
+ Subclasses must expose a ``fitted_models_final`` attribute
509
+ returning an array of the fitted models for each non-control treatment
510
+ """
482
511
483
512
def coef_ (self , T ):
484
513
""" The coefficients in the linear model of the constant marginal treatment
@@ -498,7 +527,8 @@ def coef_(self, T):
498
527
"""
499
528
_ , T = self ._expand_treatments (None , T )
500
529
ind = (T @ np .arange (T .shape [1 ])).astype (int )[0 ]
501
- return self .fitted_models_final [ind ].coef_
530
+ all_coefs = self .fitted_models_final [ind ].coef_
531
+ return all_coefs
502
532
503
533
def intercept_ (self , T ):
504
534
""" The intercept in the linear model of the constant marginal treatment
0 commit comments