diff --git a/CRAN-SUBMISSION b/CRAN-SUBMISSION index 3024634..343d9c9 100644 --- a/CRAN-SUBMISSION +++ b/CRAN-SUBMISSION @@ -1,3 +1,3 @@ -Version: 1.2.0 -Date: 2024-07-12 12:10:00 UTC -SHA: daf4cee64500abb8d78f92d8b1e8f1e588a59884 +Version: 1.2.1 +Date: 2024-08-17 15:36:23 UTC +SHA: dd44d33e27102fa327b72bdd4893e4b483b362bc diff --git a/R/hstats.R b/R/hstats.R index 1ee1bd8..280a310 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -304,10 +304,9 @@ hstats.ranger <- function( survival = c("chf", "prob"), ... ) { - survival <- match.arg(survival) - + if (is.null(pred_fun)) { - pred_fun <- pred_ranger + pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival)) } hstats.default( @@ -323,7 +322,6 @@ hstats.ranger <- function( eps = eps, w = w, verbose = verbose, - survival = survival, ... ) } diff --git a/R/ice.R b/R/ice.R index 919a25e..4597d7a 100644 --- a/R/ice.R +++ b/R/ice.R @@ -174,12 +174,11 @@ ice.ranger <- function( survival = c("chf", "prob"), ... ) { - survival <- match.arg(survival) - + if (is.null(pred_fun)) { - pred_fun <- pred_ranger + pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival)) } - + ice.default( object = object, v = v, @@ -192,7 +191,6 @@ ice.ranger <- function( strategy = strategy, na.rm = na.rm, n_max = n_max, - survival = survival, ... ) } diff --git a/R/partial_dep.R b/R/partial_dep.R index 04fa534..8d41f55 100644 --- a/R/partial_dep.R +++ b/R/partial_dep.R @@ -217,12 +217,11 @@ partial_dep.ranger <- function( survival = c("chf", "prob"), ... ) { - survival <- match.arg(survival) - + if (is.null(pred_fun)) { - pred_fun <- pred_ranger + pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival)) } - + partial_dep.default( object = object, v = v, @@ -237,7 +236,6 @@ partial_dep.ranger <- function( na.rm = na.rm, n_max = n_max, w = w, - survival = survival, ... ) } diff --git a/R/utils_input.R b/R/utils_input.R index be5f15d..f44abf5 100644 --- a/R/utils_input.R +++ b/R/utils_input.R @@ -127,27 +127,34 @@ prepare_y <- function(y, X) { #' Predict Function for Ranger #' -#' Internal function that prepares the predictions of different types of ranger models. +#' Returns prediction function for different modes of ranger. #' #' @noRd #' @keywords internal -#' @param model Fitted ranger model. -#' @param newdata Data to predict on. +#' @param treetype The value of `fit$treetype` in a fitted ranger model. #' @param survival Cumulative hazards "chf" (default) or probabilities "prob" per time. -#' @param ... Additional arguments passed to ranger's predict function. #' -#' @returns A vector or matrix with predictions. -pred_ranger <- function(model, newdata, survival = c("chf", "prob"), ...) { +#' @returns A function with signature f(model, newdata, ...). +create_ranger_pred_fun <- function(treetype, survival = c("chf", "prob")) { survival <- match.arg(survival) - pred <- stats::predict(model, newdata, ...) + if (treetype != "Survival") { + pred_fun <- function(model, newdata, ...) { + stats::predict(model, newdata, ...)$predictions + } + return(pred_fun) + } + + if (survival == "prob") { + survival <- "survival" + } - if (model$treetype == "Survival") { - out <- if (survival == "chf") pred$chf else pred$survival + pred_fun <- function(model, newdata, ...) { + pred <- stats::predict(model, newdata, ...) + out <- pred[[survival]] colnames(out) <- paste0("t", pred$unique.death.times) - } else { - out <- pred$predictions + return(out) } - return(out) + return(pred_fun) } diff --git a/backlog/survival_hstats.R b/backlog/survival_hstats.R new file mode 100644 index 0000000..b58e32d --- /dev/null +++ b/backlog/survival_hstats.R @@ -0,0 +1,33 @@ +library(ranger) +library(survival) +library(hstats) +library(ggplot2) + +set.seed(1) + +fit <- ranger(Surv(time, status) ~ ., data = veteran) +fit2 <- ranger(time ~ . - status, data = veteran) +fit3 <- ranger(time ~ . - status, data = veteran, quantreg = TRUE) +fit4 <- ranger(status ~ . - time, data = veteran, probability = TRUE) + +xvars <- setdiff(colnames(veteran), c("time", "status")) + +hstats(fit, X = veteran, v = xvars[1:2], survival = "prob") +hstats(fit, X = veteran, v = xvars[1:2], survival = "chf") +hstats(fit2, X = veteran, v = xvars[1:2]) +hstats(fit3, X = veteran, v = xvars[1:2], type = "quantiles") +hstats(fit4, X = veteran, v = xvars[1:2]) + +partial_dep(fit, X = veteran, v = "celltype") +partial_dep(fit, X = veteran, v = "celltype", survival = "prob") +partial_dep(fit2, X = veteran, v = "celltype") +partial_dep(fit3, X = veteran, v = "celltype", type = "quantiles") +partial_dep(fit4, X = veteran, v = "celltype") + + +ice(fit, X = veteran, v = "celltype") +ice(fit, X = veteran, v = "celltype", survival = "prob") +ice(fit2, X = veteran, v = "celltype") +ice(fit3, X = veteran, v = "celltype", type = "quantiles") +ice(fit4, X = veteran, v = "celltype") + diff --git a/cran-comments.md b/cran-comments.md index d5dfef7..877ff0d 100644 --- a/cran-comments.md +++ b/cran-comments.md @@ -1,12 +1,6 @@ -# Re-submission: hstats 1.2.0 +# Re-submission: hstats 1.2.1 -Moving the github repo has left some old links in the NEWS file. This is fixed here. - -# Original message - -Hello CRAN - -This release mainly updates the new repository ("ModelOriented" of TU Warcaw instead of my personal one), and adds Prof Biecek as co-author. +This is a small update, fixing a wrong ORCID. ## Local checks @@ -14,6 +8,4 @@ This release mainly updates the new repository ("ModelOriented" of TU Warcaw ins ## Winbuilder -Status: 1 NOTE -R Under development (unstable) (2024-07-11 r86890 ucrt) - +Status: OK