Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

automate the first stage model T and update DML notebook #172

Merged
merged 11 commits into from
Nov 21, 2019
17 changes: 12 additions & 5 deletions econml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@
from warnings import warn
from .utilities import (shape, reshape, ndim, hstack, cross_product, transpose, inverse_onehot,
broadcast_unit_treatments, reshape_treatmentwise_effects,
StatsModelsLinearRegression, LassoCVWrapper, check_high_dimensional)
StatsModelsLinearRegression, LassoCVWrapper, check_high_dimensional, WeightedLassoCVWrapper)
from econml.sklearn_extensions.linear_model import MultiOutputDebiasedLasso
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
from sklearn.linear_model import LinearRegression, LassoCV, ElasticNetCV
from sklearn.linear_model import LinearRegression, LassoCV, LogisticRegressionCV, ElasticNetCV
from sklearn.preprocessing import (PolynomialFeatures, LabelEncoder, OneHotEncoder,
FunctionTransformer)
from sklearn.base import clone, TransformerMixin
Expand All @@ -52,6 +52,7 @@
DebiasedLassoCateEstimatorMixin)
from .inference import StatsModelsInference
from ._rlearner import _RLearner
from .sklearn_extensions.model_selection import WeightedStratifiedKFold


class DMLCateEstimator(_RLearner):
Expand Down Expand Up @@ -170,6 +171,12 @@ def __init__(self,
# TODO: consider whether we need more care around stateful featurizers,
# since we clone it and fit separate copies

if model_t == 'auto':
if discrete_treatment:
model_t = LogisticRegressionCV(cv=WeightedStratifiedKFold())
else:
model_t = WeightedLassoCVWrapper()

class FirstStageWrapper:
def __init__(self, model, is_Y):
self._model = clone(model, safe=False)
Expand Down Expand Up @@ -329,7 +336,7 @@ class LinearDMLCateEstimator(StatsModelsCateEstimatorMixin, DMLCateEstimator):
"""

def __init__(self,
model_y=LassoCV(), model_t=LassoCV(),
model_y=LassoCV(), model_t='auto',
featurizer=PolynomialFeatures(degree=1, include_bias=True),
linear_first_stages=True,
discrete_treatment=False,
Expand Down Expand Up @@ -446,7 +453,7 @@ class SparseLinearDMLCateEstimator(DebiasedLassoCateEstimatorMixin, DMLCateEstim
"""

def __init__(self,
model_y=LassoCV(), model_t=LassoCV(),
model_y=LassoCV(), model_t='auto',
alpha='auto',
max_iter=1000,
tol=1e-4,
Expand Down Expand Up @@ -551,7 +558,7 @@ class KernelDMLCateEstimator(DMLCateEstimator):
by :mod:`np.random<numpy.random>`.
"""

def __init__(self, model_y=LassoCV(), model_t=LassoCV(),
def __init__(self, model_y=LassoCV(), model_t='auto',
dim=20, bw=1.0, discrete_treatment=False, n_splits=2, random_state=None):
class RandomFeatures(TransformerMixin):
def __init__(self, random_state):
Expand Down
2 changes: 1 addition & 1 deletion econml/tests/test_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def make_random(is_discrete, d):
all_infs.append(BootstrapInference(1))

for est, multi, infs in [(LinearDMLCateEstimator(model_y=Lasso(),
model_t=model_t,
model_t='auto',
discrete_treatment=is_discrete),
False,
all_infs),
Expand Down
40 changes: 40 additions & 0 deletions econml/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sklearn.model_selection._split import _CVIterableWrapper, CV_WARNING
from sklearn.utils.multiclass import type_of_target
import numbers
from .sklearn_extensions.linear_model import WeightedLassoCV, WeightedMultiTaskLassoCV

MAX_RAND_SEED = np.iinfo(np.int32).max

Expand Down Expand Up @@ -1871,3 +1872,42 @@ def fit(self, X, Y):
def predict(self, X):
predictions = self.model.predict(X)
return reshape(predictions, (-1, 1)) if self.needs_unravel else predictions


class WeightedLassoCVWrapper:
"""Helper class to wrap either WeightedLassoCV or WeightedMultiTaskLassoCV depending on the shape of the target."""

def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs

def fit(self, X, y):
assert shape(X)[0] == shape(y)[0]
assert ndim(y) <= 2
self.needs_unravel = False
if ndim(y) == 2 and shape(y)[1] > 1:
self.model = WeightedMultiTaskLassoCV(*self.args, **self.kwargs)
else:
if ndim(y) == 2 and shape(y)[1] == 1:
y = np.ravel(y)
self.needs_unravel = True
self.model = WeightedLassoCV(*self.args, **self.kwargs)
self.model.fit(X, y)
# set intercept_ attribute
self.intercept_ = self.model.intercept_
# set coef_ attribute
self.coef_ = self.model.coef_
# set alpha_ attribute
self.alpha_ = self.model.alpha_
# set alphas_ attribute
self.alphas_ = self.model.alphas_
# set n_iter_ attribute
self.n_iter_ = self.model.n_iter_
return self

def predict(self, X):
predictions = self.model.predict(X)
return reshape(predictions, (-1, 1)) if self.needs_unravel else predictions

def score(self, X, y, sample_weight=None):
return self.model.score(X, y, sample_weight)
46 changes: 23 additions & 23 deletions notebooks/Double Machine Learning Examples.ipynb

Large diffs are not rendered by default.