@@ -139,6 +139,13 @@ instrumental_forest <- function(X, Y, W, Z,
139139
140140 all.tunable.params <- c(" sample.fraction" , " mtry" , " min.node.size" , " honesty.fraction" ,
141141 " honesty.prune.leaves" , " alpha" , " imbalance.penalty" )
142+ default.parameters <- list (sample.fraction = 0.5 ,
143+ mtry = min(ceiling(sqrt(ncol(X )) + 20 ), ncol(X )),
144+ min.node.size = 5 ,
145+ honesty.fraction = 0.5 ,
146+ honesty.prune.leaves = TRUE ,
147+ alpha = 0.05 ,
148+ imbalance.penalty = 0 )
142149
143150 args.orthog = list (X = X ,
144151 num.trees = min(500 , num.trees ),
@@ -207,27 +214,26 @@ instrumental_forest <- function(X, Y, W, Z,
207214
208215 tuning.output <- NULL
209216 if (! identical(tune.parameters , " none" )){
210- tuning.output <- tune_instrumental_forest(X , Y , W , Z , Y.hat , W.hat , Z.hat ,
211- sample.weights = sample.weights ,
212- clusters = clusters ,
213- equalize.cluster.weights = equalize.cluster.weights ,
214- sample.fraction = sample.fraction ,
215- mtry = mtry ,
216- min.node.size = min.node.size ,
217- honesty = honesty ,
218- honesty.fraction = honesty.fraction ,
219- honesty.prune.leaves = honesty.prune.leaves ,
220- alpha = alpha ,
221- imbalance.penalty = imbalance.penalty ,
222- stabilize.splits = stabilize.splits ,
223- ci.group.size = ci.group.size ,
224- reduced.form.weight = reduced.form.weight ,
225- tune.parameters = tune.parameters ,
226- tune.num.trees = tune.num.trees ,
227- tune.num.reps = tune.num.reps ,
228- tune.num.draws = tune.num.draws ,
229- num.threads = num.threads ,
230- seed = seed )
217+ if (identical(tune.parameters , " all" )) {
218+ tune.parameters <- all.tunable.params
219+ } else {
220+ tune.parameters <- unique(match.arg(tune.parameters , all.tunable.params , several.ok = TRUE ))
221+ }
222+ if (! honesty ) {
223+ tune.parameters <- tune.parameters [! grepl(" honesty" , tune.parameters )]
224+ }
225+ tune.parameters.defaults <- default.parameters [tune.parameters ]
226+ tuning.output <- tune_forest(data = data ,
227+ nrow.X = nrow(X ),
228+ ncol.X = ncol(X ),
229+ args = args ,
230+ tune.parameters = tune.parameters ,
231+ tune.parameters.defaults = tune.parameters.defaults ,
232+ num.fit.trees = tune.num.trees ,
233+ num.fit.reps = tune.num.reps ,
234+ num.optimize.reps = tune.num.draws ,
235+ train = instrumental_train )
236+
231237 args <- modifyList(args , as.list(tuning.output [[" params" ]]))
232238 }
233239
0 commit comments