Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vasilis/orf speed #316

Merged
merged 31 commits into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b2391b1
orf speedup by moving pointwise effect outside of class
vasilismsr Nov 13, 2020
7595609
re-run ORF notebook
vasilismsr Nov 13, 2020
ce10c48
re-run forest learners notebook
vasilismsr Nov 13, 2020
59dce4f
linting
vasilismsr Nov 13, 2020
b7e01e2
removed unused variable and re-run noteoboks
vasilismsr Nov 13, 2020
8c3c561
incorporated weight computation as part of spawining of the parallel …
vasilismsr Nov 14, 2020
e974429
changed orf test inference to n_jobs=1 so that coverage levels can wo…
vasilismsr Nov 14, 2020
e18ea7c
removed the one hot encoding from the nuisance and parameter estimato…
vasilismsr Nov 14, 2020
cca85f9
fixed bug in code that removes first stage cross-fitting. Replaced po…
vasilismsr Nov 14, 2020
6ea2664
re-run notebooks
vasilismsr Nov 14, 2020
a310782
fixed discrete treatment expansion. fixed test_orf.
vasilismsr Nov 14, 2020
d999732
added the option of global residulization to the continuous treatment…
vasilismsr Nov 15, 2020
092175b
linting on orf. added global_res flag in orf notebook. added examples…
vasilismsr Nov 15, 2020
3913aba
added causal forest module that is enabled via the global residualiza…
vasilismsr Nov 15, 2020
3433452
removed cross ref in cross_val_predict docstring
vasilismsr Nov 15, 2020
5d35ed0
added inference to BLBInference so that all ortho forests and causal …
vasilismsr Nov 15, 2020
59cc2b7
orthoforest notebook added inference.
vasilismsr Nov 15, 2020
b84c1e8
fixed shape api in inference of ortho forest
vasilismsr Nov 16, 2020
41b1ff4
replaced super(class, obj) with just super()
vasilismsr Nov 16, 2020
7510cad
fixes
vasilismsr Nov 16, 2020
ff1d430
fixes
vasilismsr Nov 16, 2020
4ff41d5
attribute sklearn cross_val_predict
vasilismsr Nov 16, 2020
e42ab0a
lintig
vasilismsr Nov 16, 2020
ff89699
changed ORF names to DMLOrthoForest and DROrthoForest and deprecated …
vasilismsr Nov 16, 2020
d190fa2
Merge branch 'master' into vasilis/orf_speed
vsyrgkanis Nov 17, 2020
37bb3ce
using stratified kfold in dmlorthoforest when discrete_treatment=True…
vasilismsr Nov 17, 2020
3641c08
addressing @moprescu review comments
vasilismsr Nov 17, 2020
5d4b53f
Merge branch 'master' into vasilis/orf_speed
vsyrgkanis Nov 17, 2020
2e8ad53
updated notebooks
vasilismsr Nov 17, 2020
332aa96
renmed notebook
vasilismsr Nov 17, 2020
f2689ac
made changes to put all logic in get_conforming_residuals
vasilismsr Nov 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions econml/causal_forest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from .ortho_forest import DMLOrthoForest
from .utilities import LassoCVWrapper
from sklearn.linear_model import LogisticRegressionCV


class CausalForest(DMLOrthoForest):
"""CausalForest for continuous treatments. To apply to discrete
treatments, first one-hot-encode your treatments and then pass the one-hot-encoding.

Parameters
----------
n_trees : integer, optional (default=500)
Number of causal estimators in the forest.

min_leaf_size : integer, optional (default=10)
The minimum number of samples in a leaf.

max_depth : integer, optional (default=10)
The maximum number of splits to be performed when expanding the tree.

subsample_ratio : float, optional (default=0.7)
The ratio of the total sample to be used when training a causal tree.
Values greater than 1.0 will be considered equal to 1.0.

lambda_reg : float, optional (default=0.01)
The regularization coefficient in the ell_2 penalty imposed on the
locally linear part of the second stage fit. This is not applied to
the local intercept, only to the coefficient of the linear component.

model_T : estimator, optional (default=sklearn.linear_model.LassoCV(cv=3))
The estimator for residualizing the continuous treatment.
Must implement `fit` and `predict` methods.

model_Y : estimator, optional (default=sklearn.linear_model.LassoCV(cv=3)
The estimator for residualizing the outcome. Must implement
`fit` and `predict` methods.

cv : int, cross-validation generator or an iterable, optional (default=2)
The specification of the cv splitter to be used for cross-fitting, when constructing
the global residuals of Y and T.

discrete_treatment : bool, optional (default=False)
Whether the treatment should be treated as categorical. If True, then the treatment T is
one-hot-encoded and the model_T is treated as a classifier that must have a predict_proba
method.

categories : array like or 'auto', optional (default='auto')
A list of pre-specified treatment categories. If 'auto' then categories are automatically
recognized at fit time.

n_jobs : int, optional (default=-1)
The number of jobs to run in parallel for both :meth:`fit` and :meth:`effect`.
``-1`` means using all processors. Since OrthoForest methods are
computationally heavy, it is recommended to set `n_jobs` to -1.

random_state : int, :class:`~numpy.random.mtrand.RandomState` instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If :class:`~numpy.random.mtrand.RandomState` instance, random_state is the random number generator;
If None, the random number generator is the :class:`~numpy.random.mtrand.RandomState` instance used
by :mod:`np.random<numpy.random>`.

"""

def __init__(self,
n_trees=500,
min_leaf_size=10,
max_depth=10,
subsample_ratio=0.7,
lambda_reg=0.01,
model_T='auto',
model_Y=LassoCVWrapper(cv=3),
cv=2,
discrete_treatment=False,
categories='auto',
n_jobs=-1,
random_state=None):
super().__init__(n_trees=n_trees,
min_leaf_size=min_leaf_size,
max_depth=max_depth,
subsample_ratio=subsample_ratio,
bootstrap=False,
lambda_reg=lambda_reg,
model_T=model_T,
model_Y=model_Y,
model_T_final=None,
model_Y_final=None,
global_residualization=True,
global_res_cv=cv,
discrete_treatment=discrete_treatment,
categories=categories,
n_jobs=n_jobs,
random_state=random_state)
Loading