diff --git a/README.md b/README.md index 8c3fedad..13d4a5a7 100644 --- a/README.md +++ b/README.md @@ -30,20 +30,24 @@ The package currently supports the following methods * **Tree-based algorithms** * Uplift tree/random forests on KL divergence, Euclidean Distance, and Chi-Square [[2]](#Literature) * Uplift tree/random forests on Contextual Treatment Selection [[3]](#Literature) - * Causal Tree [[4]](#Literature) - Work-in-progress + * Uplift tree/random forests on DDP [[4]](#Literature) + * Uplift tree/random forests on IDDP [[5]](#Literature) + * Interaction Tree [[6]](#Literature) + * Conditional Interaction Tree [[7]](#Literature) + * Causal Tree [[8]](#Literature) - Work-in-progress * **Meta-learner algorithms** - * S-learner [[5]](#Literature) - * T-learner [[5]](#Literature) - * X-learner [[5]](#Literature) - * R-learner [[6]](#Literature) - * Doubly Robust (DR) learner [[7]](#Literature) - * TMLE learner [[8]](#Literature) + * S-learner [[9]](#Literature) + * T-learner [[9]](#Literature) + * X-learner [[9]](#Literature) + * R-learner [[10]](#Literature) + * Doubly Robust (DR) learner [[11]](#Literature) + * TMLE learner [[12]](#Literature) * **Instrumental variables algorithms** * 2-Stage Least Squares (2SLS) - * Doubly Robust (DR) IV [[9]](#Literature) + * Doubly Robust (DR) IV [[13]](#Literature) * **Neural-network-based algorithms** - * CEVAE [[10]](#Literature) - * DragonNet [[11]](#Literature) - with `causalml[tf]` installation (see [Installation](#installation)) + * CEVAE [[14]](#Literature) + * DragonNet [[15]](#Literature) - with `causalml[tf]` installation (see [Installation](#installation)) # Installation @@ -272,16 +276,20 @@ Bibtex: 1. Chen, Huigang, Totte Harinen, Jeong-Yoon Lee, Mike Yung, and Zhenyu Zhao. "Causalml: Python package for causal machine learning." arXiv preprint arXiv:2002.11631 (2020). 2. Radcliffe, Nicholas J., and Patrick D. Surry. "Real-world uplift modelling with significance-based uplift trees." White Paper TR-2011-1, Stochastic Solutions (2011): 1-33. 3. Zhao, Yan, Xiao Fang, and David Simchi-Levi. "Uplift modeling with multiple treatments and general response types." Proceedings of the 2017 SIAM International Conference on Data Mining. Society for Industrial and Applied Mathematics, 2017. -4. Athey, Susan, and Guido Imbens. "Recursive partitioning for heterogeneous causal effects." Proceedings of the National Academy of Sciences 113.27 (2016): 7353-7360. -5. Künzel, Sören R., et al. "Metalearners for estimating heterogeneous treatment effects using machine learning." Proceedings of the national academy of sciences 116.10 (2019): 4156-4165. -6. Nie, Xinkun, and Stefan Wager. "Quasi-oracle estimation of heterogeneous treatment effects." arXiv preprint arXiv:1712.04912 (2017). -7. Bang, Heejung, and James M. Robins. "Doubly robust estimation in missing data and causal inference models." Biometrics 61.4 (2005): 962-973. -8. Van Der Laan, Mark J., and Daniel Rubin. "Targeted maximum likelihood learning." The international journal of biostatistics 2.1 (2006). -9. Kennedy, Edward H. "Optimal doubly robust estimation of heterogeneous causal effects." arXiv preprint arXiv:2004.14497 (2020). -10. Louizos, Christos, et al. "Causal effect inference with deep latent-variable models." arXiv preprint arXiv:1705.08821 (2017). -11. Shi, Claudia, David M. Blei, and Victor Veitch. "Adapting neural networks for the estimation of treatment effects." 33rd Conference on Neural Information Processing Systems (NeurIPS 2019), 2019. -12. Zhao, Zhenyu, Yumin Zhang, Totte Harinen, and Mike Yung. "Feature Selection Methods for Uplift Modeling." arXiv preprint arXiv:2005.03447 (2020). -13. Zhao, Zhenyu, and Totte Harinen. "Uplift modeling for multiple treatments with cost optimization." In 2019 IEEE International Conference on Data Science and Advanced Analytics (DSAA), pp. 422-431. IEEE, 2019. +4. Hansotia, Behram, and Brad Rukstales. "Incremental value modeling." Journal of Interactive Marketing 16.3 (2002): 35-46. +5. Jannik Rößler, Richard Guse, and Detlef Schoder. "The Best of Two Worlds: Using Recent Advances from Uplift Modeling and Heterogeneous Treatment Effects to Optimize Targeting Policies". International Conference on Information Systems (2022) +6. Su, Xiaogang, et al. "Subgroup analysis via recursive partitioning." Journal of Machine Learning Research 10.2 (2009). +7. Su, Xiaogang, et al. "Facilitating score and causal inference trees for large observational studies." Journal of Machine Learning Research 13 (2012): 2955. +8. Athey, Susan, and Guido Imbens. "Recursive partitioning for heterogeneous causal effects." Proceedings of the National Academy of Sciences 113.27 (2016): 7353-7360. +9. Künzel, Sören R., et al. "Metalearners for estimating heterogeneous treatment effects using machine learning." Proceedings of the national academy of sciences 116.10 (2019): 4156-4165. +10. Nie, Xinkun, and Stefan Wager. "Quasi-oracle estimation of heterogeneous treatment effects." arXiv preprint arXiv:1712.04912 (2017). +11. Bang, Heejung, and James M. Robins. "Doubly robust estimation in missing data and causal inference models." Biometrics 61.4 (2005): 962-973. +12. Van Der Laan, Mark J., and Daniel Rubin. "Targeted maximum likelihood learning." The international journal of biostatistics 2.1 (2006). +13. Kennedy, Edward H. "Optimal doubly robust estimation of heterogeneous causal effects." arXiv preprint arXiv:2004.14497 (2020). +14. Louizos, Christos, et al. "Causal effect inference with deep latent-variable models." arXiv preprint arXiv:1705.08821 (2017). +15. Shi, Claudia, David M. Blei, and Victor Veitch. "Adapting neural networks for the estimation of treatment effects." 33rd Conference on Neural Information Processing Systems (NeurIPS 2019), 2019. +16. Zhao, Zhenyu, Yumin Zhang, Totte Harinen, and Mike Yung. "Feature Selection Methods for Uplift Modeling." arXiv preprint arXiv:2005.03447 (2020). +17. Zhao, Zhenyu, and Totte Harinen. "Uplift modeling for multiple treatments with cost optimization." In 2019 IEEE International Conference on Data Science and Advanced Analytics (DSAA), pp. 422-431. IEEE, 2019. ## Related projects diff --git a/causalml/inference/tree/uplift.pyx b/causalml/inference/tree/uplift.pyx index d0e29f23..baf8fa9d 100644 --- a/causalml/inference/tree/uplift.pyx +++ b/causalml/inference/tree/uplift.pyx @@ -17,18 +17,20 @@ The module structure is the following: # Authors: Zhenyu Zhao # Totte Harinen +import multiprocessing as mp from collections import defaultdict + +import logging import cython -from joblib import Parallel, delayed -import multiprocessing as mp -cimport numpy as np import numpy as np -from packaging import version import pandas as pd import scipy.stats as stats import sklearn -from sklearn.utils import check_array, check_random_state, check_X_y -from typing import List +from joblib import Parallel, delayed +from packaging import version +from sklearn.model_selection import train_test_split +from sklearn.utils import check_X_y, check_array, check_random_state + if version.parse(sklearn.__version__) >= version.parse('0.22.0'): from sklearn.utils._testing import ignore_warnings else: @@ -37,6 +39,7 @@ else: MAX_INT = np.iinfo(np.int32).max +logger = logging.getLogger("causalml") cdef extern from "math.h": double log(double x) nogil @@ -199,7 +202,7 @@ class UpliftTreeClassifier: ---------- evaluationFunction : string - Choose from one of the models: 'KL', 'ED', 'Chi', 'CTS', 'DDP'. + Choose from one of the models: 'KL', 'ED', 'Chi', 'CTS', 'DDP', 'IT', 'CIT', 'IDDP'. max_features: int, optional (default=None) The number of features to consider when looking for the best split. @@ -224,18 +227,29 @@ class UpliftTreeClassifier: The normalization factor defined in Rzepakowski et al. 2012, correcting for tests with large number of splits and imbalanced treatment and control splits. + honesty: bool (default=False) + True if the honest approach based on "Athey, S., & Imbens, G. (2016). Recursive partitioning for heterogeneous causal effects." + shall be used. If 'IDDP' is used as evaluation function, this parameter is automatically set to true. + + estimation_sample_size: float (default=0.5) + Sample size for estimating the CATE score in the leaves if honesty == True. + random_state: int, RandomState instance or None (default=None) A random seed or `np.random.RandomState` to control randomness in building a tree. """ def __init__(self, control_name, max_features=None, max_depth=3, min_samples_leaf=100, min_samples_treatment=10, n_reg=100, evaluationFunction='KL', - normalization=True, random_state=None): + normalization=True, honesty=False, estimation_sample_size=0.5, random_state=None): self.max_depth = max_depth self.min_samples_leaf = min_samples_leaf self.min_samples_treatment = min_samples_treatment self.n_reg = n_reg self.max_features = max_features + + assert evaluationFunction in ['KL', 'ED', 'Chi', 'CTS', 'DDP', 'IT', 'CIT', 'IDDP'], \ + f"evaluationFunction should be either 'KL', 'ED', 'Chi', 'CTS', 'DDP', 'IT', 'CIT', or 'IDDP' but {evaluationFunction} is passed" + if evaluationFunction == 'KL': self.evaluationFunction = self.evaluate_KL elif evaluationFunction == 'ED': @@ -244,7 +258,13 @@ class UpliftTreeClassifier: self.evaluationFunction = self.evaluate_Chi elif evaluationFunction == 'DDP': self.evaluationFunction = self.evaluate_DDP - else: + elif evaluationFunction == 'IT': + self.evaluationFunction = self.evaluate_IT + elif evaluationFunction == 'CIT': + self.evaluationFunction = self.evaluate_CIT + elif evaluationFunction == 'IDDP': + self.evaluationFunction = self.evaluate_IDDP + elif evaluationFunction == 'CTS': self.evaluationFunction = self.evaluate_CTS self.fitted_uplift_tree = None @@ -253,9 +273,13 @@ class UpliftTreeClassifier: self.control_name = control_name self.classes_ = [self.control_name] self.n_class = 1 - self.normalization = normalization + self.honesty = honesty + self.estimation_sample_size = estimation_sample_size self.random_state = random_state + if evaluationFunction == 'IDDP' and self.honesty is False: + self.honesty = True + def fit(self, X, treatment, y): """ Fit the uplift model. @@ -293,11 +317,20 @@ class UpliftTreeClassifier: self.feature_imp_dict = defaultdict(float) - if self.evaluationFunction == self.evaluate_DDP and self.n_class > 2: - raise ValueError("The DDP approach can only cope with two class problems, that is two different treatment " + if (self.n_class > 2) and (self.evaluationFunction in [self.evaluate_DDP, self.evaluate_IDDP, self.evaluate_IT, self.evaluate_CIT]): + raise ValueError("The DDP, IDDP, IT, and CIT approach can only cope with two class problems, that is two different treatment " "options (e.g., control vs treatment). Please select another approach or only use a " "dataset which employs two treatment options.") + if self.honesty: + try: + X, X_est, treatment_idx, treatment_idx_est, y, y_est = train_test_split(X, treatment_idx, y, stratify=[treatment_idx, y], test_size=self.estimation_sample_size, + shuffle=True, random_state=self.random_state) + except ValueError: + logger.warning(f"Stratified sampling failed. Falling back to random sampling.") + X, X_est, treatment_idx, treatment_idx_est, y, y_est = train_test_split(X, treatment_idx, y, test_size=self.estimation_sample_size, shuffle=True, + random_state=self.random_state) + self.fitted_uplift_tree = self.growDecisionTreeFrom( X, treatment_idx, y, max_depth=self.max_depth, min_samples_leaf=self.min_samples_leaf, @@ -305,6 +338,9 @@ class UpliftTreeClassifier: n_reg=self.n_reg, parentNodeSummary=None ) + if self.honesty: + self.honestApproach(X_est, treatment_idx_est, y_est) + self.feature_importances_ = np.zeros(X.shape[1]) for col, imp in self.feature_imp_dict.items(): self.feature_importances_[col] = imp @@ -351,6 +387,49 @@ class UpliftTreeClassifier: parentNodeSummary=None) return self + def honestApproach(self, X_est, T_est, Y_est): + """ Apply the honest approach based on "Athey, S., & Imbens, G. (2016). Recursive partitioning for heterogeneous causal effects." + Args + ---- + X_est : ndarray, shape = [num_samples, num_features] + An ndarray of the covariates used to calculate the unbiased estimates in the leafs of the decision tree. + T_est : array-like, shape = [num_samples] + An array containing the treatment group for each unit. + Y_est : array-like, shape = [num_samples] + An array containing the outcome of interest for each unit. + """ + + self.modifyEstimation(X_est, T_est, Y_est, self.fitted_uplift_tree) + + def modifyEstimation(self, X_est, t_est, y_est, tree): + """ Modifies the leafs of the current decision tree to only contain unbiased estimates. + Applies the honest approach based on "Athey, S., & Imbens, G. (2016). Recursive partitioning for heterogeneous causal effects." + Args + ---- + X_est : ndarray, shape = [num_samples, num_features] + An ndarray of the covariates used to calculate the unbiased estimates in the leafs of the decision tree. + T_est : array-like, shape = [num_samples] + An array containing the treatment group for each unit. + Y_est : array-like, shape = [num_samples] + An array containing the outcome of interest for each unit. + tree : object + object of DecisionTree class - the current decision tree that shall be modified + """ + + # Divide sets for child nodes + if tree.trueBranch or tree.falseBranch: + X_l, X_r, w_l, w_r, y_l, y_r = self.divideSet(X_est, t_est, y_est, tree.col, tree.value) + + # recursive call for each branch + if tree.trueBranch is not None: + self.modifyEstimation(X_l, w_l, y_l, tree.trueBranch) + if tree.falseBranch is not None: + self.modifyEstimation(X_r, w_r, y_r, tree.falseBranch) + + # classProb + if tree.results is not None: + tree.results = self.uplift_classification_results(t_est, y_est) + def pruneTree(self, X, treatment_idx, y, tree, rule='maxAbsDiff', minGain=0., n_reg=0, parentNodeSummary=None): @@ -751,6 +830,157 @@ class UpliftTreeClassifier: d_res += treatment_group[0] - pc return d_res + @staticmethod + def evaluate_IT(leftNodeSummary, rightNodeSummary, w_l, w_r): + ''' + Calculate Squared T-Statistic as split evaluation criterion for a given node + + Args + ---- + leftNodeSummary : list of list + The left node summary statistics. + rightNodeSummary : list of list + The right node summary statistics. + w_l: array-like, shape = [num_samples] + An array containing the treatment for each unit in the left node + w_r: array-like, shape = [num_samples] + An array containing the treatment for each unit in the right node + + Returns + ------- + g_s : Squared T-Statistic + ''' + g_s = 0 + + ## Control Group + # Sample mean in left & right child node + y_l_0 = leftNodeSummary[0][0] + y_r_0 = rightNodeSummary[0][0] + # Sample size left & right child node + n_3 = leftNodeSummary[0][1] + n_4 = rightNodeSummary[0][1] + # Sample variance in left & right child node (p*(p-1) for bernoulli) + s_3 = y_l_0*(1-y_l_0) + s_4 = y_r_0*(1-y_r_0) + + for treatment_left, treatment_right in zip(leftNodeSummary[1:], rightNodeSummary[1:]): + ## Treatment Group + # Sample mean in left & right child node + y_l_1 = treatment_left[0] + y_r_1 = treatment_right[0] + # Sample size left & right child node + n_1 = treatment_left[1] + n_2 = treatment_right[1] + # Sample variance in left & right child node + s_1 = y_l_1*(1-y_l_1) + s_2 = y_r_1*(1-y_r_1) + + sum_n = np.sum([n_1 - 1, n_2 - 1, n_3 - 1, n_4 - 1]) + w_1 = (n_1 - 1) / sum_n + w_2 = (n_2 - 1) / sum_n + w_3 = (n_3 - 1) / sum_n + w_4 = (n_4 - 1) / sum_n + + # Pooled estimator of the constant variance + sigma = np.sqrt(np.sum([w_1 * s_1, w_2 * s_2, w_3 * s_3, w_4 * s_4])) + + # Squared t-statistic + g_s = np.power(((y_l_1 - y_l_0) - (y_r_1 - y_r_0)) / (sigma * np.sqrt(np.sum([1 / n_1, 1 / n_2, 1 / n_3, 1 / n_4]))), 2) + + return g_s + + @staticmethod + def evaluate_CIT(currentNodeSummary, leftNodeSummary, rightNodeSummary, y_l, y_r, w_l, w_r, y, w): + ''' + Calculate likelihood ratio test statistic as split evaluation criterion for a given node + Args + ---- + currentNodeSummary: list of lists + The parent node summary statistics + leftNodeSummary : list of lists + The left node summary statistics. + rightNodeSummary : list of lists + The right node summary statistics. + y_l: array-like, shape = [num_samples] + An array containing the outcome of interest for each unit in the left node + y_r: array-like, shape = [num_samples] + An array containing the outcome of interest for each unit in the right node + w_l: array-like, shape = [num_samples] + An array containing the treatment for each unit in the left node + w_r: array-like, shape = [num_samples] + An array containing the treatment for each unit in the right node + y: array-like, shape = [num_samples] + An array containing the outcome of interest for each unit + w: array-like, shape = [num_samples] + An array containing the treatment for each unit + Returns + ------- + lrt : Likelihood ratio test statistic + ''' + lrt = 0 + + # Control sample size left & right child node + n_l_t_0 = leftNodeSummary[0][1] + n_r_t_0 = rightNodeSummary[0][1] + + for treatment_left, treatment_right in zip(leftNodeSummary[1:], rightNodeSummary[1:]): + # Treatment sample size left & right child node + n_l_t_1 = treatment_left[1] + n_r_t_1 = treatment_right[1] + + # Total size of left & right node + n_l_t = n_l_t_1 + n_l_t_0 + n_r_t = n_r_t_1 + n_r_t_0 + + # Total size of parent node + n_t = n_l_t + n_r_t + + # Total treatment & control size in parent node + n_t_1 = n_l_t_1 + n_r_t_1 + n_t_0 = n_l_t_0 + n_r_t_0 + + # Standard squared error of left child node + sse_tau_l = np.sum(np.power(y_l[w_l == 1] - treatment_left[0], 2)) + np.sum( + np.power(y_l[w_l == 0] - treatment_left[0], 2)) + + # Standard squared error of right child node + sse_tau_r = np.sum(np.power(y_r[w_r == 1] - treatment_right[0], 2)) + np.sum( + np.power(y_r[w_r == 0] - treatment_right[0], 2)) + + # Standard squared error of parent child node + sse_tau = np.sum(np.power(y[w == 1] - currentNodeSummary[1][0], 2)) + np.sum( + np.power(y[w == 0] - currentNodeSummary[0][0], 2)) + + # Maximized log-likelihood function + i_tau_l = - (n_l_t / 2) * np.log(n_l_t * sse_tau_l) + n_l_t_1 * np.log(n_l_t_1) + n_l_t_0 * np.log(n_l_t_0) + i_tau_r = - (n_r_t / 2) * np.log(n_r_t * sse_tau_r) + n_r_t_1 * np.log(n_r_t_1) + n_r_t_0 * np.log(n_r_t_0) + i_tau = - (n_t / 2) * np.log(n_t * sse_tau) + n_t_1 * np.log(n_t_1) + n_t_0 * np.log(n_t_0) + + # Likelihood ration test statistic + lrt = 2 * (i_tau_l + i_tau_r - i_tau) + + return lrt + + @staticmethod + def evaluate_IDDP(nodeSummary): + ''' + Calculate Delta P as split evaluation criterion for a given node. + Args + ---- + nodeSummary : dictionary + The tree node summary statistics, produced by tree_node_summary() method. + control_name : string + The control group name. + Returns + ------- + d_res : Delta P + ''' + pc = nodeSummary[0][0] + d_res = 0 + for treatment_group in nodeSummary[1:]: + d_res += treatment_group[0] - pc + return d_res + @staticmethod def evaluate_CTS(nodeSummary): ''' @@ -767,8 +997,7 @@ class UpliftTreeClassifier: ''' return -max([stat[0] for stat in nodeSummary]) - def normI(self, n_c: cython.int, n_c_left: cython.int, n_t: list, n_t_left: list, - alpha: cython.float=0.9) -> cython.float: + def normI(self, n_c: cython.int, n_c_left: cython.int, n_t: list, n_t_left: list, alpha: cython.float = 0.9, currentDivergence: cython.float = 0.0) -> cython.float: ''' Normalization factor. @@ -796,21 +1025,21 @@ class UpliftTreeClassifier: pt_a = 1. * np.sum(n_t_left) / (np.sum(n_t) + 0.1) pc_a = 1. * n_c_left / (n_c + 0.1) - # Normalization Part 1 - norm_res += ( - alpha * entropyH(1. * np.sum(n_t) / (np.sum(n_t) + n_c), 1. * n_c / (np.sum(n_t) + n_c)) - * kl_divergence(pt_a, pc_a) - ) - # Normalization Part 2 & 3 - for i in range(len(n_t)): - pt_a_i = 1. * n_t_left[i] / (n_t[i] + 0.1) - norm_res += ( - (1 - alpha) * entropyH(1. * n_t[i] / (n_t[i] + n_c), 1. * n_c / (n_t[i] + n_c)) - * kl_divergence(1. * pt_a_i, pc_a) - ) - norm_res += (1. * n_t[i] / (np.sum(n_t) + n_c) * entropyH(pt_a_i)) + if self.evaluationFunction == self.evaluate_IDDP: + # Normalization Part 1 + norm_res += (entropyH(1. * np.sum(n_t) / (np.sum(n_t) + n_c), 1. * n_c / (np.sum(n_t) + n_c)) * currentDivergence) + norm_res += (1. * np.sum(n_t) / (np.sum(n_t) + n_c) * entropyH(pt_a)) + + else: + # Normalization Part 1 + norm_res += (alpha * entropyH(1. * np.sum(n_t) / (np.sum(n_t) + n_c), 1. * n_c / (np.sum(n_t) + n_c)) * kl_divergence(pt_a, pc_a)) + # Normalization Part 2 & 3 + for i in range(len(n_t)): + pt_a_i = 1. * n_t_left[i] / (n_t[i] + 0.1) + norm_res += ((1 - alpha) * entropyH(1. * n_t[i] / (n_t[i] + n_c), 1. * n_c / (n_t[i] + n_c)) * kl_divergence(1. * pt_a_i, pc_a)) + norm_res += (1. * n_t[i] / (np.sum(n_t) + n_c) * entropyH(pt_a_i)) # Normalization Part 4 - norm_res += 1. * n_c/(np.sum(n_t) + n_c) * entropyH(pc_a) + norm_res += 1. * n_c / (np.sum(n_t) + n_c) * entropyH(pc_a) # Normalization Part 5 norm_res += 0.5 @@ -931,7 +1160,11 @@ class UpliftTreeClassifier: min_samples_treatment=min_samples_treatment, n_reg=n_reg, parentNodeSummary=parentNodeSummary) - currentScore = self.evaluationFunction(currentNodeSummary) + + if self.evaluationFunction == self.evaluate_IT or self.evaluationFunction == self.evaluate_CIT: + currentScore = 0 + else: + currentScore = self.evaluationFunction(currentNodeSummary) # Prune Stats maxAbsDiff = 0 @@ -1023,6 +1256,28 @@ class UpliftTreeClassifier: rightScore2 = self.evaluationFunction(rightNodeSummary) gain = np.abs(leftScore1 - rightScore2) gain_for_imp = np.abs(len(X_l) * leftScore1 - len(X_r) * rightScore2) + elif self.evaluationFunction == self.evaluate_IT: + gain = self.evaluationFunction(leftNodeSummary, rightNodeSummary, w_l, w_r) + gain_for_imp = gain * len(X) + elif self.evaluationFunction == self.evaluate_CIT: + gain = self.evaluationFunction(currentNodeSummary, leftNodeSummary, rightNodeSummary, y_l, y_r, w_l, w_r, y, treatment_idx) + gain_for_imp = gain * len(X) + elif self.evaluationFunction == self.evaluate_IDDP: + leftScore1 = self.evaluationFunction(leftNodeSummary) + rightScore2 = self.evaluationFunction(rightNodeSummary) + gain = np.abs(leftScore1 - rightScore2) - np.abs(currentScore) + gain_for_imp = (len(X_l) * leftScore1 + len(X_r) * rightScore2 - len(X) * np.abs(currentScore)) + if self.normalization: + # Normalize used divergence + currentDivergence = 2 * (gain + 1) / 3 + n_c = currentNodeSummary[0][1] + n_c_left = leftNodeSummary[0][1] + n_t = [tr[1] for tr in currentNodeSummary[1:]] + n_t_left = [tr[1] for tr in leftNodeSummary[1:]] + norm_factor = self.normI(n_c, n_c_left, n_t, n_t_left, alpha=0.9, currentDivergence=currentDivergence) + else: + norm_factor = 1 + gain = gain / norm_factor else: leftScore1 = self.evaluationFunction(leftNodeSummary) rightScore2 = self.evaluationFunction(rightNodeSummary) @@ -1206,7 +1461,7 @@ class UpliftRandomForestClassifier: The number of trees in the uplift random forest. evaluationFunction : string - Choose from one of the models: 'KL', 'ED', 'Chi', 'CTS', 'DDP'. + Choose from one of the models: 'KL', 'ED', 'Chi', 'CTS', 'DDP', 'IT', 'CIT', 'IDDP'. max_features: int, optional (default=10) The number of features to consider when looking for the best split. @@ -1236,6 +1491,13 @@ class UpliftRandomForestClassifier: correcting for tests with large number of splits and imbalanced treatment and control splits + honesty: bool (default=False) + True if the honest approach based on "Athey, S., & Imbens, G. (2016). Recursive partitioning for + heterogeneous causal effects." shall be used. + + estimation_sample_size: float (default=0.5) + Sample size for estimating the CATE score in the leaves if honesty == True. + n_jobs: int, optional (default=-1) The parallelization parameter to define how many parallel jobs need to be created. This is passed on to joblib library for parallelizing uplift-tree creation and prediction. @@ -1260,6 +1522,8 @@ class UpliftRandomForestClassifier: n_reg=10, evaluationFunction='KL', normalization=True, + honesty=False, + estimation_sample_size=0.5, n_jobs=-1, joblib_prefer: str = "threads"): @@ -1276,6 +1540,7 @@ class UpliftRandomForestClassifier: self.evaluationFunction = evaluationFunction self.control_name = control_name self.normalization = normalization + self.honesty = honesty self.n_jobs = n_jobs self.joblib_prefer = joblib_prefer @@ -1315,6 +1580,7 @@ class UpliftRandomForestClassifier: evaluationFunction=self.evaluationFunction, control_name=self.control_name, normalization=self.normalization, + honesty=self.honesty, random_state=random_state.randint(MAX_INT)) for _ in range(self.n_estimators) ] diff --git a/docs/about.rst b/docs/about.rst index 9c65bf78..f7f86909 100644 --- a/docs/about.rst +++ b/docs/about.rst @@ -15,7 +15,10 @@ The package currently supports the following methods: - Tree-based algorithms - :ref:`Uplift Random Forests ` on KL divergence, Euclidean Distance, and Chi-Square - :ref:`Uplift Random Forests ` on Contextual Treatment Selection - - :ref:`Uplift Random Forests ` on delta-delta-p (:math:`\Delta\Delta P`) criterion (only for binary trees and two-class problems) + - :ref:`Uplift Random Forests ` on delta-delta-p (:math:`\Delta\Delta P`) criterion (only for binary trees and two-class problems) + - :ref:`Uplift Random Forests ` on IDDP (only for binary trees and two-class problems) + - :ref:`Interaction Tree ` (only for binary trees and two-class problems) + - :ref:`Causal Inference Tree ` (only for binary trees and two-class problems) - Meta-learner algorithms - :ref:`S-learner` - :ref:`T-learner` diff --git a/docs/methodology.rst b/docs/methodology.rst index e815c0ae..12fc0b7f 100755 --- a/docs/methodology.rst +++ b/docs/methodology.rst @@ -205,7 +205,70 @@ Another Uplift Tree algorithm that is implemented is the delta-delta-p (:math:`\ where :math:`a_0` and :math:`a_1` are the outcomes of a Split A, :math:`y` is the selected class, and :math:`P^T` and :math:`P^C` are the response rates of treatment and control group, respectively. In other words, we first calculate the difference in the response rate in each branch (:math:`\Delta P_{left}` and :math:`\Delta P_{right}`), and subsequently, calculate their differences (:math:`\Delta\Delta P = |\Delta P_{left} - \Delta P_{right}|`). +IDDP +~~~~ + +Build upon the :math:`\Delta\Delta P` approach, the IDDP approach by :cite:`rossler2022the` is implemented, where the sample splitting +criterion is defined as follows: + +.. math:: + IDDP = \frac{\Delta\Delta P^*}{I(\phi, \phi_l, \phi_r)} + +where :math:`\Delta\Delta P^*` is defined as :math:`\Delta\Delta P - |E[Y(1) - Y(0)]| X \epsilon \phi|` and +:math:`I(\phi, \phi_l, \phi_r)` is defined as: + +.. math:: + I(\phi, \phi_l, \phi_r) = H(\frac{n_t(\phi)} {n(\phi)}, \frac{n_c(\phi)}{n(\phi)}) * 2 \frac{1+\Delta\Delta P^*}{3} + \frac{n_t(\phi)}{n(\phi)} H(\frac{n_t(\phi_l)}{n(\phi)}, \frac{n_t(\phi_r)}{n(\phi)}) \\ + + \frac{n_c(\phi)}{n(\phi)} * H(\frac{n_c(\phi_l)}{n(\phi)}, \frac{n_c(\phi_r)}{n(\phi)}) + \frac{1}{2} +where the entropy H is defined as :math:`H(p,q)=(-p*log_2(p)) + (-q*log_2(q))` and where :math:`\phi` is a subset of the feature space +associated with the current decision node, and :math:`\phi_l` and :math:`\phi_r` are the left and right child nodes, respectively. +:math:`n_t(\phi)` is the number of treatment samples, :math:`n_c(\phi)` the number of control samples, and :math:`n(\phi)` the number +of all samples in the current (parent) node. + +IT +~~ + +Further, the package implements the Interaction Tree (IT) proposed by :cite:`su2009subgroup`, where the sample splitting criterion +maximizes the G statistic among all permissible splits: + +.. math:: + G(s^*) = max G(s) + +where :math:`G(s)=t^2(s)` and :math:`t(s)` is defined as: + +.. math:: + t(s) = \frac{(y^L_1 - y^L_0) - (y^R_1 - y^R_0)}{\sigma * (1/n_1 + 1/n_2 + 1/n_3 + 1/n_4)} + +where :math:`\sigma=\sum_{i=4}^4w_is_i^2` is a pooled estimator of the constant variance, and :math:`w_i=(n_i-1)/\sum_{j=1}^4(n_j-1)`. +Further, :math:`y^L_1`, :math:`s^2_1`, and :math:`n_1` are the the sample mean, the sample variance, and the sample size +for the treatment group in the left child node ,respectively. Similar notation applies to the other quantities. + +Note that this implementation deviates from the original implementation in that (1) the pruning techniques and (2) the validation method +for determining the best tree size are different. + +CIT +~~~ + +Also, the package implements the Causal Inference Tree (CIT) by :cite:`su2012facilitating`, where the sample splitting +criterion calculates the likelihood ratio test statistic: + +.. math:: + LRT(s) = -n_{\tau L}/2 * ln(n_{\tau L} SSE_{\tau L}) -n_{\tau R}/2 * ln(n_{\tau R} SSE_{\tau R}) + \\ + n_{\tau L1} ln n_{\tau L1} + n_{\tau L0} ln n_{\tau L0} + n_{\tau R1} ln n_{\tau R1} + n_{\tau R0} ln n_{\tau R0} + +where :math:`n_{\tau}`, :math:`n_{\tau 0}`, and :math:`n_{\tau 1}` are the total number of observations in node :math:`\tau`, +the number of observations in node :math:`\tau` that are assigned to the control group, and the number of observations in node :math:`\tau` +that are assigned to the treatment group, respectively. :math:`SSE_{\tau}` is defined as: + +.. math:: + SSE_{\tau} = \sum_{i \epsilon \tau: t_i=1}(y_i - \hat{y_{t1}})^2 + \sum_{i \epsilon \tau: t_i=0}(y_i - \hat{y_{t0}})^2 + +and :math:`\hat{y_{t0}}` and :math:`\hat{y_{t1}}` are the sample average responses of the control and treatment groups in node +:math:`\tau`, respectively. + +Note that this implementation deviates from the original implementation in that (1) the pruning techniques and (2) the validation method +for determining the best tree size are different. CTS ~~~ @@ -217,6 +280,7 @@ The final Uplift Tree algorithm that is implemented is the Contextual Treatment where :math:`\phi_l` and :math:`\phi_r` refer to the feature subspaces in the left leaf and the right leaves respectively, :math:`\hat{p}(\phi_j \mid \phi)` denotes the estimated conditional probability of a subject's being in :math:`\phi_j` given :math:`\phi`, and :math:`\hat{y}_t(\phi_j)` is the conditional expected response under treatment :math:`t`. + Value optimization methods -------------------------- diff --git a/docs/refs.bib b/docs/refs.bib index a07ef1f1..8b0c13d1 100755 --- a/docs/refs.bib +++ b/docs/refs.bib @@ -359,6 +359,32 @@ @article{hansotia2002ddp year={2002}, } +@article{su2009subgroup, + title={Subgroup analysis via recursive partitioning.}, + author={Su, Xiaogang and Tsai, Chih-Ling and Wang, Hansheng and Nickerson, David M and Li, Bogong}, + journal={Journal of Machine Learning Research}, + volume={10}, + number={2}, + year={2009} +} + +@article{su2012facilitating, + title={Facilitating score and causal inference trees for large observational studies}, + author={Su, Xiaogang and Kang, Joseph and Fan, Juanjuan and Levine, Richard A and Yan, Xin}, + journal={Journal of Machine Learning Research}, + volume={13}, + pages={2955}, + year={2012} +} + +@article{rossler2022the, + title={The Best of Two Worlds: Using Recent Advances from Uplift Modeling and Heterogeneous Treatment Effects to Optimize Targeting + Policies}, + author={R{\"o}{\ss}ler, Jannik and Guse, Richard and Schoder, Detlef}, + journal={International Conference on Information Systems}, + year={2022} +} + @article{https://doi.org/10.1111/1468-0262.00442, author = {Hirano, Keisuke and Imbens, Guido W. and Ridder, Geert}, title = {Efficient Estimation of Average Treatment Effects Using the Estimated Propensity Score}, diff --git a/tests/conftest.py b/tests/conftest.py index b0cd932a..d9532fac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,14 @@ from causalml.dataset import synthetic_data from causalml.dataset import make_uplift_classification -from .const import RANDOM_SEED, N_SAMPLE, TREATMENT_NAMES, CONVERSION +from .const import ( + RANDOM_SEED, + N_SAMPLE, + TREATMENT_NAMES, + CONVERSION, + DELTA_UPLIFT_INCREASE_DICT, + N_UPLIFT_INCREASE_DICT, +) @pytest.fixture(scope="module") @@ -52,6 +59,8 @@ def _generate_data(): treatment_name=TREATMENT_NAMES[0:2], y_name=CONVERSION, random_seed=RANDOM_SEED, + delta_uplift_increase_dict=DELTA_UPLIFT_INCREASE_DICT, + n_uplift_increase_dict=N_UPLIFT_INCREASE_DICT, ) return data diff --git a/tests/const.py b/tests/const.py index 1ea843de..e94df526 100644 --- a/tests/const.py +++ b/tests/const.py @@ -11,3 +11,7 @@ CONTROL_NAME = "control" TREATMENT_NAMES = [CONTROL_NAME, "treatment1", "treatment2", "treatment3"] CONVERSION = "conversion" +DELTA_UPLIFT_INCREASE_DICT = { + "treatment1": 0.1, +} +N_UPLIFT_INCREASE_DICT = {"treatment1": 5} diff --git a/tests/test_uplift_trees.py b/tests/test_uplift_trees.py index 652562ee..b3bb2346 100644 --- a/tests/test_uplift_trees.py +++ b/tests/test_uplift_trees.py @@ -92,15 +92,35 @@ def test_UpliftRandomForestClassifier( assert cumgain["uplift_tree"].sum() > cumgain["Random"].sum() -def test_UpliftTreeClassifier(generate_classification_data): +@pytest.mark.parametrize("evaluation_function", ["DDP", "IT", "CIT", "IDDP"]) +def test_UpliftTreeClassifierTwoTreatments( + generate_classification_data_two_treatments, evaluation_function +): + df, x_names = generate_classification_data_two_treatments() + UpliftTreeClassifierTesting(df, x_names, evaluation_function) + + +@pytest.mark.parametrize("evaluation_function", ["KL", "Chi", "ED", "CTS"]) +def test_UpliftTreeClassifierMultipleTreatments( + generate_classification_data, evaluation_function +): df, x_names = generate_classification_data() + UpliftTreeClassifierTesting(df, x_names, evaluation_function) + + +def UpliftTreeClassifierTesting(df, x_names, evaluation_function): df_train, df_test = train_test_split(df, test_size=0.2, random_state=RANDOM_SEED) # Train the UpLift Random Forest classifier uplift_model = UpliftTreeClassifier( - control_name=TREATMENT_NAMES[0], random_state=RANDOM_SEED + control_name=TREATMENT_NAMES[0], + random_state=RANDOM_SEED, + evaluationFunction=evaluation_function, ) + if evaluation_function == "IDDP": + assert uplift_model.honesty is True + pr = cProfile.Profile(subcalls=True, builtins=True, timeunit=0.001) pr.enable() uplift_model.fit( @@ -143,7 +163,7 @@ def test_UpliftTreeClassifier(generate_classification_data): ) # Check if the cumulative gain of UpLift Random Forest is higher than - # random + # random (sometimes IT and IDDP are not better than random) assert cumgain["uplift_tree"].sum() > cumgain["Random"].sum() # Check if the total count is split correctly, at least for control group in the first level