Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vasilis/drlearner #137

Merged
merged 51 commits into from
Nov 11, 2019
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
f31d86a
separating drlearner
vasilismsr Nov 6, 2019
7c4d76d
better docstring
vasilismsr Nov 6, 2019
6e2c972
notebook
vasilismsr Nov 6, 2019
b9d5238
drlearner documentation finalized.
vasilismsr Nov 6, 2019
ad72eab
removed doubly robust learner from metalearners
vasilismsr Nov 6, 2019
868b5c1
revmoed drlearner test from metalearners and started new test file fo…
vasilismsr Nov 6, 2019
a293f0f
revmoed drlearner test from metalearners and started new test file fo…
vasilismsr Nov 6, 2019
c63d6b3
changed tests in dml to adhere to the CATE API. Some calls to effect …
vasilismsr Nov 6, 2019
d2494b6
added exhaustive tests for drlearner. Fixed bugs from corner cases
vasilismsr Nov 7, 2019
e75fd48
bug in docstring regarding how split is called to generate crossfit f…
vasilismsr Nov 7, 2019
fbc3f1d
changed bootstarp tests to conform with keyword only effect and effec…
vasilismsr Nov 7, 2019
3397d95
bootstrap tests, fixing bugs related to positional arguments T0, T1
vasilismsr Nov 7, 2019
b13581e
linting
vasilismsr Nov 7, 2019
e4c355d
more tests for drlearner and small bugs
vasilismsr Nov 7, 2019
8611577
reverted back to coef_ and intercept_. Fixed docstring. Added exhaust…
vasilismsr Nov 7, 2019
0dcea06
added checks of fitted dims for W and Z during scoring in _OrthoLearn…
vasilismsr Nov 7, 2019
7c55fb0
testing notebook
vasilismsr Nov 7, 2019
b5c0a07
linting
vasilismsr Nov 7, 2019
59d9fb1
docstring fix regarding n_splits in multiple places. Fixed notebook f…
vasilismsr Nov 7, 2019
d3274ff
changed statsmodels inference inptu properties to reflect that we onl…
vasilismsr Nov 7, 2019
988c53d
small change in test drlearner
vasilismsr Nov 7, 2019
e48f7de
docstring for linear drlearner
vasilismsr Nov 7, 2019
95fa08f
improved example in docstring of lineardrlearner
vasilismsr Nov 7, 2019
ef0dff5
adding more slacks to some drlearner coverage tests
vasilismsr Nov 7, 2019
07a79e2
Merge branch 'master' into vasilis/drlearner
vasilismsr Nov 7, 2019
d2b7835
linting
vasilismsr Nov 7, 2019
0474015
linting
vasilismsr Nov 7, 2019
f6d1505
removed OrthoLearner testing notebook
vasilismsr Nov 7, 2019
cb60dc4
comments on fit score dimension mismatch
vasilismsr Nov 8, 2019
9768a76
removed leftover print statement from cate estimator
vasilismsr Nov 8, 2019
c315539
replaced :code: with
vasilismsr Nov 8, 2019
51e0595
added :meth:
vasilismsr Nov 8, 2019
f377e95
added :meth:
vasilismsr Nov 8, 2019
aa21807
removed + for string concat
vasilismsr Nov 8, 2019
a3b7b52
removed redundant test code
vasilismsr Nov 8, 2019
b00052c
added comment on overlapping tests between DML and DRLearner
vasilismsr Nov 8, 2019
cd597f4
removed redundant rand_sol code in test_drlearner
vasilismsr Nov 8, 2019
42fae99
removed + operator for string concat
vasilismsr Nov 8, 2019
8de555b
added TODO for allowing for 2d y of shape (n,1) and also added test t…
vasilismsr Nov 8, 2019
084561d
removed redundant adding and subtracting in statsmodelscateestimator
vasilismsr Nov 8, 2019
d606f2a
changed :attr: to :meth:
vasilismsr Nov 8, 2019
7803c2e
added TODO so that we merge functionality between statsmodelsinferenc…
vasilismsr Nov 8, 2019
610db58
replaced printing with subTests
vasilismsr Nov 8, 2019
68d5180
linting
vasilismsr Nov 8, 2019
fdd536a
fixed docstring in dml. Added utility function of inverse_onehot enco…
vasilismsr Nov 9, 2019
bfb1ff8
removed replacing None weights with np.ones in drlearner scoring sinc…
vasilismsr Nov 9, 2019
179c2cf
typo in error mst
vasilismsr Nov 9, 2019
9c4271e
made statsmodelslinearregression be child of BaseEstimator
vasilismsr Nov 9, 2019
8ea37c8
added comment on code design choice in model_final of drlearner, rela…
vasilismsr Nov 9, 2019
36ddc55
put docstrings in methods and removed them from attributes
vasilismsr Nov 9, 2019
c1549c8
Merge branch 'master' into vasilis/drlearner
vasilismsr Nov 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Public Module Reference
econml.deepiv
econml.dgp
econml.dml
econml.drlearner
econml.inference
econml.ortho_forest
econml.selective_regularization
Expand Down
2 changes: 1 addition & 1 deletion econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
return super().const_marginal_effect_interval(X, alpha=alpha)
const_marginal_effect_interval.__doc__ = LinearCateEstimator.const_marginal_effect_interval.__doc__

