Skip to content

Commit

Permalink
Fix problems with extra argument "survival"
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Aug 3, 2024
1 parent 724d4c0 commit 4e8328a
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 26 deletions.
6 changes: 2 additions & 4 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,9 @@ hstats.ranger <- function(
survival = c("chf", "prob"),
...
) {
survival <- match.arg(survival)


if (is.null(pred_fun)) {
pred_fun <- pred_ranger
pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival))
}

hstats.default(
Expand All @@ -323,7 +322,6 @@ hstats.ranger <- function(
eps = eps,
w = w,
verbose = verbose,
survival = survival,
...
)
}
Expand Down
8 changes: 3 additions & 5 deletions R/ice.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,11 @@ ice.ranger <- function(
survival = c("chf", "prob"),
...
) {
survival <- match.arg(survival)


if (is.null(pred_fun)) {
pred_fun <- pred_ranger
pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival))
}

ice.default(
object = object,
v = v,
Expand All @@ -192,7 +191,6 @@ ice.ranger <- function(
strategy = strategy,
na.rm = na.rm,
n_max = n_max,
survival = survival,
...
)
}
Expand Down
8 changes: 3 additions & 5 deletions R/partial_dep.R
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,11 @@ partial_dep.ranger <- function(
survival = c("chf", "prob"),
...
) {
survival <- match.arg(survival)


if (is.null(pred_fun)) {
pred_fun <- pred_ranger
pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival))
}

partial_dep.default(
object = object,
v = v,
Expand All @@ -237,7 +236,6 @@ partial_dep.ranger <- function(
na.rm = na.rm,
n_max = n_max,
w = w,
survival = survival,
...
)
}
Expand Down
31 changes: 19 additions & 12 deletions R/utils_input.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,27 +127,34 @@ prepare_y <- function(y, X) {

#' Predict Function for Ranger
#'
#' Internal function that prepares the predictions of different types of ranger models.
#' Returns prediction function for different modes of ranger.
#'
#' @noRd
#' @keywords internal
#' @param model Fitted ranger model.
#' @param newdata Data to predict on.
#' @param treetype The value of `fit$treetype` in a fitted ranger model.
#' @param survival Cumulative hazards "chf" (default) or probabilities "prob" per time.
#' @param ... Additional arguments passed to ranger's predict function.
#'
#' @returns A vector or matrix with predictions.
pred_ranger <- function(model, newdata, survival = c("chf", "prob"), ...) {
#' @returns A function with signature f(model, newdata, ...).
create_ranger_pred_fun <- function(treetype, survival = c("chf", "prob")) {
survival <- match.arg(survival)

pred <- stats::predict(model, newdata, ...)
if (treetype != "Survival") {
pred_fun <- function(model, newdata, ...) {
stats::predict(model, newdata, ...)$predictions
}
return(pred_fun)
}

if (survival == "prob") {
survival <- "survival"
}

if (model$treetype == "Survival") {
out <- if (survival == "chf") pred$chf else pred$survival
pred_fun <- function(model, newdata, ...) {
pred <- stats::predict(model, newdata, ...)
out <- pred[[survival]]
colnames(out) <- paste0("t", pred$unique.death.times)
} else {
out <- pred$predictions
return(out)
}
return(out)
return(pred_fun)
}

0 comments on commit 4e8328a

Please sign in to comment.