Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grouping and other changes #273

Merged
merged 3 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 44 additions & 35 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class in this module implements the general logic in a very versatile way
import copy
from warnings import warn
from .utilities import (shape, reshape, ndim, hstack, cross_product, transpose, inverse_onehot,
broadcast_unit_treatments, reshape_treatmentwise_effects,
broadcast_unit_treatments, reshape_treatmentwise_effects, filter_none_kwargs,
StatsModelsLinearRegression, _EncoderWrapper)
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
from sklearn.linear_model import LinearRegression, LassoCV
Expand Down Expand Up @@ -125,9 +125,11 @@ def predict(self, X, y, W=None):
fitted_inds = []
calculate_scores = hasattr(model, 'score')

# remove None arguments
kwargs = filter_none_kwargs(**kwargs)

if folds is None: # skip crossfitting
model_list.append(clone(model, safe=False))
kwargs = {k: v for k, v in kwargs.items() if v is not None}
model_list[0].fit(*args, **kwargs)
nuisances = model_list[0].predict(*args, **kwargs)
scores = model_list[0].score(*args, **kwargs) if calculate_scores else None
Expand Down Expand Up @@ -155,8 +157,8 @@ def predict(self, X, y, W=None):
args_train = tuple(var[train_idxs] if var is not None else None for var in args)
args_test = tuple(var[test_idxs] if var is not None else None for var in args)

kwargs_train = {key: var[train_idxs] for key, var in kwargs.items() if var is not None}
kwargs_test = {key: var[test_idxs] for key, var in kwargs.items() if var is not None}
kwargs_train = {key: var[train_idxs] for key, var in kwargs.items()}
kwargs_test = {key: var[test_idxs] for key, var in kwargs.items()}

model_list[idx].fit(*args_train, **kwargs_train)