def effect_interval(self, X=None, T0=0, T1=1, *, alpha=0.1):
def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):
self._check_fitted_dims(X)
return super().effect_interval(X, T0=T0, T1=T1, alpha=alpha)
effect_interval.__doc__ = LinearCateEstimator.effect_interval.__doc__
Expand Down
154 changes: 152 additions & 2 deletions econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .bootstrap import BootstrapEstimator
from .inference import BootstrapInference
from .utilities import tensordot, ndim, reshape, shape
from .inference import StatsModelsInference
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete


class BaseCateEstimator(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -173,10 +173,50 @@ def call(self, *args, **kwargs):

@_defer_to_inference
def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):
""" Confidence intervals for the quantities :math:`\\tau(X, T0, T1)` produced
by the model. Available only when :code:`inference` is not :code:`None`, when
calling the fit method.

Parameters
----------
X: optional (m, d_x) matrix
Features for each sample
T0: optional (m, d_t) matrix or vector of length m (Default=0)
Base treatments for each sample
T1: optional (m, d_t) matrix or vector of length m (Default=1)
Target treatments for each sample
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper : tuple(type of effect(X, T0, T1), type of effect(X, T0, T1))
The lower and the upper bounds of the confidence interval for each quantity.
"""
pass

@_defer_to_inference
def marginal_effect_interval(self, T, X=None, *, alpha=0.1):
""" Confidence intervals for the quantities :math:`\\partial \\tau(T, X)` produced
by the model. Available only when :code:`inference` is not :code:`None`, when
calling the fit method.

Parameters
----------
T: (m, d_t) matrix
Base treatments for each sample
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper : tuple(type of marginal_effect(T, X), type of marginal_effect(T, X))
The lower and the upper bounds of the confidence interval for each quantity.
"""
pass


Expand Down Expand Up @@ -279,9 +319,27 @@ def marginal_effect_interval(self, T, X=None, *, alpha=0.1):
effs = self.const_marginal_effect_interval(X=X, alpha=alpha)
return tuple(np.repeat(eff, shape(T)[0], axis=0) if X is None else eff
for eff in effs)
marginal_effect_interval.__doc__ = BaseCateEstimator.marginal_effect_interval.__doc__

@BaseCateEstimator._defer_to_inference
def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
""" Confidence intervals for the quantities :math:`\\theta(X)` produced
by the model. Available only when :code:`inference` is not :code:`None`, when
calling the fit method.

Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper : tuple(type of const_marginal_effect(X), type of const_marginal_effect(X))
The lower and the upper bounds of the confidence interval for each quantity.
"""
pass


Expand Down Expand Up @@ -314,7 +372,7 @@ def _expand_treatments(self, X=None, *Ts):
return (X,) + tuple(outTs)

# override effect to set defaults, which works with the new definition of _expand_treatments
def effect(self, X=None, T0=0, T1=1):
def effect(self, X=None, *, T0=0, T1=1):
# NOTE: don't explicitly expand treatments here, because it's done in the super call
return super().effect(X, T0=T0, T1=T1)
effect.__doc__ = BaseCateEstimator.effect.__doc__
Expand Down Expand Up @@ -348,3 +406,95 @@ def coef__interval(self, *, alpha=0.1):
@BaseCateEstimator._defer_to_inference
def intercept__interval(self, *, alpha=0.1):
pass


class StatsModelsCateEstimatorDiscreteMixin(BaseCateEstimator):

def _get_inference_options(self):
# add statsmodels to parent's options
options = super()._get_inference_options()
options.update(statsmodels=StatsModelsInferenceDiscrete)
return options

@property
@abc.abstractmethod
def statsmodels(self):
pass

def coef_(self, T):
""" The coefficients in the linear model of the constant marginal treatment
effect associated with treatment T.

Parameters
----------
T: alphanumeric
The input treatment for which we want the coefficients.

Returns
-------
coef_: (n_x,) or (n_y, n_x) array like
Where n_x is the number of features that enter the final model (either the
dimension of X or the dimension of featurizer.fit_transform(X) if the CATE
estimator has a featurizer.)
"""
_, T = self._expand_treatments(None, T)
ind = (T @ np.arange(1, T.shape[1] + 1)).astype(int)[0] - 1
return self.statsmodels[ind].coef_

@property
def intercept_(self, T):
""" The intercept in the linear model of the constant marginal treatment
effect associated with treatment T.

Parameters
----------
T: alphanumeric
The input treatment for which we want the coefficients.

Returns
-------
intercept_: float or (n_y,) array like
"""
_, T = self._expand_treatments(None, T)
ind = (T @ np.arange(1, T.shape[1] + 1)).astype(int)[0] - 1
return self.statsmodels[ind].intercept_

@BaseCateEstimator._defer_to_inference
def coef__interval(self, T, *, alpha=0.1):
""" The confidence interval for the coefficients in the linear model of the
constant marginal treatment effect associated with treatment T.

Parameters
----------
T: alphanumeric
The input treatment for which we want the coefficients.
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper: tuple(type of coef_, type of coef_)
The lower and upper bounds of the confidence interval for each quantity.
"""
pass

@BaseCateEstimator._defer_to_inference
def intercept__interval(self, T, *, alpha=0.1):
""" The intercept in the linear model of the constant marginal treatment
effect associated with treatment T.

Parameters
----------
T: alphanumeric
The input treatment for which we want the coefficients.
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper: tuple(type of intercept_, type of intercept_)
The lower and upper bounds of the confidence interval.
"""
pass
Loading