Skip to content

Commit

Permalink
Causal trees update (#522)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-pv authored Aug 21, 2022
1 parent 73b7cd3 commit c82d636
Show file tree
Hide file tree
Showing 17 changed files with 3,854 additions and 434 deletions.
5 changes: 5 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ Before submitting a PR, make sure the change to pass all tests and test coverage
$ pytest -vs tests/ --cov causalml/
```

You can also run tests via make:
```bash
$ make test
```


## Submission :tada:

Expand Down
21 changes: 21 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.PHONY: build_ext
build_ext: clean
python setup.py build_ext --force --inplace

.PHONY: build
build: build_ext
python setup.py bdist_wheel

.PHONY: install
install: build_ext
pip install .

.PHONY: test
test: build_ext
pytest -vs --cov causalml/
python setup.py clean --all

.PHONY: clean
clean:
python setup.py clean --all
rm -rf ./build ./dist ./causalml.egg-info
21 changes: 1 addition & 20 deletions causalml/dataset/regression.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import logging

import numpy as np
from scipy.special import expit, logit


logger = logging.getLogger("causalml")


def synthetic_data(mode=1, n=1000, p=5, sigma=1.0, adj=0.0):
""" Synthetic data in Nie X. and Wager S. (2018) 'Quasi-Oracle Estimation of Heterogeneous Treatment Effects'
Args:
mode (int, optional): mode of the simulation: \
1 for difficult nuisance components and an easy treatment effect. \
Expand All @@ -21,10 +20,8 @@ def synthetic_data(mode=1, n=1000, p=5, sigma=1.0, adj=0.0):
sigma (float): standard deviation of the error term
adj (float): adjustment term for the distribution of propensity, e. Higher values shift the distribution to 0.
It does not apply to mode == 2 or 3.
Returns:
(tuple): Synthetically generated samples with the following outputs:
- y ((n,)-array): outcome variable.
- X ((n,p)-ndarray): independent variables.
- w ((n,)-array): treatment flag with value 0 or 1.
Expand All @@ -50,16 +47,13 @@ def synthetic_data(mode=1, n=1000, p=5, sigma=1.0, adj=0.0):
def simulate_nuisance_and_easy_treatment(n=1000, p=5, sigma=1.0, adj=0.0):
"""Synthetic data with a difficult nuisance components and an easy treatment effect
From Setup A in Nie X. and Wager S. (2018) 'Quasi-Oracle Estimation of Heterogeneous Treatment Effects'
Args:
n (int, optional): number of observations
p (int optional): number of covariates (>=5)
sigma (float): standard deviation of the error term
adj (float): adjustment term for the distribution of propensity, e. Higher values shift the distribution to 0.
Returns:
(tuple): Synthetically generated samples with the following outputs:
- y ((n,)-array): outcome variable.
- X ((n,p)-ndarray): independent variables.
- w ((n,)-array): treatment flag with value 0 or 1.
Expand Down Expand Up @@ -92,17 +86,13 @@ def simulate_nuisance_and_easy_treatment(n=1000, p=5, sigma=1.0, adj=0.0):
def simulate_randomized_trial(n=1000, p=5, sigma=1.0, adj=0.0):
"""Synthetic data of a randomized trial
From Setup B in Nie X. and Wager S. (2018) 'Quasi-Oracle Estimation of Heterogeneous Treatment Effects'
Args:
n (int, optional): number of observations
p (int optional): number of covariates (>=5)
sigma (float): standard deviation of the error term
adj (float): no effect. added for consistency
Returns:
(tuple): Synthetically generated samples with the following outputs:
- y ((n,)-array): outcome variable.
- X ((n,p)-ndarray): independent variables.
- w ((n,)-array): treatment flag with value 0 or 1.
Expand All @@ -127,16 +117,13 @@ def simulate_randomized_trial(n=1000, p=5, sigma=1.0, adj=0.0):
def simulate_easy_propensity_difficult_baseline(n=1000, p=5, sigma=1.0, adj=0.0):
"""Synthetic data with easy propensity and a difficult baseline
From Setup C in Nie X. and Wager S. (2018) 'Quasi-Oracle Estimation of Heterogeneous Treatment Effects'
Args:
n (int, optional): number of observations
p (int optional): number of covariates (>=3)
sigma (float): standard deviation of the error term
adj (float): no effect. added for consistency
Returns:
(tuple): Synthetically generated samples with the following outputs:
- y ((n,)-array): outcome variable.
- X ((n,p)-ndarray): independent variables.
- w ((n,)-array): treatment flag with value 0 or 1.
Expand All @@ -159,16 +146,13 @@ def simulate_easy_propensity_difficult_baseline(n=1000, p=5, sigma=1.0, adj=0.0)
def simulate_unrelated_treatment_control(n=1000, p=5, sigma=1.0, adj=0.0):
"""Synthetic data with unrelated treatment and control groups.
From Setup D in Nie X. and Wager S. (2018) 'Quasi-Oracle Estimation of Heterogeneous Treatment Effects'
Args:
n (int, optional): number of observations
p (int optional): number of covariates (>=3)
sigma (float): standard deviation of the error term
adj (float): adjustment term for the distribution of propensity, e. Higher values shift the distribution to 0.
Returns:
(tuple): Synthetically generated samples with the following outputs:
- y ((n,)-array): outcome variable.
- X ((n,p)-ndarray): independent variables.
- w ((n,)-array): treatment flag with value 0 or 1.
Expand Down Expand Up @@ -197,16 +181,13 @@ def simulate_unrelated_treatment_control(n=1000, p=5, sigma=1.0, adj=0.0):
def simulate_hidden_confounder(n=10000, p=5, sigma=1.0, adj=0.0):
"""Synthetic dataset with a hidden confounder biasing treatment.
From Louizos et al. (2018) "Causal Effect Inference with Deep Latent-Variable Models"
Args:
n (int, optional): number of observations
p (int optional): number of covariates (>=3)
sigma (float): standard deviation of the error term
adj (float): no effect. added for consistency
Returns:
(tuple): Synthetically generated samples with the following outputs:
- y ((n,)-array): outcome variable.
- X ((n,p)-ndarray): independent variables.
- w ((n,)-array): treatment flag with value 0 or 1.
Expand Down
2 changes: 1 addition & 1 deletion causalml/dataset/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
BaseSRegressor,
BaseTRegressor,
)
from causalml.inference.tree import CausalTreeRegressor
from causalml.inference.tree.causal.causaltree import CausalTreeRegressor
from causalml.propensity import ElasticNetPropensityModel
from causalml.metrics import plot_gain, get_cumgain

Expand Down
5 changes: 3 additions & 2 deletions causalml/inference/tree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .causal.causaltree import CausalTreeRegressor, CausalRandomForestRegressor
from .plot import uplift_tree_string, uplift_tree_plot, plot_dist_tree_leaves_values
from .uplift import DecisionTree, UpliftTreeClassifier, UpliftRandomForestClassifier
from .causaltree import CausalMSE, CausalTreeRegressor
from .plot import uplift_tree_string, uplift_tree_plot
from .utils import (
cat_group,
cat_transform,
cv_fold_index,
cat_continuous,
kpi_transform,
get_tree_leaves_mask,
)
Empty file.
Loading

0 comments on commit c82d636

Please sign in to comment.