Skip to content

Commit d912da1

Browse files
authored
Add revised CSF experiments (#1083)
1 parent 00ce7c1 commit d912da1

File tree

10 files changed

+232
-90
lines changed

10 files changed

+232
-90
lines changed

experiments/csf/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ _This folder has replication files for the paper "Estimating Heterogeneous Treat
22

33
* Figure 1: `prediction_comparison.R`
44

5-
* Table 1 and Table 2, Figure 5 and Figure 6: `simulation_mse_output.R`
5+
* MSE and classification error simulations: `simulation_mse_output.R`
66

7-
* Table 3: `simulation_coverage_output.R`
7+
* 95 % CI coverage table: `simulation_coverage_output.R`
88

9-
* Figure 2: `simulation_blp.R`
9+
* Best linear projection simulation: `simulation_blp.R`
1010

11-
* Figure 3, Figure 4, and Table 4: `hiv.R`
11+
* HIV application: `hiv.R`
1212

1313
These scripts were run using R version 3.5. In addition to `grf` they rely on the additional packages: `"ggplot2", "randomForestSRC", "speff2trial", "texreg", "xtable"`.

experiments/csf/hiv.R

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
# The following script reproduces the HIV application from
2+
# the manuscript.
13
rm(list = ls())
24
library(ggplot2)
35
library(texreg)
46
library(speff2trial) # "ACTG175" data set.
57
library(grf)
8+
library(xtable)
69
set.seed(123)
710

811
data = ACTG175[ACTG175$arms == 1 | ACTG175$arms == 3, ]
@@ -13,10 +16,10 @@ Y = data$days
1316
W = as.numeric(data$arms == 1) # W = 0 : ddI, W = 1: ZDV+ddI
1417
D = data$cens
1518

16-
# Figure 3 - histogram overlaid
19+
# Overlaid histogram with T.max
1720
ggplot(data.frame(Y, Censored = factor(D, labels = c("Yes", "No"))), aes(x = Y, fill = Censored)) +
1821
geom_histogram(alpha = 0.5) +
19-
geom_vline(xintercept = 1000, linetype = 2, col = "red") +
22+
geom_vline(xintercept = 1000, linetype = 1, col = "red") +
2023
xlab("Survival time (days)") +
2124
ylab("Frequency") +
2225
theme_bw() +
@@ -27,21 +30,37 @@ ggsave("HIV_histogram.pdf", width = 6, height = 5)
2730
# Truncate Y at Y.max
2831
Y.max = 1000
2932

30-
cs.forest = causal_survival_forest(X, Y, W, D, horizon = Y.max)
33+
cs.forest = causal_survival_forest(X, Y, W, D, horizon = Y.max, num.trees = 10000, ci.group.size = 12)
34+
35+
# Estimates and SEs for a random subset of individuals
36+
idx = sample(nrow(X), 10)
37+
pp = predict(cs.forest, estimate.variance = TRUE)
38+
vimp = variable_importance(cs.forest)
39+
colnames(X)[order(vimp)[1:4]]
40+
df = data.frame(
41+
CATE = pp$predictions[idx],
42+
CATE.se = sqrt(pp$variance.estimates[idx]),
43+
hemophilia = ifelse(X[idx, "hemo"] == 1, "Yes", "No"),
44+
gender = ifelse(X[idx, "gender"] == 1, "Male", "Female"),
45+
homosexual.activity = ifelse(X[idx, "homo"] == 1, "Yes", "No"),
46+
antiretroviral.history = ifelse(X[idx, "preanti"] == 1, "Experienced", "Naive")
47+
)
48+
print(xtable(df[order(df$CATE), ]
49+
), include.rownames = FALSE)
3150

3251
# BLP
3352
full = best_linear_projection(cs.forest, X)
3453
age = best_linear_projection(cs.forest, X[, "age", drop = F])
3554

3655
# Same names as in paper
37-
varnames = c("Constant", "age", "weight", "Karnofsky score",
38-
"CD4 count", "CD8 count", "gender", "homosexual activity",
39-
"race", "symptomatic status", "intravenous drug use",
40-
"hemophilia", "antiretroviral history",
56+
varnames = c("Constant", "Age", "Weight", "Karnofsky score",
57+
"CD4 count", "CD8 count", "Gender", "Homosexual activity",
58+
"Race", "Symptomatic status", "Intravenous drug use",
59+
"Hemophilia", "Antiretroviral history",
4160
"CD4 count 20+/-5 weeks", "CD8 count 20+/-5 weeks"
4261
)
4362

44-
# Table 4
63+
# BLP Table
4564
texreg(list(full, age),
4665
custom.model.names = c("All covariates", "Age only"),
4766
table = FALSE,
@@ -51,14 +70,18 @@ texreg(list(full, age),
5170
custom.coef.names = varnames
5271
)
5372

54-
# Figure 4
73+
# CATE plot
5574
X.median <- apply(X, 2, median)
5675
age.test = seq(min(X$age), max(X$age))
5776
X.test = matrix(rep(X.median, length(age.test)), length(age.test), byrow = TRUE)
5877
X.test[, 1] = age.test
59-
cs.pred = predict(cs.forest, X.test)
78+
cs.pred = predict(cs.forest, X.test, estimate.variance = TRUE)
6079
pt = cs.pred$predictions
80+
ub = pt + sqrt(cs.pred$variance.estimates) * qnorm(0.975)
81+
lb = pt - sqrt(cs.pred$variance.estimates) * qnorm(0.975)
6182
pdf("HIV_data.pdf")
62-
plot(X.test[, 1], pt, type = 'l', xlab = "Age (years)", ylab ="CATE (days)")
83+
plot(X.test[, 1], pt, type = 'l', xlab = "Age (years)", ylab = "CATE (days)", ylim = c(min(lb), max(ub)))
84+
lines(X.test[, 1], ub, lty = 2)
85+
lines(X.test[, 1], lb, lty = 2)
6386
grid()
6487
dev.off()

experiments/csf/prediction_comparison.R

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ estimators = list(SRC = SRC1,
1515
CSF = CSF)
1616

1717
out = list()
18-
n = 2000
19-
p = 5
18+
n = 5000
19+
p = 15
2020
n.test = 2000
2121
dgp = "type2"
2222
# dgp = "type3"
23-
n.mc = 100000
2423

25-
data = generate_causal_survival_data(n = n, p = p, dgp = dgp, n.mc = 10)
26-
data.test = generate_causal_survival_data(n = n.test, p = p, dgp = dgp, n.mc = n.mc)
24+
data = generate_causal_survival_data(n = n, p = p, dgp = dgp, n.mc = 1)
25+
data$Y = round(data$Y, 2)
26+
data.test = generate_causal_survival_data(n = n.test, p = p, dgp = dgp, n.mc = 100000)
2727
true.cate = data.test$cate
2828
for (j in 1:length(estimators)) {
2929
estimator = names(estimators)[j]
@@ -49,11 +49,11 @@ out$label = factor(out$label, levels = unique(out$label)[c(2, 1, 3)])
4949

5050
ggplot(out, aes(y = predictions, x = true.cate)) +
5151
geom_point(size = 0.1) +
52-
geom_abline(intercept = 0, slope = 1, col = "red", lty = 3) +
52+
geom_abline(intercept = 0, slope = 1, col = "red", lty = 1) +
5353
facet_wrap(. ~ label, ncol = 3) +
5454
theme_bw() +
55-
xlab("True effect") +
56-
# xlab("") +
55+
# xlab("True effect") +
56+
xlab("") +
5757
ylab("Estimated effect")
5858

5959
ggsave(paste0("prediction_comparsion_", dgp, ".pdf"), width = 6, height = 3)

experiments/csf/simulation_blp.R

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@ library(grf)
44
set.seed(123)
55

66
n = 2000
7-
p = 5
7+
p = 15
88
dgp = "type3"
9-
nreps = 200
9+
nreps = 500
1010

1111
# ground truth
12-
n.test = 50000
13-
data.test = generate_causal_survival_data(n.test, p, dgp=dgp, n.mc = 50000)
12+
data.test = generate_causal_survival_data(50000, p, dgp=dgp, n.mc = 50000)
1413
df = data.frame(cate=data.test$cate, x=data.test$X)
1514
lm1 = coeftest(lm(cate ~ x.1 + x.2, df))
1615
true = lm1[2, 1]
@@ -60,15 +59,15 @@ res = replicate(nreps, {
6059
# Figure 2:
6160
res.cov = round(rowMeans(res), 2)
6261
pdf("blp_simulation.pdf")
63-
breaks = 7
62+
breaks = 20
6463
par(mfrow = c(2, 2))
6564
hist(res["blp.cate", ], breaks = breaks, main = "BLP (CATE)", xlab = paste("coverage: ", res.cov["cov.blp.cate"]))
66-
abline(v=true, col = "red", lty = 2)
65+
abline(v=true, col = "red", lty = 1)
6766
hist(res["blp.dr", ], breaks = breaks, main = "BLP (DR)", xlab = paste("coverage: ", res.cov["cov.blp.dr"]))
68-
abline(v=true, col = "red", lty = 2)
67+
abline(v=true, col = "red", lty = 1)
6968

7069
hist(res["ate", ], breaks = breaks, main = "ATE (CATE)" , xlab = paste("coverage: ", res.cov["cov.ate"]))
71-
abline(v=true.ate, col = "red", lty = 2)
70+
abline(v=true.ate, col = "red", lty = 1)
7271
hist(res["ate.dr", ], breaks = breaks, main = "ATE (DR)", xlab = paste("coverage: ", res.cov["cov.ate.dr"]))
73-
abline(v=true.ate, col = "red", lty = 2)
72+
abline(v=true.ate, col = "red", lty = 1)
7473
dev.off()

experiments/csf/simulation_coverage.R

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@ set.seed(123)
44

55
out = list()
66
n.sim = 1000
7-
n.mc = 100000
8-
p = 5
9-
X.test = matrix(c(0.2, 0.4, 0.6, 0.8), 4, p)
107
grid = expand.grid(n = c(2000),
8+
p = 15,
9+
rho = c(0, 0.5),
1110
num.trees = c(10000),
1211
dgp = c("type1", "type2", "type3", "type4"),
1312
stringsAsFactors = FALSE)
@@ -17,16 +16,19 @@ for (i in 1:nrow(grid)) {
1716
print(paste("grid", i, "of", nrow(grid)))
1817
print(grid[i, ])
1918
n = grid$n[i]
19+
p = grid$p[i]
2020
dgp = grid$dgp[i]
21+
rho = grid$rho[i]
2122
num.trees = grid$num.trees[i]
23+
X.test = matrix(c(0.2, 0.4, 0.6, 0.8), 4, p)
2224

25+
data.test = generate_causal_survival_data(n = nrow(X.test), p = p, X = X.test, dgp = dgp, rho = rho, n.mc = 100000)
26+
cate.true = data.test$cate
27+
cate.true.prob = data.test$cate.prob
2328
for (sim in 1:n.sim) {
2429
print(paste("sim", sim))
25-
data = generate_causal_survival_data(n = n, p = p, dgp = dgp, n.mc = 10)
30+
data = generate_causal_survival_data(n = n, p = p, dgp = dgp, rho = rho, n.mc = 1)
2631
data$Y = round(data$Y, 2)
27-
data.test = generate_causal_survival_data(n = nrow(X.test), p = p, X = X.test, dgp = dgp, n.mc = n.mc)
28-
cate.true = data.test$cate
29-
cate.true.prob = data.test$cate.prob
3032
forest.W = regression_forest(data$X, data$W, num.trees = 500, ci.group.size = 1)
3133
W.hat = predict(forest.W)$predictions
3234

@@ -52,7 +54,9 @@ for (i in 1:nrow(grid)) {
5254
coverage = c(coverage, coverage.prob),
5355
width = c(width, width.prob),
5456
n = n,
57+
p = p,
5558
dgp = dgp,
59+
rho = rho,
5660
num.trees = num.trees,
5761
sim = sim,
5862
X.test = X.test[, 1]
Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,41 @@
11
# Run `simulation_coverage.R` to produce `coverage.csv.gz`.
2-
# Table 3 is produced below.
2+
# 95 % CI coverage table is produced below.
33

44
rm(list = ls())
55
library(xtable)
66
df = read.csv("coverage.csv.gz")
7-
apply(df[c("target", "n", "num.trees", "dgp", "X.test")], 2, unique)
7+
apply(df[c("target", "n", "p", "rho", "num.trees", "dgp", "X.test")], 2, unique)
88

9-
tab = aggregate(list(coverage = df$coverage),
9+
tab = aggregate(list(coverage = df$coverage, width = df$width),
1010
by = list(target = df$target,
1111
dgp = df$dgp,
1212
Xi = df$X.test,
13+
p = df$p,
14+
rho = df$rho,
1315
n.train = df$n,
1416
num.trees = df$num.trees),
1517
FUN = mean)
1618

17-
# Table 3
18-
options(digits = 2)
19-
xtabs(coverage ~ dgp + Xi + target, tab)
19+
# Table coverage and CI length
20+
# RMST
21+
print(xtable(
22+
cbind(xtabs(coverage ~ dgp + Xi, tab, subset = target == "RMST" & rho == 0),
23+
xtabs(width ~ dgp + Xi, tab, subset = target == "RMST" & rho == 0))
24+
))
25+
# SP
26+
print(xtable(
27+
cbind(xtabs(coverage ~ dgp + Xi, tab, subset = target == "survival.probability" & rho == 0),
28+
xtabs(width ~ dgp + Xi, tab, subset = target == "survival.probability" & rho == 0))
29+
))
30+
31+
# w correlated X's
32+
# RMST
33+
print(xtable(
34+
cbind(xtabs(coverage ~ dgp + Xi, tab, subset = target == "RMST" & rho == 0.5),
35+
xtabs(width ~ dgp + Xi, tab, subset = target == "RMST" & rho == 0.5))
36+
))
37+
# SP
38+
print(xtable(
39+
cbind(xtabs(coverage ~ dgp + Xi, tab, subset = target == "survival.probability" & rho == 0.5),
40+
xtabs(width ~ dgp + Xi, tab, subset = target == "survival.probability" & rho == 0.5))
41+
))

experiments/csf/simulation_mse.R

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ estimators = list(SRC1 = SRC1,
1414
# *** Setup ***
1515
out = list()
1616
n.sim = 250
17-
n.mc = 100000
1817
grid = expand.grid(n = c(500, 1000, 2000, 5000),
19-
p = 5,
18+
p = 15,
19+
rho = c(0, 0.5),
2020
n.test = 2000,
2121
dgp = c("type1", "type2", "type3", "type4"),
2222
stringsAsFactors = FALSE)
@@ -29,15 +29,16 @@ for (i in 1:nrow(grid)) {
2929
p = grid$p[i]
3030
n.test = grid$n.test[i]
3131
dgp = grid$dgp[i]
32+
rho = grid$rho[i]
3233

34+
data.test = generate_causal_survival_data(n = n.test, p = p, dgp = dgp, rho = rho, n.mc = 100000)
35+
true.cate = data.test$cate
36+
true.cate.prob = data.test$cate.prob
37+
true.cate.sign = data.test$cate.sign
3338
for (sim in 1:n.sim) {
3439
print(paste("sim", sim))
35-
data = generate_causal_survival_data(n = n, p = p, dgp = dgp, n.mc = 10)
40+
data = generate_causal_survival_data(n = n, p = p, dgp = dgp, rho = rho, n.mc = 1)
3641
data$Y = round(data$Y, 2)
37-
data.test = generate_causal_survival_data(n = n.test, p = p, dgp = dgp, n.mc = n.mc)
38-
true.cate = data.test$cate
39-
true.cate.prob = data.test$cate.prob
40-
true.cate.sign = data.test$cate.sign
4142
estimator.output = list()
4243
for (j in 1:length(estimators)) {
4344
estimator = names(estimators)[j]
@@ -57,6 +58,7 @@ for (i in 1:nrow(grid)) {
5758
df$p = p
5859
df$n.test = n.test
5960
df$dgp = dgp
61+
df$rho = rho
6062
df$sim = sim
6163

6264
out = c(out, list(df))

0 commit comments

Comments
 (0)