Skip to content

Commit

Permalink
Feature/ttest criterion (#570)
Browse files Browse the repository at this point in the history
* t test criterion for CausalTreeRegressor
* example for t test criterion
* black formatting fix
  • Loading branch information
volico authored Jul 8, 2023
1 parent a9661d9 commit 5632c53
Show file tree
Hide file tree
Showing 12 changed files with 354 additions and 200 deletions.
10 changes: 6 additions & 4 deletions causalml/inference/iv/drivlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,9 @@ def fit(
mask = (treatment_treat == group) | (
treatment_treat == self.control_name
)
mask_1, mask_0 = mask & (assignment_treat == 1), mask & (
assignment_treat == 0
mask_1, mask_0 = (
mask & (assignment_treat == 1),
mask & (assignment_treat == 0),
)
cur_p_1[group], _ = compute_propensity_score(
X=X_treat[mask_1],
Expand Down Expand Up @@ -232,8 +233,9 @@ def fit(
logger.info("Generate outcome regressions")
for group in self.t_groups:
mask = (treatment_out == group) | (treatment_out == self.control_name)
mask_1, mask_0 = mask & (assignment_out == 1), mask & (
assignment_out == 0
mask_1, mask_0 = (
mask & (assignment_out == 1),
mask & (assignment_out == 0),
)
self.models_mu_c[group][ifold].fit(X_out[mask_0], y_out[mask_0])
self.models_mu_t[group][ifold].fit(X_out[mask_1], y_out[mask_1])
Expand Down
5 changes: 1 addition & 4 deletions causalml/inference/meta/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,7 @@ def clean_dict_keys(orig):
return {clean_xgboost_objective(k): v for (k, v) in orig.items()}

metric_mapping = clean_dict_keys(
{
"rank:pairwise": "auc",
"reg:squarederror": "rmse",
}
{"rank:pairwise": "auc", "reg:squarederror": "rmse"}
)

objective = clean_xgboost_objective(objective)
Expand Down
2 changes: 1 addition & 1 deletion causalml/inference/tree/causal/_criterion.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ cdef struct NodeInfo:
double y_sq_sum # the squared sum of outcomes
double tr_y_sq_sum # the squared sum of outcomes among treatment obs
double ct_y_sq_sum # the squared sum of outcomes among control obs

double split_metric # Additional split metric for t-test criterion

cdef struct SplitState:
NodeInfo node # current node state
Expand Down
114 changes: 111 additions & 3 deletions causalml/inference/tree/causal/_criterion.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ cdef class CausalRegressionCriterion(RegressionCriterion):
memset(&self.sum_total[0], 0, self.n_outputs * sizeof(double))
self.sq_sum_total = 0.
self.eps = 1e-5
self.state.node = [0., 0., 0., 0., 0., 0., 0., 0.]
self.state.left = [0., 0., 0., 0., 0., 0., 0., 0.]
self.state.right = [0., 0., 0., 0., 0., 0., 0., 0.]
self.state.node = [0., 0., 0., 0., 0., 0., 0., 0., 1.]
self.state.left = [0., 0., 0., 0., 0., 0., 0., 0., 1.]
self.state.right = [0., 0., 0., 0., 0., 0., 0., 0., 1.]

for p in range(start, end):
i = samples[p]
Expand Down Expand Up @@ -382,3 +382,111 @@ cdef class CausalMSE(CausalRegressionCriterion):

impurity_left[0] += self.get_groups_penalty(self.state.left.tr_count, self.state.left.ct_count)
impurity_right[0] += self.get_groups_penalty(self.state.right.tr_count, self.state.right.ct_count)


cdef class TTest(CausalRegressionCriterion):
"""
TTest impurity criterion for Causal Tree based on "Su, Xiaogang, et al. (2009). Subgroup analysis via recursive partitioning."
"""
cdef double node_impurity(self) nogil:
cdef double impurity
cdef double node_tau
cdef double tr_var
cdef double ct_var

node_tau = self.get_tau(self.state.node)
tr_var = self.get_variance(
self.state.node.tr_y_sum,
self.state.node.tr_y_sq_sum,
self.state.node.tr_count
)
ct_var = self.get_variance(
self.state.node.ct_y_sum,
self.state.node.ct_y_sq_sum,
self.state.node.ct_count)
# T statistic of difference between treatment and control means
impurity = node_tau / (((tr_var / self.state.node.tr_count) + (ct_var / self.state.node.ct_count)) ** 0.5)

return impurity

cdef double get_tau(self, NodeInfo info) nogil:
return info.tr_y_sum / info.tr_count - info.ct_y_sum / info.ct_count

cdef double get_variance(self, double y_sum, double y_sq_sum, double count) nogil:
return y_sq_sum / count - (y_sum * y_sum) / (count * count)

cdef void children_impurity(self, double * impurity_left, double * impurity_right) nogil:
"""
Evaluate the impurity in children nodes, i.e. the impurity of the
left child (samples[start:pos]) and the impurity the right child
(samples[pos:end]).
"""
cdef double right_tr_var
cdef double right_ct_var
cdef double left_tr_var
cdef double left_ct_var
cdef double right_tau
cdef double left_tau
cdef double right_t_stat
cdef double left_t_stat
cdef double t_stat

right_tau = self.get_tau(self.state.right)
right_tr_var = self.get_variance(
self.state.right.tr_y_sum,
self.state.right.tr_y_sq_sum,
self.state.right.tr_count)
right_ct_var = self.get_variance(
self.state.right.ct_y_sum,
self.state.right.ct_y_sq_sum,
self.state.right.ct_count)

left_tau = self.get_tau(self.state.left)
left_tr_var = self.get_variance(
self.state.left.tr_y_sum,
self.state.left.tr_y_sq_sum,
self.state.left.tr_count)
left_ct_var = self.get_variance(
self.state.left.ct_y_sum,
self.state.left.ct_y_sq_sum,
self.state.left.ct_count)
pooled_var = ((self.state.right.tr_count - 1) / (
self.state.node.tr_count + self.state.node.ct_count - 4)) * right_tr_var + \
(self.state.right.ct_count - 1) / (
self.state.node.tr_count + self.state.node.ct_count - 4) * right_ct_var + \
(self.state.left.tr_count - 1) / (
self.state.node.tr_count + self.state.node.ct_count - 4) * left_tr_var + \
(self.state.left.ct_count - 1) / (
self.state.node.tr_count + self.state.node.ct_count - 4) * left_ct_var

# T statistic of difference between treatment and control means in left and right nodes
left_t_stat = left_tau / (
((left_ct_var / self.state.left.ct_count) + (left_tr_var / self.state.left.tr_count)) ** 0.5)
right_t_stat = right_tau / (
((right_ct_var / self.state.right.ct_count) + (right_tr_var / self.state.right.tr_count)) ** 0.5)

# Squared T statistic of difference between tau from left and right nodes.
t_stat = ((left_tau - right_tau) / ((pooled_var ** 0.5) * (
(1 / self.state.right.tr_count) + (1 / self.state.right.ct_count) + (
1 / self.state.left.tr_count) + (1 / self.state.left.ct_count)) ** 0.5)) ** 2

self.state.left.split_metric = t_stat+self.get_groups_penalty(self.state.node.tr_count,
self.state.node.ct_count)

impurity_left[0] = left_t_stat
impurity_right[0] = right_t_stat

cdef double impurity_improvement(self, double impurity_parent,
double impurity_left,
double impurity_right) nogil:
return self.state.left.split_metric

cdef double proxy_impurity_improvement(self) nogil:
"""Compute a proxy of the impurity reduction. In case of t statistic - proxy_impurity_improvement
is the same as impurity_improvement.
"""
cdef double impurity_left
cdef double impurity_right
self.children_impurity(&impurity_left, &impurity_right)

return self.state.left.split_metric
8 changes: 6 additions & 2 deletions causalml/inference/tree/causal/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
from sklearn.utils.validation import _check_sample_weight

from ._builder import DepthFirstCausalTreeBuilder, BestFirstCausalTreeBuilder
from ._criterion import StandardMSE, CausalMSE
from ._criterion import StandardMSE, CausalMSE, TTest

CAUSAL_TREES_CRITERIA = {"causal_mse": CausalMSE, "standard_mse": StandardMSE}
CAUSAL_TREES_CRITERIA = {
"causal_mse": CausalMSE,
"standard_mse": StandardMSE,
"t_test": TTest,
}
CRITERIA_REG.update(CAUSAL_TREES_CRITERIA)


Expand Down
4 changes: 1 addition & 3 deletions causalml/inference/tree/causal/causalforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,7 @@ def __init__(
"""
super().__init__(
base_estimator=CausalTreeRegressor(
control_name=control_name,
criterion=criterion,
groups_cnt=groups_cnt,
control_name=control_name, criterion=criterion, groups_cnt=groups_cnt
),
n_estimators=n_estimators,
estimator_params=(
Expand Down
9 changes: 2 additions & 7 deletions causalml/inference/tree/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,14 +309,9 @@ def plot_dist_tree_leaves_values(
"""
tree_leaves_mask = get_tree_leaves_mask(tree)
leaves_values = tree.tree_.value.reshape(
-1,
)[tree_leaves_mask]
leaves_values = tree.tree_.value.reshape(-1)[tree_leaves_mask]
fig, ax = plt.subplots(figsize=figsize)
sns.distplot(
leaves_values,
ax=ax,
)
sns.distplot(leaves_values, ax=ax)
plt.title(title, fontsize=fontsize)
plt.show()

Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@
# (source start file, target name, title, author, documentclass
# [howto/manual]).
latex_documents = [
("index", "causalml.tex", "causalml Documentation", "Someone at Uber", "manual"),
("index", "causalml.tex", "causalml Documentation", "Someone at Uber", "manual")
]

# The name of an image file (relative to this directory) to place at
Expand Down Expand Up @@ -269,7 +269,7 @@
"causalml",
"One line description of project.",
"Miscellaneous",
),
)
]

# Documents to append as an appendix to all manuals.
Expand Down
381 changes: 221 additions & 160 deletions examples/causal_trees_with_synthetic_data.ipynb

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions tests/test_causal_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def prepare_data(self, generate_regression_data) -> tuple:
class TestCausalTreeRegressor(CausalTreeBase):
def prepare_causal_tree(self) -> CausalTreeRegressor:
ctree = CausalTreeRegressor(
control_name=self.control_name,
groups_cnt=True,
random_state=RANDOM_SEED,
control_name=self.control_name, groups_cnt=True, random_state=RANDOM_SEED
)
return ctree

Expand Down
8 changes: 1 addition & 7 deletions tests/test_meta_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,13 +936,7 @@ def test_XGBRegressor_with_sample_weights(generate_regression_data):
# Check if XGBRRegressor successfully produces treatment effect estimation
# when sample_weight is passed
uplift_model = XGBRRegressor()
uplift_model.fit(
X=X,
p=e,
treatment=treatment,
y=y,
sample_weight=weights,
)
uplift_model.fit(X=X, p=e, treatment=treatment, y=y, sample_weight=weights)
tau_pred = uplift_model.predict(X=X)
assert len(tau_pred) == len(weights)

Expand Down
5 changes: 1 addition & 4 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@

def test_qini_score():
test_df = pd.DataFrame(
{
"y": [0, 0, 0, 0, 1, 0, 0, 1, 1, 1],
"w": [0] * 5 + [1] * 5,
}
{"y": [0, 0, 0, 0, 1, 0, 0, 1, 1, 1], "w": [0] * 5 + [1] * 5}
)

good_uplift = [_ / 10 for _ in range(0, 5)]
Expand Down

0 comments on commit 5632c53

Please sign in to comment.