Expand Down Expand Up @@ -459,9 +461,9 @@ def __init__(self, model_nuisance, model_final, *,
def _asarray(A):
return None if A is None else np.asarray(A)

def _check_input_dims(self, Y, T, X=None, W=None, Z=None, sample_weight=None, sample_var=None):
def _check_input_dims(self, Y, T, X=None, W=None, Z=None, *other_arrays):
assert shape(Y)[0] == shape(T)[0], "Dimension mis-match!"
for arr in [X, W, Z, sample_weight, sample_var]:
for arr in [X, W, Z, *other_arrays]:
assert (arr is None) or (arr.shape[0] == Y.shape[0]), "Dimension mismatch"
self._d_x = X.shape[1:] if X is not None else None
self._d_w = W.shape[1:] if W is not None else None
Expand All @@ -487,14 +489,7 @@ def _check_fitted_dims_w_z(self, W, Z):
def _subinds_check_none(self, var, inds):
return var[inds] if var is not None else None

def _filter_none_kwargs(self, **kwargs):
non_none_kwargs = {}
for key, value in kwargs.items():
if value is not None:
non_none_kwargs[key] = value
return non_none_kwargs

def _strata(self, Y, T, X=None, W=None, Z=None, sample_weight=None, sample_var=None):
def _strata(self, Y, T, X=None, W=None, Z=None, sample_weight=None, sample_var=None, groups=None):
if self._discrete_instrument:
Z = LabelEncoder().fit_transform(np.ravel(Z))

Expand All @@ -511,7 +506,7 @@ def _strata(self, Y, T, X=None, W=None, Z=None, sample_weight=None, sample_var=N
return None

@BaseCateEstimator._wrap_fit
def fit(self, Y, T, X=None, W=None, Z=None, sample_weight=None, sample_var=None, *, inference=None):
def fit(self, Y, T, X=None, W=None, Z=None, sample_weight=None, sample_var=None, groups=None, *, inference=None):
"""
Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.

Expand All @@ -531,6 +526,10 @@ def fit(self, Y, T, X=None, W=None, Z=None, sample_weight=None, sample_var=None,
Weights for each samples
sample_var: optional (n,) vector or None (Default=None)
Sample variance for each sample
groups: (n,) vector, optional
All rows corresponding to the same group will be kept together during splitting.
If groups is not None, the n_splits argument passed to this class's initializer
must support a 'groups' argument to its split method.
inference: string, :class:`.Inference` instance, or None
Method for performing inference. This estimator supports 'bootstrap'
(or an instance of :class:`.BootstrapInference`).
Expand All @@ -539,10 +538,11 @@ def fit(self, Y, T, X=None, W=None, Z=None, sample_weight=None, sample_var=None,
-------
self : _OrthoLearner instance
"""
Y, T, X, W, Z, sample_weight, sample_var = [self._asarray(A)
for A in (Y, T, X, W, Z, sample_weight, sample_var)]
self._check_input_dims(Y, T, X, W, Z, sample_weight, sample_var)
nuisances, fitted_inds = self._fit_nuisances(Y, T, X, W, Z, sample_weight=sample_weight)
Y, T, X, W, Z, sample_weight, sample_var, groups = [self._asarray(A)
for A in (Y, T, X, W, Z,
sample_weight, sample_var, groups)]
self._check_input_dims(Y, T, X, W, Z, sample_weight, sample_var, groups)
nuisances, fitted_inds = self._fit_nuisances(Y, T, X, W, Z, sample_weight=sample_weight, groups=groups)
self._fit_final(self._subinds_check_none(Y, fitted_inds),
self._subinds_check_none(T, fitted_inds),
X=self._subinds_check_none(X, fitted_inds),
Expand All @@ -553,10 +553,10 @@ def fit(self, Y, T, X=None, W=None, Z=None, sample_weight=None, sample_var=None,
sample_var=self._subinds_check_none(sample_var, fitted_inds))
return self

def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
# use a binary array to get stratified split in case of discrete treatment
stratify = self._discrete_treatment or self._discrete_instrument
strata = self._strata(Y, T, X=X, W=W, Z=Z, sample_weight=sample_weight)
strata = self._strata(Y, T, X=X, W=W, Z=Z, sample_weight=sample_weight, groups=groups)
if strata is None:
strata = T # always safe to pass T as second arg to split even if we're not actually stratifying

Expand All @@ -579,16 +579,24 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
else:
splitter = check_cv(self._n_splits, [0], classifier=stratify)
# if check_cv produced a new KFold or StratifiedKFold object, we need to set shuffle and random_state
# TODO: ideally, we'd also infer whether we need a GroupKFold (if groups are passed)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grouping is more important than stratification for valid inference. So I would prioritize grouping over stratification here, i.e. if groups are enabled then use groupkfold. If not then use stratified if strata is not None else kfold.

Also we should most prob be raising a warning that "cross fitting performed without treatment stratification because grouping was enabled."

Ultimately I feel we should just add our own stratified group kfold that stratifies within each group, so that we really deliver the full version of our API.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With small sample sizes, failure to stratify can cause first stage model prediction to fail if no examples from one strata make it into a training fold. I agree, though, that we ought to have a mechanism that supports both simultaneously; there is work in progress to add such a feature to sklearn natively.

# however, sklearn doesn't support both stratifying and grouping (see
# https://github.com/scikit-learn/scikit-learn/issues/13621), so for now the user needs to supply
# their own object that supports grouping if they want to use groups.
if splitter != self._n_splits and isinstance(splitter, (KFold, StratifiedKFold)):
splitter.shuffle = True
splitter.random_state = self._random_state

all_vars = [var if np.ndim(var) == 2 else var.reshape(-1, 1) for var in [Z, W, X] if var is not None]
if all_vars:
all_vars = np.hstack(all_vars)
folds = splitter.split(all_vars, strata)
to_split = np.hstack(all_vars) if all_vars else np.ones((T.shape[0], 1))

if groups is not None:
if isinstance(splitter, (KFold, StratifiedKFold)):
raise TypeError("Groups were passed to fit while using a KFold or StratifiedKFold splitter. "
"Instead you must initialize this object with a splitter that can handle groups.")
folds = splitter.split(to_split, strata, groups=groups)
else:
folds = splitter.split(np.ones((T.shape[0], 1)), strata)
folds = splitter.split(to_split, strata)

if self._discrete_treatment:
self._d_t = shape(T)[1:]
Expand All @@ -597,21 +605,22 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
validate=False)

nuisances, fitted_models, fitted_inds, scores = _crossfit(self._model_nuisance, folds,
Y, T, X=X, W=W, Z=Z, sample_weight=sample_weight)
Y, T, X=X, W=W, Z=Z,
sample_weight=sample_weight, groups=groups)
self._models_nuisance = fitted_models
self.nuisance_scores_ = scores
return nuisances, fitted_inds

def _fit_final(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, sample_var=None):
self._model_final.fit(Y, T, **self._filter_none_kwargs(X=X, W=W, Z=Z,
nuisances=nuisances, sample_weight=sample_weight,
sample_var=sample_var))
self._model_final.fit(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z,
nuisances=nuisances, sample_weight=sample_weight,
sample_var=sample_var))
self.score_ = None
if hasattr(self._model_final, 'score'):
self.score_ = self._model_final.score(Y, T, **self._filter_none_kwargs(X=X, W=W, Z=Z,
nuisances=nuisances,
sample_weight=sample_weight,
sample_var=sample_var))
self.score_ = self._model_final.score(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z,
nuisances=nuisances,
sample_weight=sample_weight,
sample_var=sample_var))

