diff --git a/R/hstats.R b/R/hstats.R index 1ee1bd8..280a310 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -304,10 +304,9 @@ hstats.ranger <- function( survival = c("chf", "prob"), ... ) { - survival <- match.arg(survival) - + if (is.null(pred_fun)) { - pred_fun <- pred_ranger + pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival)) } hstats.default( @@ -323,7 +322,6 @@ hstats.ranger <- function( eps = eps, w = w, verbose = verbose, - survival = survival, ... ) } diff --git a/R/ice.R b/R/ice.R index 919a25e..4597d7a 100644 --- a/R/ice.R +++ b/R/ice.R @@ -174,12 +174,11 @@ ice.ranger <- function( survival = c("chf", "prob"), ... ) { - survival <- match.arg(survival) - + if (is.null(pred_fun)) { - pred_fun <- pred_ranger + pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival)) } - + ice.default( object = object, v = v, @@ -192,7 +191,6 @@ ice.ranger <- function( strategy = strategy, na.rm = na.rm, n_max = n_max, - survival = survival, ... ) } diff --git a/R/partial_dep.R b/R/partial_dep.R index 04fa534..8d41f55 100644 --- a/R/partial_dep.R +++ b/R/partial_dep.R @@ -217,12 +217,11 @@ partial_dep.ranger <- function( survival = c("chf", "prob"), ... ) { - survival <- match.arg(survival) - + if (is.null(pred_fun)) { - pred_fun <- pred_ranger + pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival)) } - + partial_dep.default( object = object, v = v, @@ -237,7 +236,6 @@ partial_dep.ranger <- function( na.rm = na.rm, n_max = n_max, w = w, - survival = survival, ... ) } diff --git a/R/utils_input.R b/R/utils_input.R index be5f15d..f44abf5 100644 --- a/R/utils_input.R +++ b/R/utils_input.R @@ -127,27 +127,34 @@ prepare_y <- function(y, X) { #' Predict Function for Ranger #' -#' Internal function that prepares the predictions of different types of ranger models. +#' Returns prediction function for different modes of ranger. #' #' @noRd #' @keywords internal -#' @param model Fitted ranger model. -#' @param newdata Data to predict on. +#' @param treetype The value of `fit$treetype` in a fitted ranger model. #' @param survival Cumulative hazards "chf" (default) or probabilities "prob" per time. -#' @param ... Additional arguments passed to ranger's predict function. #' -#' @returns A vector or matrix with predictions. -pred_ranger <- function(model, newdata, survival = c("chf", "prob"), ...) { +#' @returns A function with signature f(model, newdata, ...). +create_ranger_pred_fun <- function(treetype, survival = c("chf", "prob")) { survival <- match.arg(survival) - pred <- stats::predict(model, newdata, ...) + if (treetype != "Survival") { + pred_fun <- function(model, newdata, ...) { + stats::predict(model, newdata, ...)$predictions + } + return(pred_fun) + } + + if (survival == "prob") { + survival <- "survival" + } - if (model$treetype == "Survival") { - out <- if (survival == "chf") pred$chf else pred$survival + pred_fun <- function(model, newdata, ...) { + pred <- stats::predict(model, newdata, ...) + out <- pred[[survival]] colnames(out) <- paste0("t", pred$unique.death.times) - } else { - out <- pred$predictions + return(out) } - return(out) + return(pred_fun) }