Skip to content

Commit

Permalink
Merge pull request #122 from ModelOriented/ranger-survival
Browse files Browse the repository at this point in the history
Direct ranger survival support
  • Loading branch information
mayer79 authored Jul 26, 2024
2 parents c990d75 + 09be450 commit ecd0536
Show file tree
Hide file tree
Showing 21 changed files with 413 additions and 159 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# hstats 1.2.1

## Usability

- `ranger()` survival models now also work out-of-the-box without passing a tailored prediction function. Use the new argument `survival = "chf"` in `hstats()`, `ice()`, and `partial_dep()` to distinguish cumulative hazards (default) and survival probabilities ("prob") per time point.

## Other changes

- Fixed wrong ORCID of Michael.
Expand Down
6 changes: 4 additions & 2 deletions R/H2_overall.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ h2_overall.default <- function(object, ...) {

#' @describeIn h2_overall Overall interaction strength from "hstats" object.
#' @export
h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE,
sort = TRUE, zero = TRUE, ...) {
h2_overall.hstats <- function(
object, normalize = TRUE, squared = TRUE, sort = TRUE, zero = TRUE, ...
) {
get_hstats_matrix(
statistic = "h2_overall",
object = object,
Expand Down Expand Up @@ -113,3 +114,4 @@ h2_overall_raw <- function(x) {

list(num = num, denom = x[["mean_f2"]])
}

10 changes: 6 additions & 4 deletions R/H2_pairwise.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,14 @@ h2_pairwise.default <- function(object, ...) {

#' @describeIn h2_pairwise Pairwise interaction strength from "hstats" object.
#' @export
h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE,
sort = TRUE, zero = TRUE, ...) {
h2_pairwise.hstats <- function(
object, normalize = TRUE, squared = TRUE, sort = TRUE, zero = TRUE, ...
) {
get_hstats_matrix(
statistic = "h2_pairwise",
object = object,
normalize = normalize,
squared = squared,
normalize = normalize,
squared = squared,
sort = sort,
zero = zero
)
Expand Down Expand Up @@ -122,3 +123,4 @@ h2_pairwise_raw <- function(x) {

list(num = num, denom = denom)
}

10 changes: 6 additions & 4 deletions R/H2_threeway.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,14 @@ h2_threeway.default <- function(object, ...) {

#' @describeIn h2_threeway Pairwise interaction strength from "hstats" object.
#' @export
h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE,
sort = TRUE, zero = TRUE, ...) {
h2_threeway.hstats <- function(
object, normalize = TRUE, squared = TRUE, sort = TRUE, zero = TRUE, ...
) {
get_hstats_matrix(
statistic = "h2_threeway",
object = object,
normalize = normalize,
squared = squared,
normalize = normalize,
squared = squared,
sort = sort,
zero = zero
)
Expand Down Expand Up @@ -109,3 +110,4 @@ h2_threeway_raw <- function(x) {

list(num = num, denom = denom)
}

60 changes: 37 additions & 23 deletions R/average_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,18 @@ average_loss <- function(object, ...) {

#' @describeIn average_loss Default method.
#' @export
average_loss.default <- function(object, X, y,
pred_fun = stats::predict,
loss = "squared_error",
agg_cols = FALSE,
BY = NULL, by_size = 4L,
w = NULL, ...) {
average_loss.default <- function(
object,
X,
y,
pred_fun = stats::predict,
loss = "squared_error",
agg_cols = FALSE,
BY = NULL,
by_size = 4L,
w = NULL,
...
) {
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun)
Expand Down Expand Up @@ -109,13 +115,18 @@ average_loss.default <- function(object, X, y,

#' @describeIn average_loss Method for "ranger" models.
#' @export
average_loss.ranger <- function(object, X, y,
pred_fun = function(m, X, ...)
stats::predict(m, X, ...)$predictions,
loss = "squared_error",
agg_cols = FALSE,
BY = NULL, by_size = 4L,
w = NULL, ...) {
average_loss.ranger <- function(
object,
X,
y,
pred_fun = function(m, X, ...)
stats::predict(m, X, ...)$predictions,
loss = "squared_error",
agg_cols = FALSE,
BY = NULL, by_size = 4L,
w = NULL,
...
) {
average_loss.default(
object = object,
X = X,
Expand All @@ -132,16 +143,18 @@ average_loss.ranger <- function(object, X, y,

#' @describeIn average_loss Method for DALEX "explainer".
#' @export
average_loss.explainer <- function(object,
X = object[["data"]],
y = object[["y"]],
pred_fun = object[["predict_function"]],
loss = "squared_error",
agg_cols = FALSE,
BY = NULL,
by_size = 4L,
w = object[["weights"]],
...) {
average_loss.explainer <- function(
object,
X = object[["data"]],
y = object[["y"]],
pred_fun = object[["predict_function"]],
loss = "squared_error",
agg_cols = FALSE,
BY = NULL,
by_size = 4L,
w = object[["weights"]],
...
) {
average_loss.default(
object = object[["model"]],
X = X,
Expand All @@ -155,3 +168,4 @@ average_loss.explainer <- function(object,
...
)
}

105 changes: 77 additions & 28 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
#' @param eps Threshold below which numerator values are set to 0. Default is 1e-10.
#' @param w Optional vector of case weights. Can also be a column name of `X`.
#' @param verbose Should a progress bar be shown? The default is `TRUE`.
#' @param survival Should cumulative hazards ("chf", default) or survival
#' probabilities ("prob") per time be predicted? Only in `ranger()` survival models.
#' @param ... Additional arguments passed to `pred_fun(object, X, ...)`,
#' for instance `type = "response"` in a [glm()] model, or `reshape = TRUE` in a
#' multiclass XGBoost model.
Expand Down Expand Up @@ -140,12 +142,21 @@ hstats <- function(object, ...) {

#' @describeIn hstats Default hstats method.
#' @export
hstats.default <- function(object, X, v = NULL,
pred_fun = stats::predict,
pairwise_m = 5L, threeway_m = 0L,
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = NULL, verbose = TRUE, ...) {
hstats.default <- function(
object,
X,
v = NULL,
pred_fun = stats::predict,
pairwise_m = 5L,
threeway_m = 0L,
approx = FALSE,
grid_size = 50L,
n_max = 500L,
eps = 1e-10,
w = NULL,
verbose = TRUE,
...
) {
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun)
Expand Down Expand Up @@ -277,12 +288,28 @@ hstats.default <- function(object, X, v = NULL,

#' @describeIn hstats Method for "ranger" models.
#' @export
hstats.ranger <- function(object, X, v = NULL,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
pairwise_m = 5L, threeway_m = 0L,
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = NULL, verbose = TRUE, ...) {
hstats.ranger <- function(
object,
X,
v = NULL,
pred_fun = NULL,
pairwise_m = 5L,
threeway_m = 0L,
approx = FALSE,
grid_size = 50L,
n_max = 500L,
eps = 1e-10,
w = NULL,
verbose = TRUE,
survival = c("chf", "prob"),
...
) {
survival <- match.arg(survival)

if (is.null(pred_fun)) {
pred_fun <- pred_ranger
}

hstats.default(
object = object,
X = X,
Expand All @@ -296,19 +323,28 @@ hstats.ranger <- function(object, X, v = NULL,
eps = eps,
w = w,
verbose = verbose,
survival = survival,
...
)
}

#' @describeIn hstats Method for DALEX "explainer".
#' @export
hstats.explainer <- function(object, X = object[["data"]],
v = NULL,
pred_fun = object[["predict_function"]],
pairwise_m = 5L, threeway_m = 0L,
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = object[["weights"]], verbose = TRUE, ...) {
hstats.explainer <- function(
object,
X = object[["data"]],
v = NULL,
pred_fun = object[["predict_function"]],
pairwise_m = 5L,
threeway_m = 0L,
approx = FALSE,
grid_size = 50L,
n_max = 500L,
eps = 1e-10,
w = object[["weights"]],
verbose = TRUE,
...
) {
hstats.default(
object = object[["model"]],
X = X,
Expand Down Expand Up @@ -353,8 +389,9 @@ print.hstats <- function(x, ...) {
#' "h2", "h2_overall", "h2_pairwise", "h2_threeway", all of class "hstats_matrix".
#' @export
#' @seealso See [hstats()] for examples.
summary.hstats <- function(object, normalize = TRUE, squared = TRUE,
sort = TRUE, zero = TRUE, ...) {
summary.hstats <- function(
object, normalize = TRUE, squared = TRUE, sort = TRUE, zero = TRUE, ...
) {
args <- list(
object = object,
normalize = normalize,
Expand Down Expand Up @@ -407,11 +444,21 @@ print.hstats_summary <- function(x, ...) {
#' @returns An object of class "ggplot".
#' @export
#' @seealso See [hstats()] for examples.
plot.hstats <- function(x, which = 1:3, normalize = TRUE, squared = TRUE,
sort = TRUE, top_m = 15L, zero = TRUE,
fill = getOption("hstats.fill"),
viridis_args = getOption("hstats.viridis_args"),
facet_scales = "free", ncol = 2L, rotate_x = FALSE, ...) {
plot.hstats <- function(
x,
which = 1:3,
normalize = TRUE,
squared = TRUE,
sort = TRUE,
top_m = 15L,
zero = TRUE,
fill = getOption("hstats.fill"),
viridis_args = getOption("hstats.viridis_args"),
facet_scales = "free",
ncol = 2L,
rotate_x = FALSE,
...
) {
if (is.null(viridis_args)) {
viridis_args <- list()
}
Expand Down Expand Up @@ -477,8 +524,9 @@ plot.hstats <- function(x, which = 1:3, normalize = TRUE, squared = TRUE,
#' @returns
#' A list with a named list of feature combinations (pairs or triples), and
#' corresponding centered partial dependencies.
mway <- function(object, v, X, pred_fun = stats::predict, w = NULL,
way = 2L, verb = TRUE, ...) {
mway <- function(
object, v, X, pred_fun = stats::predict, w = NULL, way = 2L, verb = TRUE, ...
) {
combs <- utils::combn(v, way, simplify = FALSE)
n_combs <- length(combs)
F_way <- vector("list", length = n_combs)
Expand Down Expand Up @@ -528,3 +576,4 @@ get_v <- function(H, m) {
}
v[v %in% v_cand]
}

Loading

0 comments on commit ecd0536

Please sign in to comment.