def const_marginal_effect(self, X=None):
self._check_fitted_dims(X)
Expand Down Expand Up @@ -679,7 +688,7 @@ def score(self, Y, T, X=None, W=None, Z=None):
Z = self.z_transformer.transform(Z)
n_splits = len(self._models_nuisance)
for idx, mdl in enumerate(self._models_nuisance):
nuisance_temp = mdl.predict(Y, T, **self._filter_none_kwargs(X=X, W=W, Z=Z))
nuisance_temp = mdl.predict(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z))
if not isinstance(nuisance_temp, tuple):
nuisance_temp = (nuisance_temp,)

Expand All @@ -692,7 +701,7 @@ def score(self, Y, T, X=None, W=None, Z=None):
for it in range(len(nuisances)):
nuisances[it] = np.mean(nuisances[it], axis=0)

return self._model_final.score(Y, T, **self._filter_none_kwargs(X=X, W=W, Z=Z, nuisances=nuisances))
return self._model_final.score(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z, nuisances=nuisances))

@property
def model_final(self):
Expand Down
28 changes: 18 additions & 10 deletions econml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import numpy as np
import copy
from warnings import warn
from .utilities import (shape, reshape, ndim, hstack)
from .utilities import (shape, reshape, ndim, hstack, filter_none_kwargs)
from sklearn.linear_model import LinearRegression
from sklearn.base import clone
from ._ortho_learner import _OrthoLearner
Expand All @@ -45,24 +45,26 @@ def __init__(self, model_y, model_t):
self._model_y = clone(model_y, safe=False)
self._model_t = clone(model_t, safe=False)

def fit(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
def fit(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
assert Z is None, "Cannot accept instrument!"
self._model_t.fit(X, W, T, sample_weight=sample_weight)
self._model_y.fit(X, W, Y, sample_weight=sample_weight)
self._model_t.fit(X, W, T, **filter_none_kwargs(sample_weight=sample_weight, groups=groups))
self._model_y.fit(X, W, Y, **filter_none_kwargs(sample_weight=sample_weight, groups=groups))
return self

def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
if hasattr(self._model_y, 'score'):
Y_score = self._model_y.score(X, W, Y, sample_weight=sample_weight)
# note that groups are not passed to score because they are only used for fitting
Y_score = self._model_y.score(X, W, Y, **filter_none_kwargs(sample_weight=sample_weight))
else:
Y_score = None
if hasattr(self._model_t, 'score'):
T_score = self._model_t.score(X, W, T, sample_weight=sample_weight)
# note that groups are not passed to score because they are only used for fitting
T_score = self._model_t.score(X, W, T, **filter_none_kwargs(sample_weight=sample_weight))
else:
T_score = None
return Y_score, T_score

def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
Y_pred = self._model_y.predict(X, W)
T_pred = self._model_t.predict(X, W)
if (X is None) and (W is None): # In this case predict above returns a single row
Expand Down Expand Up @@ -282,7 +284,7 @@ def __init__(self, model_y, model_t, model_final,
n_splits=n_splits,
random_state=random_state)

def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, inference=None):
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, groups=None, inference=None):
"""
Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.

Expand All @@ -300,6 +302,10 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, infe
Weights for each samples
sample_var: optional(n,) vector or None (Default=None)
Sample variance for each sample
groups: (n,) vector, optional
All rows corresponding to the same group will be kept together during splitting.
If groups is not None, the n_splits argument passed to this class's initializer
must support a 'groups' argument to its split method.
inference: string,:class:`.Inference` instance, or None
Method for performing inference. This estimator supports 'bootstrap'
(or an instance of:class:`.BootstrapInference`).
Expand All @@ -309,7 +315,9 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, infe
self: _RLearner instance
"""
# Replacing fit from _OrthoLearner, to enforce Z=None and improve the docstring
return super().fit(Y, T, X=X, W=W, sample_weight=sample_weight, sample_var=sample_var, inference=inference)
return super().fit(Y, T, X=X, W=W,
sample_weight=sample_weight, sample_var=sample_var, groups=groups,
inference=inference)

def score(self, Y, T, X=None, W=None):
"""
Expand Down
Loading