Skip to content

Commit 97ce171

Browse files
authored
Deprecate tune_***_forest (#790)
See issue #758.
1 parent 5a4f07c commit 97ce171

19 files changed

+171
-968
lines changed

experiments/local_linear_examples/wages.R

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,8 @@ mse.sample.sizes = data.frame(t(sapply(sample.sizes, function(size){
6868

6969
forest = regression_forest(as.matrix(X), Y, honesty = TRUE, tune.parameters = "all")
7070

71-
ll.lambda = tune_ll_regression_forest(forest, linear.correction.variables = continuous.covariates,
72-
ll.weight.penalty = T)$lambda.min
7371
llf.preds = predict(forest, as.matrix(X.test),
7472
linear.correction.variables = continuous.covariates,
75-
ll.lambda = ll.lambda,
7673
ll.weight.penalty = T)$predictions
7774
llf.mse = mean((llf.preds - truth)**2)
7875

r-package/grf/NAMESPACE

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ export(survival_forest)
4545
export(test_calibration)
4646
export(tune_causal_forest)
4747
export(tune_instrumental_forest)
48-
export(tune_ll_causal_forest)
49-
export(tune_ll_regression_forest)
5048
export(tune_regression_forest)
5149
export(variable_importance)
5250
importFrom(Matrix,Matrix)

r-package/grf/R/causal_forest.R

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,13 @@ causal_forest <- function(X, Y, W,
172172

173173
all.tunable.params <- c("sample.fraction", "mtry", "min.node.size", "honesty.fraction",
174174
"honesty.prune.leaves", "alpha", "imbalance.penalty")
175+
default.parameters <- list(sample.fraction = 0.5,
176+
mtry = min(ceiling(sqrt(ncol(X)) + 20), ncol(X)),
177+
min.node.size = 5,
178+
honesty.fraction = 0.5,
179+
honesty.prune.leaves = TRUE,
180+
alpha = 0.05,
181+
imbalance.penalty = 0)
175182

176183
args.orthog <- list(X = X,
177184
num.trees = max(50, num.trees / 4),
@@ -239,26 +246,26 @@ causal_forest <- function(X, Y, W,
239246

240247
tuning.output <- NULL
241248
if (!identical(tune.parameters, "none")){
242-
tuning.output <- tune_causal_forest(X, Y, W, Y.hat, W.hat,
243-
sample.weights = sample.weights,
244-
clusters = clusters,
245-
equalize.cluster.weights = equalize.cluster.weights,
246-
sample.fraction = sample.fraction,
247-
mtry = mtry,
248-
min.node.size = min.node.size,
249-
honesty = honesty,
250-
honesty.fraction = honesty.fraction,
251-
honesty.prune.leaves = honesty.prune.leaves,
252-
alpha = alpha,
253-
imbalance.penalty = imbalance.penalty,
254-
stabilize.splits = stabilize.splits,
255-
ci.group.size = ci.group.size,
256-
tune.parameters = tune.parameters,
257-
tune.num.trees = tune.num.trees,
258-
tune.num.reps = tune.num.reps,
259-
tune.num.draws = tune.num.draws,
260-
num.threads = num.threads,
261-
seed = seed)
249+
if (identical(tune.parameters, "all")) {
250+
tune.parameters <- all.tunable.params
251+
} else {
252+
tune.parameters <- unique(match.arg(tune.parameters, all.tunable.params, several.ok = TRUE))
253+
}
254+
if (!honesty) {
255+
tune.parameters <- tune.parameters[!grepl("honesty", tune.parameters)]
256+
}
257+
tune.parameters.defaults <- default.parameters[tune.parameters]
258+
tuning.output <- tune_forest(data = data,
259+
nrow.X = nrow(X),
260+
ncol.X = ncol(X),
261+
args = args,
262+
tune.parameters = tune.parameters,
263+
tune.parameters.defaults = tune.parameters.defaults,
264+
num.fit.trees = tune.num.trees,
265+
num.fit.reps = tune.num.reps,
266+
num.optimize.reps = tune.num.draws,
267+
train = causal_train)
268+
262269
args <- modifyList(args, as.list(tuning.output[["params"]]))
263270
}
264271

r-package/grf/R/deprecated.R

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,56 @@ average_late <- function(forest, ...) {
2525
average_partial_effect <- function(forest, ...) {
2626
stop("This function has been deprecated after version 1.2.0. See the function `average_treatment_effect` instead.")
2727
}
28+
29+
#' Regression forest tuning (deprecated)
30+
#'
31+
#' To tune a regression forest, see the function `regression_forest`
32+
#'
33+
#' @param X X
34+
#' @param Y Y
35+
#' @param ... Additional arguments (currently ignored).
36+
#'
37+
#' @return output
38+
#'
39+
#' @export
40+
tune_regression_forest <- function(X, Y, ...) {
41+
stop("This function has been deprecated after version 1.2.0.")
42+
}
43+
44+
#' Causal forest tuning (deprecated)
45+
#'
46+
#' To tune a causal forest, see the function `causal_forest`
47+
#'
48+
#' @param X X
49+
#' @param Y Y
50+
#' @param W W
51+
#' @param Y.hat Y.hat
52+
#' @param W.hat W.hat
53+
#' @param ... Additional arguments (currently ignored).
54+
#'
55+
#' @return output
56+
#'
57+
#' @export
58+
tune_causal_forest <- function(X, Y, W, Y.hat, W.hat, ...) {
59+
stop("This function has been deprecated after version 1.2.0.")
60+
}
61+
62+
#' Instrumental forest tuning (deprecated)
63+
#'
64+
#' To tune a instrumental forest, see the function `instrumental_forest`
65+
#'
66+
#' @param X X
67+
#' @param Y Y
68+
#' @param W W
69+
#' @param Z Z
70+
#' @param Y.hat Y.hat
71+
#' @param W.hat W.hat
72+
#' @param Z.hat Z.hat
73+
#' @param ... Additional arguments (currently ignored).
74+
#'
75+
#' @return output
76+
#'
77+
#' @export
78+
tune_instrumental_forest <- function(X, Y, W, Z, Y.hat, W.hat, Z.hat, ...) {
79+
stop("This function has been deprecated after version 1.2.0.")
80+
}

r-package/grf/R/instrumental_forest.R

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@ instrumental_forest <- function(X, Y, W, Z,
139139

140140
all.tunable.params <- c("sample.fraction", "mtry", "min.node.size", "honesty.fraction",
141141
"honesty.prune.leaves", "alpha", "imbalance.penalty")
142+
default.parameters <- list(sample.fraction = 0.5,
143+
mtry = min(ceiling(sqrt(ncol(X)) + 20), ncol(X)),
144+
min.node.size = 5,
145+
honesty.fraction = 0.5,
146+
honesty.prune.leaves = TRUE,
147+
alpha = 0.05,
148+
imbalance.penalty = 0)
142149

143150
args.orthog = list(X = X,
144151
num.trees = min(500, num.trees),
@@ -207,27 +214,26 @@ instrumental_forest <- function(X, Y, W, Z,
207214

208215
tuning.output <- NULL
209216
if (!identical(tune.parameters, "none")){
210-
tuning.output <- tune_instrumental_forest(X, Y, W, Z, Y.hat, W.hat, Z.hat,
211-
sample.weights = sample.weights,
212-
clusters = clusters,
213-
equalize.cluster.weights = equalize.cluster.weights,
214-
sample.fraction = sample.fraction,
215-
mtry = mtry,
216-
min.node.size = min.node.size,
217-
honesty = honesty,
218-
honesty.fraction = honesty.fraction,
219-
honesty.prune.leaves = honesty.prune.leaves,
220-
alpha = alpha,
221-
imbalance.penalty = imbalance.penalty,
222-
stabilize.splits = stabilize.splits,
223-
ci.group.size = ci.group.size,
224-
reduced.form.weight = reduced.form.weight,
225-
tune.parameters = tune.parameters,
226-
tune.num.trees = tune.num.trees,
227-
tune.num.reps = tune.num.reps,
228-
tune.num.draws = tune.num.draws,
229-
num.threads = num.threads,
230-
seed = seed)
217+
if (identical(tune.parameters, "all")) {
218+
tune.parameters <- all.tunable.params
219+
} else {
220+
tune.parameters <- unique(match.arg(tune.parameters, all.tunable.params, several.ok = TRUE))
221+
}
222+
if (!honesty) {
223+
tune.parameters <- tune.parameters[!grepl("honesty", tune.parameters)]
224+
}
225+
tune.parameters.defaults <- default.parameters[tune.parameters]
226+
tuning.output <- tune_forest(data = data,
227+
nrow.X = nrow(X),
228+
ncol.X = ncol(X),
229+
args = args,
230+
tune.parameters = tune.parameters,
231+
tune.parameters.defaults = tune.parameters.defaults,
232+
num.fit.trees = tune.num.trees,
233+
num.fit.reps = tune.num.reps,
234+
num.optimize.reps = tune.num.draws,
235+
train = instrumental_train)
236+
231237
args <- modifyList(args, as.list(tuning.output[["params"]]))
232238
}
233239

r-package/grf/R/regression_forest.R

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,13 @@ regression_forest <- function(X, Y,
117117

118118
all.tunable.params <- c("sample.fraction", "mtry", "min.node.size", "honesty.fraction",
119119
"honesty.prune.leaves", "alpha", "imbalance.penalty")
120+
default.parameters <- list(sample.fraction = 0.5,
121+
mtry = min(ceiling(sqrt(ncol(X)) + 20), ncol(X)),
122+
min.node.size = 5,
123+
honesty.fraction = 0.5,
124+
honesty.prune.leaves = TRUE,
125+
alpha = 0.05,
126+
imbalance.penalty = 0)
120127

121128
data <- create_train_matrices(X, outcome = Y, sample.weights = sample.weights)
122129
args <- list(num.trees = num.trees,
@@ -137,25 +144,26 @@ regression_forest <- function(X, Y,
137144

138145
tuning.output <- NULL
139146
if (!identical(tune.parameters, "none")){
140-
tuning.output <- tune_regression_forest(X, Y,
141-
sample.weights = sample.weights,
142-
clusters = clusters,
143-
equalize.cluster.weights = equalize.cluster.weights,
144-
sample.fraction = sample.fraction,
145-
mtry = mtry,
146-
min.node.size = min.node.size,
147-
honesty = honesty,
148-
honesty.fraction = honesty.fraction,
149-
honesty.prune.leaves = honesty.prune.leaves,
150-
alpha = alpha,
151-
imbalance.penalty = imbalance.penalty,
152-
ci.group.size = ci.group.size,
153-
tune.parameters = tune.parameters,
154-
tune.num.trees = tune.num.trees,
155-
tune.num.reps = tune.num.reps,
156-
tune.num.draws = tune.num.draws,
157-
num.threads = num.threads,
158-
seed = seed)
147+
if (identical(tune.parameters, "all")) {
148+
tune.parameters <- all.tunable.params
149+
} else {
150+
tune.parameters <- unique(match.arg(tune.parameters, all.tunable.params, several.ok = TRUE))
151+
}
152+
if (!honesty) {
153+
tune.parameters <- tune.parameters[!grepl("honesty", tune.parameters)]
154+
}
155+
tune.parameters.defaults <- default.parameters[tune.parameters]
156+
tuning.output <- tune_forest(data = data,
157+
nrow.X = nrow(X),
158+
ncol.X = ncol(X),
159+
args = args,
160+
tune.parameters = tune.parameters,
161+
tune.parameters.defaults = tune.parameters.defaults,
162+
num.fit.trees = tune.num.trees,
163+
num.fit.reps = tune.num.reps,
164+
num.optimize.reps = tune.num.draws,
165+
train = regression_train)
166+
159167
args <- modifyList(args, as.list(tuning.output[["params"]]))
160168
}
161169

0 commit comments

Comments
 (0)