From 2f17fc5620e0a508d3fb4eec52f9e9e320a115a1 Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Sat, 23 Sep 2023 21:50:02 +0200 Subject: [PATCH 01/11] Pad zero statistics --- NEWS.md | 12 +++---- R/H2_pairwise.R | 26 ++++++++------- R/H2_threeway.R | 35 +++++++++++--------- R/hstats.R | 87 ++++++++++++++++++++++++++++++------------------- 4 files changed, 95 insertions(+), 65 deletions(-) diff --git a/NEWS.md b/NEWS.md index a29b186c..5c40cf5a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,12 +1,8 @@ # hstats 0.3.0 -## Major user visible changes +## Major visible changes -- Grid calculation: So far, the default grid strategy "uniform" used `pretty()` to generate the evaluation points. To provide more predictable grid sizes, and to be more in line with other implementations of partial dependence and ICE, we now use `seq()` to create the uniform grid. This affects `ice()`, `partial_dep()` and the exported helper functions `univariate_grid()` and `multivariate_grid()`. - -## Internal major changes - -- All available H-statistics are now calculated within `hstats()` and attached to the resulting object. Each statistic is stored as list with numerator and denominator matrices/vectors. The functions `h2()`, `h2_overall()`, `h2_pairwise()`, and `h2_threeway()`, `print.hstats()`, `summary().hstats()`, `plot.hstats()` will use these without having to recalculate the required numerators and denominators. The results, however, are unchanged. +- `ice()` and `partial_dep()`: So far, the default grid strategy "uniform" used `pretty()` to generate the evaluation points. To provide more predictable grid sizes, and to be more in line with other implementations of partial dependence and ICE, we now use `seq()` to create the uniform grid. ## Minor improvements @@ -17,6 +13,10 @@ - All progress bars were initialized 1 step too late. - `perm_importance()` and `average_loss()` would fail for "mlogloss" in case the response `y` was univariate *and* non-factor/non-character. +## Other changes + +- All available H-statistics are now calculated within `hstats()` and attached to the resulting object. Each statistic is stored as list with numerator and denominator matrices/vectors. The functions `h2()`, `h2_overall()`, `h2_pairwise()`, and `h2_threeway()`, `print.hstats()`, `summary().hstats()`, `plot.hstats()` will use these without having to recalculate the required numerators and denominators. The results, however, are unchanged. + # hstats 0.2.0 ## New major features diff --git a/R/H2_pairwise.R b/R/H2_pairwise.R index b31bd1bf..4c73c384 100644 --- a/R/H2_pairwise.R +++ b/R/H2_pairwise.R @@ -107,21 +107,25 @@ h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = #' #' @noRd #' @keywords internal -#' @param x A list containing the elements "combs2", "K", "pred_names", +#' @param x A list containing the elements "combs2", "v_pairwise_0", "K", "pred_names", #' "F_jk", "F_j", and "w". #' @returns A list with the numerator and denominator statistics. h2_pairwise_raw <- function(x) { - combs <- x[["combs2"]] - - # Note that F_jk are in the same order as combs - num <- denom <- with( - x, matrix(nrow = length(combs), ncol = K, dimnames = list(names(combs), pred_names)) + # Initialize matrices + cn0 <- combn(x[["v_pairwise_0"]], 2L, FUN = paste, collapse = ":") + num <- with( + x, matrix(0, nrow = length(cn0), ncol = K, dimnames = list(cn0, pred_names)) ) - - for (i in seq_along(combs)) { - z <- combs[[i]] - num[i, ] <- with(x, wcolMeans((F_jk[[i]] - F_j[[z[1L]]] - F_j[[z[2L]]])^2, w = w)) - denom[i, ] <- with(x, wcolMeans(F_jk[[i]]^2, w = w)) + denom <- num + 1 + + # Note that F_jk are in the same order as x[["combs2"]] + combs <- x[["combs2"]] + if (!is.null(combs)) { + for (nm in names(combs)) { + z <- combs[[nm]] + num[nm, ] <- with(x, wcolMeans((F_jk[[nm]] - F_j[[z[1L]]] - F_j[[z[2L]]])^2, w = w)) + denom[nm, ] <- with(x, wcolMeans(F_jk[[nm]]^2, w = w)) + } } list(num = num, denom = denom) diff --git a/R/H2_threeway.R b/R/H2_threeway.R index 5ab87492..d9939dcf 100644 --- a/R/H2_threeway.R +++ b/R/H2_threeway.R @@ -94,26 +94,31 @@ h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = #' #' @noRd #' @keywords internal -#' @param x A list containing the elements "combs3", "K", "pred_names", +#' @param x A list containing the elements "combs3", "v_threeway_0", "K", "pred_names", #' "F_jkl", "F_jk", "F_j", and "w". #' @returns A list with the numerator and denominator statistics. h2_threeway_raw <- function(x) { - combs <- x[["combs3"]] - - # Note that the F_jkl are in the same order as combs - num <- denom <- with( - x, matrix(nrow = length(combs), ncol = K, dimnames = list(names(combs), pred_names)) + # Initialize matrices + cn0 <- utils::combn(x[["v_threeway_0"]], 3L, FUN = paste, collapse = ":") + num <- with( + x, matrix(0, nrow = length(cn0), ncol = K, dimnames = list(cn0, pred_names)) ) + denom <- num + 1 - for (i in seq_along(combs)) { - z <- combs[[i]] - zz <- sapply(utils::combn(z, 2L, simplify = FALSE), paste, collapse = ":") - - num[i, ] <- with( - x, wcolMeans((F_jkl[[i]] - Reduce("+", F_jk[zz]) + Reduce("+", F_j[z]))^2, w = w) - ) - denom[i, ] <- with(x, wcolMeans(F_jkl[[i]]^2, w = w)) + # Note that the F_jkl are in the same order as x[["combs3"]] + combs <- x[["combs3"]] + if (!is.null(combs)) { + for (nm in names(combs)) { + z <- combs[[nm]] + zz <- utils::combn(z, 2L, paste, collapse = ":") + + num[nm, ] <- with( + x, + wcolMeans((F_jkl[[nm]] - Reduce("+", F_jk[zz]) + Reduce("+", F_j[z]))^2, w = w) + ) + denom[nm, ] <- with(x, wcolMeans(F_jkl[[nm]]^2, w = w)) + } } - + list(num = num, denom = denom) } diff --git a/R/hstats.R b/R/hstats.R index 9a49e0dc..0248bdca 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -61,6 +61,7 @@ #' - `h2_overall`: List with numerator and denominator of \eqn{H^2_j}. #' - `v_pairwise`: Subset of `v` with largest `h2_overall()` used for pairwise #' calculations. +#' - `v_pairwise_0`: Like `v_pairwise`, but padded to length `pairwise_m`. #' - `combs2`: Named list of variable pairs for which pairwise partial #' dependence functions are available. Only if pairwise calculations have been done. #' - `F_jk`: List of matrices, each representing (centered) bivariate @@ -70,6 +71,7 @@ #' Only if pairwise calculations have been done. #' - `v_threeway`: Subset of `v` with largest `h2_overall()` used for three-way #' calculations. +#' - `v_threeway_0`: Like `v_threeway`, but padded to length `threeway_m`. #' - `combs3`: Named list of variable triples for which three-way partial #' dependence functions are available. Only if threeway calculations have been done. #' - `F_jkl`: List of matrices, each representing (centered) three-way @@ -126,7 +128,13 @@ hstats.default <- function(object, X, v = colnames(X), w = NULL, pairwise_m = 5L, threeway_m = pairwise_m, verbose = TRUE, ...) { basic_check(X = X, v = v, pred_fun = pred_fun, w = w) - stopifnot(threeway_m <= pairwise_m) + p <- length(v) + stopifnot( + p >= 2L, + threeway_m <= pairwise_m + ) + pairwise_m <- min(pairwise_m, p) + threeway_m <- min(threeway_m, p) # Reduce size of X (and w) if (nrow(X) > n_max) { @@ -143,8 +151,7 @@ hstats.default <- function(object, X, v = colnames(X), # Initialize first progress bar p <- length(v) - show_bar <- verbose && p >= 2L - if (show_bar) { + if (verbose) { cat("1-way calculations...\n") pb <- utils::txtProgressBar(max = p, style = 3) } @@ -175,11 +182,11 @@ hstats.default <- function(object, X, v = colnames(X), w = w ) - if (show_bar) { + if (verbose) { utils::setTxtProgressBar(pb, j) } } - if (show_bar) { + if (verbose) { cat("\n") } @@ -199,25 +206,29 @@ hstats.default <- function(object, X, v = colnames(X), # 0-way and 1-way stats out[["h2"]] <- h2_raw(out) out[["h2_overall"]] <- h2_overall_raw(out) - - # 2+way stats are calculated only for features with largest overall interactions h2_ov <- .zap_small(out$h2_overall$num, eps = 1e-8) # Does eps need to be passed? - out[["v_pairwise"]] <- v2 <- get_v(h2_ov, m = pairwise_m) - if (min(pairwise_m, length(v2)) >= 2L) { - out[c("combs2", "F_jk")] <- mway( - object, v = v2, X = X, pred_fun = pred_fun, w = w, way = 2L, verb = verbose, ... - ) - out[["h2_pairwise"]] <- h2_pairwise_raw(out) - } - out[["v_threeway"]] <- v3 <- get_v(h2_ov, m = threeway_m) - if (min(threeway_m, length(v3)) >= 3L) { - out[c("combs3", "F_jkl")] <- mway( - object, v = v3, X = X, pred_fun = pred_fun, w = w, way = 3L, verb = verbose, ... - ) - out[["h2_threeway"]] <- h2_threeway_raw(out) + if (pairwise_m >= 2L) { + out[c("v_pairwise", "v_pairwise_0")] <- get_v(h2_ov, m = pairwise_m) + v2 <- out[["v_pairwise"]] + if (length(v2) >= 2L) { + out[c("combs2", "F_jk")] <- mway( + object, v = v2, X = X, pred_fun = pred_fun, w = w, way = 2L, verb = verbose, ... + ) + out[["h2_pairwise"]] <- h2_pairwise_raw(out) + } } - + if (threeway_m >= 3L) { + out[c("v_threeway", "v_threeway_0")] <- get_v(h2_ov, m = threeway_m) + v3 <- out[["v_threeway"]] + if (length(v3) >= 3L) { + out[c("combs3", "F_jkl")] <- mway( + object, v = v3, X = X, pred_fun = pred_fun, w = w, way = 3L, verb = verbose, ... + ) + out[["h2_threeway"]] <- h2_threeway_raw(out) + } + } + structure(out, class = "hstats") } @@ -344,12 +355,11 @@ summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE #' @export #' @seealso See [hstats()] for examples. print.summary_hstats <- function(x, ...) { - addon <- "(for features with strong overall interactions)" txt <- c( h2 = "Proportion of prediction variability unexplained by main effects of v", h2_overall = "Strongest overall interactions", - h2_pairwise = paste0("Strongest relative pairwise interactions\n", addon), - h2_threeway = paste0("Strongest relative three-way interactions\n", addon) + h2_pairwise = "Strongest pairwise interactions", + h2_threeway = "Strongest three-way interactions" ) for (nm in names(Filter(Negate(is.null), x))) { @@ -434,8 +444,7 @@ mway <- function(object, v, X, pred_fun = stats::predict, w = NULL, F_way <- vector("list", length = n_combs) names(F_way) <- names(combs) <- sapply(combs, paste, collapse = ":") - show_bar <- verb && (n_combs >= way) - if (show_bar) { + if (verb) { cat(way, "way calculations...\n", sep = "-") pb <- utils::txtProgressBar(max = n_combs, style = 3) } @@ -446,11 +455,11 @@ mway <- function(object, v, X, pred_fun = stats::predict, w = NULL, pd_raw(object, v = z, X = X, grid = X[, z], pred_fun = pred_fun, w = w, ...), w = w ) - if (show_bar) { + if (verb) { utils::setTxtProgressBar(pb, i) } } - if (show_bar) { + if (verb) { cat("\n") } list(combs, F_way) @@ -468,17 +477,29 @@ mway <- function(object, v, X, pred_fun = stats::predict, w = NULL, #' @param H Unnormalized, unsorted H2_j values. #' @param m Number of features to pick per column. #' -#' @returns A vector of feature names. +#' @returns A list with two vectors of feature names. The first contains +#' only the m most important features (union over columns), while the first +#' is a padded version of length m. It is necessary in cases where less than m +#' features show interactions. get_v <- function(H, m) { - # Get largest m positive values per column + v <- rownames(H) + + # Get m strongest features per column selector <- function(vv) names(utils::head(sort(-vv[vv > 0]), m)) if (NCOL(H) == 1L) { v_cand <- selector(drop(H)) } else { v_cand <- Reduce(union, lapply(asplit(H, MARGIN = 2L), FUN = selector)) } - # Same order as in v - v <- rownames(H) - v[v %in% v_cand] + + # Do we need to add some features without interactions? + m_miss <- m - length(v_cand) + if (m_miss > 0L) { + v_cand_0 <- c(v_cand, utils::head(setdiff(v, v_cand), m_miss)) + } else { + v_cand_0 <- v_cand + } + # Bring vectors into same order as v + return(list(v[v %in% v_cand], v[v %in% v_cand_0])) } From 2313380b462b99ff62f7299431587d43471528a3 Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Sat, 23 Sep 2023 22:05:28 +0200 Subject: [PATCH 02/11] Better print.summary_hstats() --- R/hstats.R | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/R/hstats.R b/R/hstats.R index 0248bdca..8b0ca4ee 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -321,7 +321,8 @@ print.hstats <- function(x, ...) { #' @inheritParams h2_overall #' @param ... Currently not used. #' @returns -#' An object of class "summary_hstats" representing a named list with statistics. +#' An object of class "summary_hstats" representing a named list with statistics +#' and `normalize`. #' @export #' @seealso See [hstats()] for examples. summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, @@ -339,7 +340,8 @@ summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE h2 = h2(object, normalize = normalize, squared = squared, eps = eps), h2_overall = do.call(h2_overall, args), h2_pairwise = do.call(h2_pairwise, args), - h2_threeway = do.call(h2_threeway, args) + h2_threeway = do.call(h2_threeway, args), + normalize = normalize ) class(out) <- "summary_hstats" out @@ -355,14 +357,15 @@ summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE #' @export #' @seealso See [hstats()] for examples. print.summary_hstats <- function(x, ...) { + flag <- if (x[["normalize"]]) "relative" else "absolute" txt <- c( - h2 = "Proportion of prediction variability unexplained by main effects of v", - h2_overall = "Strongest overall interactions", - h2_pairwise = "Strongest pairwise interactions", - h2_threeway = "Strongest three-way interactions" + h2 = sprintf("Prediction variability unexplained by main effects of v (%s)", flag), + h2_overall = sprintf("Strongest %s overall interactions", flag), + h2_pairwise = sprintf("Strongest %s pairwise interactions", flag), + h2_threeway = sprintf("Strongest %s three-way interactions", flag) ) - for (nm in names(Filter(Negate(is.null), x))) { + for (nm in setdiff(names(Filter(Negate(is.null), x)), "normalize")) { cat(txt[[nm]]) cat("\n") print(utils::head(drop(x[[nm]]))) From ecbddb2e5070946a9be77c9b10f751cfab4476ba Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Sun, 24 Sep 2023 09:13:44 +0200 Subject: [PATCH 03/11] Adding drop_zero to statistics --- NEWS.md | 9 ++++++--- R/H2.R | 7 +++---- R/H2_overall.R | 8 +++++--- R/H2_pairwise.R | 13 ++++++++++--- R/H2_threeway.R | 7 ++++--- R/hstats.R | 26 +++++++++++++++++--------- R/pd_importance.R | 6 ++++-- 7 files changed, 49 insertions(+), 27 deletions(-) diff --git a/NEWS.md b/NEWS.md index 5c40cf5a..566c529a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,11 +1,14 @@ # hstats 0.3.0 -## Major visible changes +## Visible changes -- `ice()` and `partial_dep()`: So far, the default grid strategy "uniform" used `pretty()` to generate the evaluation points. To provide more predictable grid sizes, and to be more in line with other implementations of partial dependence and ICE, we now use `seq()` to create the uniform grid. +- `ice()` and `partial_dep()`: So far, the default grid strategy "uniform" used `pretty()` to generate the evaluation points. To provide more predictable grid sizes, and to be more in line with other implementations of partial dependence and ICE, we now use `seq()` to create the uniform grid. +- `h2()` will now always show the normalized $H^2$. The options `normalize` and `squared` have been removed. +- `h2_overall()` will only show features with positive interaction. -## Minor improvements +## Improvements +- `h2_overall()`, `h2_pairwise()`, `h2_threeway()`, `plot.hstats()`, and `summary.hstats()` have received an argument `drop_zero = TRUE`. Set to `FALSE` to also show statistics having value 0. - `perm_importance()` and `average_loss()` will now recycle a univariate response when combined with multivariate predictions. This is useful, e.g., when the prediction function represents the predictions of multiple models that should be evaluated against a common response. ## Bug fixes diff --git a/R/H2.R b/R/H2.R index c4bc96cd..a0e53495 100644 --- a/R/H2.R +++ b/R/H2.R @@ -65,13 +65,12 @@ h2.default <- function(object, ...) { #' @describeIn h2 Total interaction strength from "interact" object. #' @export -h2.hstats <- function(object, normalize = TRUE, squared = TRUE, eps = 1e-8, ...) { +h2.hstats <- function(object, eps = 1e-8, ...) { postprocess( num = object$h2$num, denom = object$h2$denom, - normalize = normalize, - squared = squared, - sort = FALSE, + sort = FALSE, + drop_zero = FALSE, eps = eps ) } diff --git a/R/H2_overall.R b/R/H2_overall.R index a4469ef3..458eca7c 100644 --- a/R/H2_overall.R +++ b/R/H2_overall.R @@ -44,6 +44,7 @@ #' @param sort Should results be sorted? Default is `TRUE`. #' (Multioutput is sorted by row means.) #' @param top_m How many rows should be shown? (`Inf` to show all.) +#' @param drop_zero Should rows with all 0 be dropped? Default is `TRUE`. #' @param eps Threshold below which numerator values are set to 0. #' @param plot Should results be plotted as barplot? Default is `FALSE`. #' @param fill Color of bar (only for univariate statistics). @@ -78,8 +79,8 @@ h2_overall.default <- function(object, ...) { #' @describeIn h2_overall Overall interaction strength from "hstats" object. #' @export h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = 15L, eps = 1e-8, plot = FALSE, fill = "#2b51a1", - ...) { + top_m = 15L, drop_zero = TRUE, eps = 1e-8, + plot = FALSE, fill = "#2b51a1", ...) { s <- object$h2_overall out <- postprocess( num = s$num, @@ -87,7 +88,8 @@ h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = T normalize = normalize, squared = squared, sort = sort, - top_m = top_m, + top_m = top_m, + drop_zero = drop_zero, eps = eps ) if (plot) plot_stat(out, fill = fill, ...) else out diff --git a/R/H2_pairwise.R b/R/H2_pairwise.R index 4c73c384..59d5b51c 100644 --- a/R/H2_pairwise.R +++ b/R/H2_pairwise.R @@ -62,6 +62,9 @@ #' # (for features with strongest overall interactions) #' h2_pairwise(s) #' +#' # Do not drop zeros +#' h2_pairwise(s, drop_zero = FALSE) +#' #' # Absolute measure as alternative #' h2_pairwise(s, normalize = FALSE, squared = FALSE) #' @@ -69,6 +72,9 @@ #' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris) #' s <- hstats(fit, X = iris[3:5], verbose = FALSE) #' h2_pairwise(s, plot = TRUE) +#' +#' # Do not drop zeros +#' h2_pairwise(s, drop_zero = FALSE, plot = TRUE) h2_pairwise <- function(object, ...) { UseMethod("h2_pairwise") } @@ -82,8 +88,8 @@ h2_pairwise.default <- function(object, ...) { #' @describeIn h2_pairwise Pairwise interaction strength from "hstats" object. #' @export h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = 15L, eps = 1e-8, plot = FALSE, - fill = "#2b51a1", ...) { + top_m = 15L, drop_zero = TRUE, eps = 1e-8, + plot = FALSE, fill = "#2b51a1", ...) { s <- object$h2_pairwise if (is.null(s)) { return(NULL) @@ -94,7 +100,8 @@ h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = normalize = normalize, squared = squared, sort = sort, - top_m = top_m, + top_m = top_m, + drop_zero = drop_zero, eps = eps ) if (plot) plot_stat(out, fill = fill, ...) else out diff --git a/R/H2_threeway.R b/R/H2_threeway.R index d9939dcf..97580781 100644 --- a/R/H2_threeway.R +++ b/R/H2_threeway.R @@ -69,8 +69,8 @@ h2_threeway.default <- function(object, ...) { #' @describeIn h2_threeway Pairwise interaction strength from "hstats" object. #' @export h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = 15L, eps = 1e-8, plot = FALSE, - fill = "#2b51a1", ...) { + top_m = 15L, drop_zero = TRUE, eps = 1e-8, + plot = FALSE, fill = "#2b51a1", ...) { s <- object$h2_threeway if (is.null(s)) { return(NULL) @@ -81,7 +81,8 @@ h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = normalize = normalize, squared = squared, sort = sort, - top_m = top_m, + top_m = top_m, + drop_zero = drop_zero, eps = eps ) if (plot) plot_stat(out, fill = fill, ...) else out diff --git a/R/hstats.R b/R/hstats.R index 8b0ca4ee..1dd8f167 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -91,6 +91,7 @@ #' s <- hstats(fit, X = iris[-1]) #' s #' plot(s) +#' plot(s, drop_zero = FALSE) #' summary(s) #' #' # Absolute pairwise interaction strengths @@ -150,7 +151,6 @@ hstats.default <- function(object, X, v = colnames(X), mean_f2 <- wcolMeans(f^2, w = w) # A vector # Initialize first progress bar - p <- length(v) if (verbose) { cat("1-way calculations...\n") pb <- utils::txtProgressBar(max = p, style = 3) @@ -215,8 +215,8 @@ hstats.default <- function(object, X, v = colnames(X), out[c("combs2", "F_jk")] <- mway( object, v = v2, X = X, pred_fun = pred_fun, w = w, way = 2L, verb = verbose, ... ) - out[["h2_pairwise"]] <- h2_pairwise_raw(out) } + out[["h2_pairwise"]] <- h2_pairwise_raw(out) } if (threeway_m >= 3L) { out[c("v_threeway", "v_threeway_0")] <- get_v(h2_ov, m = threeway_m) @@ -225,8 +225,8 @@ hstats.default <- function(object, X, v = colnames(X), out[c("combs3", "F_jkl")] <- mway( object, v = v3, X = X, pred_fun = pred_fun, w = w, way = 3L, verb = verbose, ... ) - out[["h2_threeway"]] <- h2_threeway_raw(out) } + out[["h2_threeway"]] <- h2_threeway_raw(out) } structure(out, class = "hstats") @@ -316,7 +316,8 @@ print.hstats <- function(x, ...) { #' Summary Method #' -#' Summary method for "hstats" object. +#' Summary method for "hstats" object. Note that \eqn{H^2} is not affected by +#' the arguments `normalize` and `squared`. #' #' @inheritParams h2_overall #' @param ... Currently not used. @@ -326,18 +327,19 @@ print.hstats <- function(x, ...) { #' @export #' @seealso See [hstats()] for examples. summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = Inf, eps = 1e-8, ...) { + top_m = Inf, drop_zero = TRUE, eps = 1e-8, ...) { args <- list( object = object, normalize = normalize, squared = squared, sort = sort, top_m = top_m, + drop_zero = drop_zero, eps = eps, plot = FALSE ) out <- list( - h2 = h2(object, normalize = normalize, squared = squared, eps = eps), + h2 = h2(object, eps = eps), h2_overall = do.call(h2_overall, args), h2_pairwise = do.call(h2_pairwise, args), h2_threeway = do.call(h2_threeway, args), @@ -359,7 +361,7 @@ summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE print.summary_hstats <- function(x, ...) { flag <- if (x[["normalize"]]) "relative" else "absolute" txt <- c( - h2 = sprintf("Prediction variability unexplained by main effects of v (%s)", flag), + h2 = "Proportion of prediction variability unexplained by main effects of v", h2_overall = sprintf("Strongest %s overall interactions", flag), h2_pairwise = sprintf("Strongest %s pairwise interactions", flag), h2_threeway = sprintf("Strongest %s three-way interactions", flag) @@ -391,10 +393,16 @@ print.summary_hstats <- function(x, ...) { #' @export #' @seealso See [hstats()] for examples. plot.hstats <- function(x, which = 1:2, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = 15L, eps = 1e-8, fill = "#2b51a1", + top_m = 15L, drop_zero = TRUE, eps = 1e-8, fill = "#2b51a1", facet_scales = "free", ncol = 2L, rotate_x = FALSE, ...) { su <- summary( - x, normalize = normalize, squared = squared, sort = sort, top_m = top_m, eps = eps + x, + normalize = normalize, + squared = squared, + sort = sort, + top_m = top_m, + drop_zero = drop_zero, + eps = eps ) nms <- c("h2_overall", "h2_pairwise", "h2_threeway") ids <- c("Overall", "Pairwise", "Threeway") diff --git a/R/pd_importance.R b/R/pd_importance.R index dbb74f90..df013cef 100644 --- a/R/pd_importance.R +++ b/R/pd_importance.R @@ -57,7 +57,8 @@ pd_importance.default <- function(object, ...) { #' @describeIn pd_importance PD based feature importance from "hstats" object. #' @export pd_importance.hstats <- function(object, normalize = TRUE, squared = TRUE, - sort = TRUE, top_m = 15L, eps = 1e-8, plot = FALSE, + sort = TRUE, top_m = 15L, drop_zero = FALSE, + eps = 1e-8, plot = FALSE, fill = "#2b51a1", ...) { num <- with( object, matrix(nrow = length(v), ncol = K, dimnames = list(v, pred_names)) @@ -71,7 +72,8 @@ pd_importance.hstats <- function(object, normalize = TRUE, squared = TRUE, normalize = normalize, squared = squared, sort = sort, - top_m = top_m, + top_m = top_m, + drop_zero = drop_zero, eps = eps ) if (plot) plot_stat(out, fill = fill, ...) else out From 78dc7c920375074f036fd3333a391e0096b666a4 Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Sun, 24 Sep 2023 10:50:19 +0200 Subject: [PATCH 04/11] switch default --- NEWS.md | 6 +++--- R/H2.R | 1 - R/H2_overall.R | 8 ++++---- R/H2_pairwise.R | 12 +++++------- R/H2_threeway.R | 4 ++-- R/hstats.R | 16 +++++++--------- R/pd_importance.R | 4 ++-- 7 files changed, 23 insertions(+), 28 deletions(-) diff --git a/NEWS.md b/NEWS.md index 566c529a..5e8dddbf 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,12 +3,12 @@ ## Visible changes - `ice()` and `partial_dep()`: So far, the default grid strategy "uniform" used `pretty()` to generate the evaluation points. To provide more predictable grid sizes, and to be more in line with other implementations of partial dependence and ICE, we now use `seq()` to create the uniform grid. -- `h2()` will now always show the normalized $H^2$. The options `normalize` and `squared` have been removed. -- `h2_overall()` will only show features with positive interaction. +- `h2()` will now always show the normalized $H^2$. Its options `normalize` and `squared` have been removed. +- `h2_pairwise()` and `h2_threeway()` will now also include (some) combinations with value 0. Use `zero = FALSE` to drop them, see below. ## Improvements -- `h2_overall()`, `h2_pairwise()`, `h2_threeway()`, `plot.hstats()`, and `summary.hstats()` have received an argument `drop_zero = TRUE`. Set to `FALSE` to also show statistics having value 0. +- `h2_overall()`, `h2_pairwise()`, `h2_threeway()`, `plot.hstats()`, and `summary.hstats()` have received an argument `zero = TRUE`. Set to `FALSE` to drop statistics having value 0. - `perm_importance()` and `average_loss()` will now recycle a univariate response when combined with multivariate predictions. This is useful, e.g., when the prediction function represents the predictions of multiple models that should be evaluated against a common response. ## Bug fixes diff --git a/R/H2.R b/R/H2.R index a0e53495..ff9e0344 100644 --- a/R/H2.R +++ b/R/H2.R @@ -70,7 +70,6 @@ h2.hstats <- function(object, eps = 1e-8, ...) { num = object$h2$num, denom = object$h2$denom, sort = FALSE, - drop_zero = FALSE, eps = eps ) } diff --git a/R/H2_overall.R b/R/H2_overall.R index 458eca7c..8ff25c12 100644 --- a/R/H2_overall.R +++ b/R/H2_overall.R @@ -44,7 +44,7 @@ #' @param sort Should results be sorted? Default is `TRUE`. #' (Multioutput is sorted by row means.) #' @param top_m How many rows should be shown? (`Inf` to show all.) -#' @param drop_zero Should rows with all 0 be dropped? Default is `TRUE`. +#' @param zero Should rows with all 0 be shown? Default is `TRUE`. #' @param eps Threshold below which numerator values are set to 0. #' @param plot Should results be plotted as barplot? Default is `FALSE`. #' @param fill Color of bar (only for univariate statistics). @@ -65,7 +65,7 @@ #' # MODEL 2: Multi-response linear regression #' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris) #' s <- hstats(fit, X = iris[3:5], verbose = FALSE) -#' h2_overall(s, plot = TRUE) +#' h2_overall(s, plot = TRUE, zero = FALSE) h2_overall <- function(object, ...) { UseMethod("h2_overall") } @@ -79,7 +79,7 @@ h2_overall.default <- function(object, ...) { #' @describeIn h2_overall Overall interaction strength from "hstats" object. #' @export h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = 15L, drop_zero = TRUE, eps = 1e-8, + top_m = 15L, zero = TRUE, eps = 1e-8, plot = FALSE, fill = "#2b51a1", ...) { s <- object$h2_overall out <- postprocess( @@ -89,7 +89,7 @@ h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = T squared = squared, sort = sort, top_m = top_m, - drop_zero = drop_zero, + zero = zero, eps = eps ) if (plot) plot_stat(out, fill = fill, ...) else out diff --git a/R/H2_pairwise.R b/R/H2_pairwise.R index 59d5b51c..bcbe81e6 100644 --- a/R/H2_pairwise.R +++ b/R/H2_pairwise.R @@ -62,8 +62,8 @@ #' # (for features with strongest overall interactions) #' h2_pairwise(s) #' -#' # Do not drop zeros -#' h2_pairwise(s, drop_zero = FALSE) +#' # Drop 0 +#' h2_pairwise(s, zero = FALSE) #' #' # Absolute measure as alternative #' h2_pairwise(s, normalize = FALSE, squared = FALSE) @@ -72,9 +72,7 @@ #' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris) #' s <- hstats(fit, X = iris[3:5], verbose = FALSE) #' h2_pairwise(s, plot = TRUE) -#' -#' # Do not drop zeros -#' h2_pairwise(s, drop_zero = FALSE, plot = TRUE) +#' h2_pairwise(s, zero = FALSE, plot = TRUE) h2_pairwise <- function(object, ...) { UseMethod("h2_pairwise") } @@ -88,7 +86,7 @@ h2_pairwise.default <- function(object, ...) { #' @describeIn h2_pairwise Pairwise interaction strength from "hstats" object. #' @export h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = 15L, drop_zero = TRUE, eps = 1e-8, + top_m = 15L, zero = TRUE, eps = 1e-8, plot = FALSE, fill = "#2b51a1", ...) { s <- object$h2_pairwise if (is.null(s)) { @@ -101,7 +99,7 @@ h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = squared = squared, sort = sort, top_m = top_m, - drop_zero = drop_zero, + zero = zero, eps = eps ) if (plot) plot_stat(out, fill = fill, ...) else out diff --git a/R/H2_threeway.R b/R/H2_threeway.R index 97580781..f5ec7f99 100644 --- a/R/H2_threeway.R +++ b/R/H2_threeway.R @@ -69,7 +69,7 @@ h2_threeway.default <- function(object, ...) { #' @describeIn h2_threeway Pairwise interaction strength from "hstats" object. #' @export h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = 15L, drop_zero = TRUE, eps = 1e-8, + top_m = 15L, zero = TRUE, eps = 1e-8, plot = FALSE, fill = "#2b51a1", ...) { s <- object$h2_threeway if (is.null(s)) { @@ -82,7 +82,7 @@ h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = squared = squared, sort = sort, top_m = top_m, - drop_zero = drop_zero, + zero = zero, eps = eps ) if (plot) plot_stat(out, fill = fill, ...) else out diff --git a/R/hstats.R b/R/hstats.R index 1dd8f167..7d32e5a5 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -91,7 +91,7 @@ #' s <- hstats(fit, X = iris[-1]) #' s #' plot(s) -#' plot(s, drop_zero = FALSE) +#' plot(s, zero = FALSE) #' summary(s) #' #' # Absolute pairwise interaction strengths @@ -112,9 +112,7 @@ #' #' # On original scale, we have interactions everywhere... #' s <- hstats(fit, X = iris[-1], type = "response", verbose = FALSE) -#' -#' # All three types use different denominators -#' plot(s, which = 1:3, ncol = 1) +#' plot(s, which = 1:3, ncol = 1) # All three types use different denominators #' #' # All statistics on same scale (of predictions) #' plot(s, which = 1:3, squared = FALSE, normalize = FALSE, facet_scale = "free_y") @@ -327,14 +325,14 @@ print.hstats <- function(x, ...) { #' @export #' @seealso See [hstats()] for examples. summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = Inf, drop_zero = TRUE, eps = 1e-8, ...) { + top_m = Inf, zero = TRUE, eps = 1e-8, ...) { args <- list( object = object, normalize = normalize, squared = squared, sort = sort, top_m = top_m, - drop_zero = drop_zero, + zero = zero, eps = eps, plot = FALSE ) @@ -370,7 +368,7 @@ print.summary_hstats <- function(x, ...) { for (nm in setdiff(names(Filter(Negate(is.null), x)), "normalize")) { cat(txt[[nm]]) cat("\n") - print(utils::head(drop(x[[nm]]))) + print(utils::head(x[[nm]])) cat("\n") } invisible(x) @@ -393,7 +391,7 @@ print.summary_hstats <- function(x, ...) { #' @export #' @seealso See [hstats()] for examples. plot.hstats <- function(x, which = 1:2, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = 15L, drop_zero = TRUE, eps = 1e-8, fill = "#2b51a1", + top_m = 15L, zero = TRUE, eps = 1e-8, fill = "#2b51a1", facet_scales = "free", ncol = 2L, rotate_x = FALSE, ...) { su <- summary( x, @@ -401,7 +399,7 @@ plot.hstats <- function(x, which = 1:2, normalize = TRUE, squared = TRUE, sort = squared = squared, sort = sort, top_m = top_m, - drop_zero = drop_zero, + zero = zero, eps = eps ) nms <- c("h2_overall", "h2_pairwise", "h2_threeway") diff --git a/R/pd_importance.R b/R/pd_importance.R index df013cef..a0a1489b 100644 --- a/R/pd_importance.R +++ b/R/pd_importance.R @@ -57,7 +57,7 @@ pd_importance.default <- function(object, ...) { #' @describeIn pd_importance PD based feature importance from "hstats" object. #' @export pd_importance.hstats <- function(object, normalize = TRUE, squared = TRUE, - sort = TRUE, top_m = 15L, drop_zero = FALSE, + sort = TRUE, top_m = 15L, zero = TRUE, eps = 1e-8, plot = FALSE, fill = "#2b51a1", ...) { num <- with( @@ -73,7 +73,7 @@ pd_importance.hstats <- function(object, normalize = TRUE, squared = TRUE, squared = squared, sort = sort, top_m = top_m, - drop_zero = drop_zero, + zero = zero, eps = eps ) if (plot) plot_stat(out, fill = fill, ...) else out From 3a08f2088d08231a3f28557d8186dddc4ed09fab Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Sun, 24 Sep 2023 17:48:08 +0200 Subject: [PATCH 05/11] Finished draft --- R/hstats.R | 24 ++++++++++++++++-------- R/utils.R | 19 +++++++++++++++---- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/R/hstats.R b/R/hstats.R index 7d32e5a5..7ed7f2c5 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -321,7 +321,8 @@ print.hstats <- function(x, ...) { #' @param ... Currently not used. #' @returns #' An object of class "summary_hstats" representing a named list with statistics -#' and `normalize`. +#' "h2", "h2_overall", "h2_pairwise", "h2_threeway", and the input flag "normalize". +#' Statistics that equal `NULL` are omitted from the list. #' @export #' @seealso See [hstats()] for examples. summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, @@ -336,15 +337,15 @@ summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE eps = eps, plot = FALSE ) - out <- list( + out <- list( h2 = h2(object, eps = eps), h2_overall = do.call(h2_overall, args), h2_pairwise = do.call(h2_pairwise, args), h2_threeway = do.call(h2_threeway, args), normalize = normalize ) - class(out) <- "summary_hstats" - out + out <- out[!sapply(out, is.null)] + structure(out, class = "summary_hstats") } #' Print Method @@ -365,7 +366,7 @@ print.summary_hstats <- function(x, ...) { h2_threeway = sprintf("Strongest %s three-way interactions", flag) ) - for (nm in setdiff(names(Filter(Negate(is.null), x)), "normalize")) { + for (nm in setdiff(names(x), "normalize")) { cat(txt[[nm]]) cat("\n") print(utils::head(x[[nm]])) @@ -402,10 +403,17 @@ plot.hstats <- function(x, which = 1:2, normalize = TRUE, squared = TRUE, sort = zero = zero, eps = eps ) - nms <- c("h2_overall", "h2_pairwise", "h2_threeway") - ids <- c("Overall", "Pairwise", "Threeway") - dat <- lapply(which, FUN = function(j) mat2df(su[[nms[j]]], id = ids[j])) + + nms <- c(Overall = "h2_overall", Pairwise = "h2_pairwise", Threeway = "h2_threeway") + su <- su[nms[which]] + + if (length(su) == 0L) { + return(NULL) + } + + dat <- lapply(names(su), FUN = function(nm) mat2df(su[[nm]], id = nm)) dat <- do.call(rbind, dat) + p <- ggplot2::ggplot(dat, ggplot2::aes(x = value_, y = variable_)) + ggplot2::ylab(ggplot2::element_blank()) + ggplot2::xlab("Value") diff --git a/R/utils.R b/R/utils.R index 9f9a0fc2..481e536e 100644 --- a/R/utils.R +++ b/R/utils.R @@ -216,9 +216,9 @@ basic_check <- function(X, v, pred_fun, w = NULL) { #' @inheritParams H2_overall #' @param num Matrix or vector of statistic. #' @param denom Denominator of statistic (a matrix, number, or vector compatible with `num`). -#' @returns Matrix or vector of statistics. +#' @returns Matrix or vector of statistics. If length of output is 0, then `NULL`. postprocess <- function(num, denom = 1, normalize = TRUE, squared = TRUE, - sort = TRUE, top_m = Inf, eps = 1e-8) { + sort = TRUE, top_m = Inf, zero = TRUE, eps = 1e-8) { out <- .zap_small(num, eps = eps) if (normalize) { if (length(denom) == 1L || length(num) == length(denom)) { @@ -239,7 +239,15 @@ postprocess <- function(num, denom = 1, normalize = TRUE, squared = TRUE, out <- sort(out, decreasing = TRUE) } } - utils::head(out, n = top_m) + if (!zero) { + if (is.matrix(out)) { + out <- out[rowSums(out) > 0, , drop = FALSE] + } else { + out <- out[out > 0] + } + } + out <- utils::head(out, n = top_m) + if (length(out) == 0L) NULL else out } #' Zap Small Values @@ -339,8 +347,11 @@ qcut <- function(x, m) { #' @param x A matrix of statistics with rownames. #' @param fill Color of bar (only for univariate statistics). #' @param ... Arguments passed to `geom_bar()`. -#' @returns An object of class "ggplot". +#' @returns An object of class "ggplot", or `NULL`. plot_stat <- function(x, fill = "#2b51a1", ...) { + if (is.null(x)) { + return(NULL) + } p <- ggplot2::ggplot(mat2df(x), ggplot2::aes(x = value_, y = variable_)) + ggplot2::ylab(ggplot2::element_blank()) + ggplot2::xlab("Value") From 28bd297729a0b7446e6405f3f42cacaea80bc7bf Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Sun, 24 Sep 2023 19:32:00 +0200 Subject: [PATCH 06/11] Update docu, more cautious default for three-way, less verbose print.summary --- NEWS.md | 2 ++ R/hstats.R | 56 +++++++++++++++++++++++-------------------- R/pd_importance.R | 5 ++-- man/H2.Rd | 6 +---- man/H2_overall.Rd | 5 +++- man/H2_pairwise.Rd | 7 ++++++ man/H2_threeway.Rd | 3 +++ man/hstats.Rd | 41 ++++++++++++++++--------------- man/pd_importance.Rd | 3 +++ man/plot.hstats.Rd | 3 +++ man/summary.hstats.Rd | 10 ++++++-- 11 files changed, 83 insertions(+), 58 deletions(-) diff --git a/NEWS.md b/NEWS.md index 5e8dddbf..0e3865b7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -5,6 +5,8 @@ - `ice()` and `partial_dep()`: So far, the default grid strategy "uniform" used `pretty()` to generate the evaluation points. To provide more predictable grid sizes, and to be more in line with other implementations of partial dependence and ICE, we now use `seq()` to create the uniform grid. - `h2()` will now always show the normalized $H^2$. Its options `normalize` and `squared` have been removed. - `h2_pairwise()` and `h2_threeway()` will now also include (some) combinations with value 0. Use `zero = FALSE` to drop them, see below. +- `hstats()`: The default number of features considered for *three-way interactions* has been changed from `threeway_m = pairwise_m` to the more cautious `threeway_m = min(pairwise_m, 5L)`. +- The `print()` method of `summary.hstats()` is less verbose. ## Improvements diff --git a/R/hstats.R b/R/hstats.R index 7ed7f2c5..7f414b34 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -15,7 +15,7 @@ #' Furthermore, it allows to calculate an experimental partial dependence based #' measure of feature importance, \eqn{\textrm{PDI}_j^2}. It equals the proportion of #' prediction variability unexplained by other features, see [pd_importance()] -#' for details. (This statistic is not shown by `summary()` or `plot()`.) +#' for details. This statistic is not shown by `summary()` or `plot()`. #' #' Instead of using `summary()`, interaction statistics can also be obtained via the #' more flexible functions [h2()], [h2_overall()], [h2_pairwise()], and @@ -36,11 +36,11 @@ #' @param pairwise_m Number of features for which pairwise statistics are to be #' calculated. The features are selected based on Friedman and Popescu's overall #' interaction strength \eqn{H^2_j}. Set to to 0 to avoid pairwise calculations. -#' For multivariate predictions, the union of the column-wise strongest variable -#' names is taken. This can lead to very long run-times. -#' @param threeway_m Same as `pairwise_m`, but controlling the number of features for -#' which threeway interactions should be calculated. Not larger than `pairwise_m`. -#' Set to 0 to avoid threeway calculations. +#' For multivariate predictions, the union of the `pairwise_m` column-wise +#' strongest variable names is taken. This can lead to very long run-times. +#' @param threeway_m Like `pairwise_m`, but controls the feature count for +#' three-way interactions. Cannot be larger than `pairwise_m`. +#' The default is `min(pairwise_m, 5)`. Set to 0 to avoid three-way calculations. #' @param verbose Should a progress bar be shown? The default is `TRUE`. #' @param ... Additional arguments passed to `pred_fun(object, X, ...)`, #' for instance `type = "response"` in a [glm()] model. @@ -50,7 +50,8 @@ #' - `w`: Input `w` (sampled to `n_max` values, or `NULL`). #' - `v`: Same as input `v`. #' - `f`: Matrix with (centered) predictions \eqn{F}. -#' - `mean_f2`: (Weighted) column means of `f`. Used to normalize most statistics. +#' - `mean_f2`: (Weighted) column means of `f`. Used to normalize \eqn{H^2} and +#' \eqn{H^2_j}. #' - `F_j`: List of matrices, each representing (centered) #' partial dependence functions \eqn{F_j}. #' - `F_not_j`: List of matrices with (centered) partial dependence @@ -59,26 +60,23 @@ #' - `pred_names`: Column names of prediction matrix. #' - `h2`: List with numerator and denominator of \eqn{H^2}. #' - `h2_overall`: List with numerator and denominator of \eqn{H^2_j}. -#' - `v_pairwise`: Subset of `v` with largest `h2_overall()` used for pairwise +#' - `v_pairwise`: Subset of `v` with largest \eqn{H^2_j} used for pairwise #' calculations. #' - `v_pairwise_0`: Like `v_pairwise`, but padded to length `pairwise_m`. #' - `combs2`: Named list of variable pairs for which pairwise partial -#' dependence functions are available. Only if pairwise calculations have been done. +#' dependence functions are available. #' - `F_jk`: List of matrices, each representing (centered) bivariate -#' partial dependence functions \eqn{F_{jk}}. -#' Only if pairwise calculations have been done. +#' partial dependence functions \eqn{F_{jk}}. #' - `h2_pairwise`: List with numerator and denominator of \eqn{H^2_{jk}}. #' Only if pairwise calculations have been done. #' - `v_threeway`: Subset of `v` with largest `h2_overall()` used for three-way #' calculations. #' - `v_threeway_0`: Like `v_threeway`, but padded to length `threeway_m`. #' - `combs3`: Named list of variable triples for which three-way partial -#' dependence functions are available. Only if threeway calculations have been done. +#' dependence functions are available. #' - `F_jkl`: List of matrices, each representing (centered) three-way -#' partial dependence functions \eqn{F_{jkl}}. -#' Only if threeway calculations have been done. +#' partial dependence functions \eqn{F_{jkl}}. #' - `h2_threeway`: List with numerator and denominator of \eqn{H^2_{jkl}}. -#' Only if threeway calculations have been done. #' @references #' Friedman, Jerome H., and Bogdan E. Popescu. *"Predictive Learning via Rule Ensembles."* #' The Annals of Applied Statistics 2, no. 3 (2008): 916-54. @@ -91,11 +89,11 @@ #' s <- hstats(fit, X = iris[-1]) #' s #' plot(s) -#' plot(s, zero = FALSE) +#' plot(s, zero = FALSE) # Drop 0 interaction rows #' summary(s) #' #' # Absolute pairwise interaction strengths -#' h2_pairwise(s, normalize = FALSE, squared = FALSE, plot = FALSE) +#' h2_pairwise(s, normalize = FALSE, squared = FALSE, plot = FALSE, zero = FALSE) #' #' # MODEL 2: Multi-response linear regression #' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris) @@ -124,7 +122,8 @@ hstats <- function(object, ...) { #' @export hstats.default <- function(object, X, v = colnames(X), pred_fun = stats::predict, n_max = 300L, - w = NULL, pairwise_m = 5L, threeway_m = pairwise_m, + w = NULL, pairwise_m = 5L, + threeway_m = min(pairwise_m, 5L), verbose = TRUE, ...) { basic_check(X = X, v = v, pred_fun = pred_fun, w = w) p <- length(v) @@ -235,7 +234,8 @@ hstats.default <- function(object, X, v = colnames(X), hstats.ranger <- function(object, X, v = colnames(X), pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions, n_max = 300L, w = NULL, pairwise_m = 5L, - threeway_m = pairwise_m, verbose = TRUE, ...) { + threeway_m = min(pairwise_m, 5L), + verbose = TRUE, ...) { hstats.default( object = object, X = X, @@ -255,7 +255,8 @@ hstats.ranger <- function(object, X, v = colnames(X), hstats.Learner <- function(object, X, v = colnames(X), pred_fun = NULL, n_max = 300L, w = NULL, pairwise_m = 5L, - threeway_m = pairwise_m, verbose = TRUE, ...) { + threeway_m = min(pairwise_m, 5L), + verbose = TRUE, ...) { if (is.null(pred_fun)) { pred_fun <- mlr3_pred_fun(object, X = X) } @@ -279,7 +280,8 @@ hstats.explainer <- function(object, X = object[["data"]], v = colnames(X), pred_fun = object[["predict_function"]], n_max = 300L, w = object[["weights"]], - pairwise_m = 5L, threeway_m = pairwise_m, + pairwise_m = 5L, + threeway_m = min(pairwise_m, 5L), verbose = TRUE, ...) { hstats.default( object = object[["model"]], @@ -305,7 +307,7 @@ hstats.explainer <- function(object, X = object[["data"]], #' @export #' @seealso See [hstats()] for examples. print.hstats <- function(x, ...) { - cat("'hstats' object. Run plot() or summary() for details.\n\n") + cat("'hstats' object. Use plot() or summary() for details.\n\n") cat("Proportion of prediction variability unexplained by main effects of v:\n") print(h2(x)) cat("\n") @@ -359,17 +361,19 @@ summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE #' @seealso See [hstats()] for examples. print.summary_hstats <- function(x, ...) { flag <- if (x[["normalize"]]) "relative" else "absolute" + txt <- c( - h2 = "Proportion of prediction variability unexplained by main effects of v", + h2 = "Relative prediction variability unexplained by main effects", h2_overall = sprintf("Strongest %s overall interactions", flag), h2_pairwise = sprintf("Strongest %s pairwise interactions", flag), - h2_threeway = sprintf("Strongest %s three-way interactions", flag) + h2_threeway = sprintf("Strongest %s three-way interaction", flag) ) + top_n <- c(h2 = 1L, h2_overall = 4L, h2_pairwise = 3L, h2_threeway = 1L) for (nm in setdiff(names(x), "normalize")) { - cat(txt[[nm]]) + cat(txt[nm]) cat("\n") - print(utils::head(x[[nm]])) + print(utils::head(x[[nm]], top_n[nm])) cat("\n") } invisible(x) diff --git a/R/pd_importance.R b/R/pd_importance.R index a0a1489b..b60a38dd 100644 --- a/R/pd_importance.R +++ b/R/pd_importance.R @@ -57,9 +57,8 @@ pd_importance.default <- function(object, ...) { #' @describeIn pd_importance PD based feature importance from "hstats" object. #' @export pd_importance.hstats <- function(object, normalize = TRUE, squared = TRUE, - sort = TRUE, top_m = 15L, zero = TRUE, - eps = 1e-8, plot = FALSE, - fill = "#2b51a1", ...) { + sort = TRUE, top_m = 15L, zero = TRUE, eps = 1e-8, + plot = FALSE, fill = "#2b51a1", ...) { num <- with( object, matrix(nrow = length(v), ncol = K, dimnames = list(v, pred_names)) ) diff --git a/man/H2.Rd b/man/H2.Rd index b90443ea..61435b32 100644 --- a/man/H2.Rd +++ b/man/H2.Rd @@ -10,17 +10,13 @@ h2(object, ...) \method{h2}{default}(object, ...) -\method{h2}{hstats}(object, normalize = TRUE, squared = TRUE, eps = 1e-08, ...) +\method{h2}{hstats}(object, eps = 1e-08, ...) } \arguments{ \item{object}{Object of class "hstats".} \item{...}{Currently unused.} -\item{normalize}{Should statistics be normalized? Default is \code{TRUE}.} - -\item{squared}{Should \emph{squared} statistics be returned? Default is \code{TRUE}.} - \item{eps}{Threshold below which numerator values are set to 0.} } \value{ diff --git a/man/H2_overall.Rd b/man/H2_overall.Rd index 67127e99..5b2a1164 100644 --- a/man/H2_overall.Rd +++ b/man/H2_overall.Rd @@ -16,6 +16,7 @@ h2_overall(object, ...) squared = TRUE, sort = TRUE, top_m = 15L, + zero = TRUE, eps = 1e-08, plot = FALSE, fill = "#2b51a1", @@ -36,6 +37,8 @@ h2_overall(object, ...) \item{top_m}{How many rows should be shown? (\code{Inf} to show all.)} +\item{zero}{Should rows with all 0 be shown? Default is \code{TRUE}.} + \item{eps}{Threshold below which numerator values are set to 0.} \item{plot}{Should results be plotted as barplot? Default is \code{FALSE}.} @@ -103,7 +106,7 @@ h2_overall(s, plot = TRUE) # MODEL 2: Multi-response linear regression fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris) s <- hstats(fit, X = iris[3:5], verbose = FALSE) -h2_overall(s, plot = TRUE) +h2_overall(s, plot = TRUE, zero = FALSE) } \references{ Friedman, Jerome H., and Bogdan E. Popescu. \emph{"Predictive Learning via Rule Ensembles."} diff --git a/man/H2_pairwise.Rd b/man/H2_pairwise.Rd index bdc26676..ba3c915f 100644 --- a/man/H2_pairwise.Rd +++ b/man/H2_pairwise.Rd @@ -16,6 +16,7 @@ h2_pairwise(object, ...) squared = TRUE, sort = TRUE, top_m = 15L, + zero = TRUE, eps = 1e-08, plot = FALSE, fill = "#2b51a1", @@ -36,6 +37,8 @@ h2_pairwise(object, ...) \item{top_m}{How many rows should be shown? (\code{Inf} to show all.)} +\item{zero}{Should rows with all 0 be shown? Default is \code{TRUE}.} + \item{eps}{Threshold below which numerator values are set to 0.} \item{plot}{Should results be plotted as barplot? Default is \code{FALSE}.} @@ -110,6 +113,9 @@ s <- hstats(fit, X = iris[-1]) # (for features with strongest overall interactions) h2_pairwise(s) +# Drop 0 +h2_pairwise(s, zero = FALSE) + # Absolute measure as alternative h2_pairwise(s, normalize = FALSE, squared = FALSE) @@ -117,6 +123,7 @@ h2_pairwise(s, normalize = FALSE, squared = FALSE) fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris) s <- hstats(fit, X = iris[3:5], verbose = FALSE) h2_pairwise(s, plot = TRUE) +h2_pairwise(s, zero = FALSE, plot = TRUE) } \references{ Friedman, Jerome H., and Bogdan E. Popescu. \emph{"Predictive Learning via Rule Ensembles."} diff --git a/man/H2_threeway.Rd b/man/H2_threeway.Rd index 4a4eddb1..dcaf3cf5 100644 --- a/man/H2_threeway.Rd +++ b/man/H2_threeway.Rd @@ -16,6 +16,7 @@ h2_threeway(object, ...) squared = TRUE, sort = TRUE, top_m = 15L, + zero = TRUE, eps = 1e-08, plot = FALSE, fill = "#2b51a1", @@ -36,6 +37,8 @@ h2_threeway(object, ...) \item{top_m}{How many rows should be shown? (\code{Inf} to show all.)} +\item{zero}{Should rows with all 0 be shown? Default is \code{TRUE}.} + \item{eps}{Threshold below which numerator values are set to 0.} \item{plot}{Should results be plotted as barplot? Default is \code{FALSE}.} diff --git a/man/hstats.Rd b/man/hstats.Rd index 79c00ce5..a31e41e8 100644 --- a/man/hstats.Rd +++ b/man/hstats.Rd @@ -18,7 +18,7 @@ hstats(object, ...) n_max = 300L, w = NULL, pairwise_m = 5L, - threeway_m = pairwise_m, + threeway_m = min(pairwise_m, 5L), verbose = TRUE, ... ) @@ -31,7 +31,7 @@ hstats(object, ...) n_max = 300L, w = NULL, pairwise_m = 5L, - threeway_m = pairwise_m, + threeway_m = min(pairwise_m, 5L), verbose = TRUE, ... ) @@ -44,7 +44,7 @@ hstats(object, ...) n_max = 300L, w = NULL, pairwise_m = 5L, - threeway_m = pairwise_m, + threeway_m = min(pairwise_m, 5L), verbose = TRUE, ... ) @@ -57,7 +57,7 @@ hstats(object, ...) n_max = 300L, w = object[["weights"]], pairwise_m = 5L, - threeway_m = pairwise_m, + threeway_m = min(pairwise_m, 5L), verbose = TRUE, ... ) @@ -87,12 +87,12 @@ selected from \code{X}. In this case, set a random seed for reproducibility.} \item{pairwise_m}{Number of features for which pairwise statistics are to be calculated. The features are selected based on Friedman and Popescu's overall interaction strength \eqn{H^2_j}. Set to to 0 to avoid pairwise calculations. -For multivariate predictions, the union of the column-wise strongest variable -names is taken. This can lead to very long run-times.} +For multivariate predictions, the union of the \code{pairwise_m} column-wise +strongest variable names is taken. This can lead to very long run-times.} -\item{threeway_m}{Same as \code{pairwise_m}, but controlling the number of features for -which threeway interactions should be calculated. Not larger than \code{pairwise_m}. -Set to 0 to avoid threeway calculations.} +\item{threeway_m}{Like \code{pairwise_m}, but controls the feature count for +three-way interactions. Cannot be larger than \code{pairwise_m}. +The default is \code{min(pairwise_m, 5)}. Set to 0 to avoid three-way calculations.} \item{verbose}{Should a progress bar be shown? The default is \code{TRUE}.} } @@ -103,7 +103,8 @@ An object of class "hstats" containing these elements: \item \code{w}: Input \code{w} (sampled to \code{n_max} values, or \code{NULL}). \item \code{v}: Same as input \code{v}. \item \code{f}: Matrix with (centered) predictions \eqn{F}. -\item \code{mean_f2}: (Weighted) column means of \code{f}. Used to normalize most statistics. +\item \code{mean_f2}: (Weighted) column means of \code{f}. Used to normalize \eqn{H^2} and +\eqn{H^2_j}. \item \code{F_j}: List of matrices, each representing (centered) partial dependence functions \eqn{F_j}. \item \code{F_not_j}: List of matrices with (centered) partial dependence @@ -112,24 +113,23 @@ functions \eqn{F_{\setminus j}} of other features. \item \code{pred_names}: Column names of prediction matrix. \item \code{h2}: List with numerator and denominator of \eqn{H^2}. \item \code{h2_overall}: List with numerator and denominator of \eqn{H^2_j}. -\item \code{v_pairwise}: Subset of \code{v} with largest \code{h2_overall()} used for pairwise +\item \code{v_pairwise}: Subset of \code{v} with largest \eqn{H^2_j} used for pairwise calculations. +\item \code{v_pairwise_0}: Like \code{v_pairwise}, but padded to length \code{pairwise_m}. \item \code{combs2}: Named list of variable pairs for which pairwise partial -dependence functions are available. Only if pairwise calculations have been done. +dependence functions are available. \item \code{F_jk}: List of matrices, each representing (centered) bivariate partial dependence functions \eqn{F_{jk}}. -Only if pairwise calculations have been done. \item \code{h2_pairwise}: List with numerator and denominator of \eqn{H^2_{jk}}. Only if pairwise calculations have been done. \item \code{v_threeway}: Subset of \code{v} with largest \code{h2_overall()} used for three-way calculations. +\item \code{v_threeway_0}: Like \code{v_threeway}, but padded to length \code{threeway_m}. \item \code{combs3}: Named list of variable triples for which three-way partial -dependence functions are available. Only if threeway calculations have been done. +dependence functions are available. \item \code{F_jkl}: List of matrices, each representing (centered) three-way partial dependence functions \eqn{F_{jkl}}. -Only if threeway calculations have been done. \item \code{h2_threeway}: List with numerator and denominator of \eqn{H^2_{jkl}}. -Only if threeway calculations have been done. } } \description{ @@ -149,7 +149,7 @@ see \code{\link[=h2_threeway]{h2_threeway()}} for details. Furthermore, it allows to calculate an experimental partial dependence based measure of feature importance, \eqn{\textrm{PDI}_j^2}. It equals the proportion of prediction variability unexplained by other features, see \code{\link[=pd_importance]{pd_importance()}} -for details. (This statistic is not shown by \code{summary()} or \code{plot()}.) +for details. This statistic is not shown by \code{summary()} or \code{plot()}. Instead of using \code{summary()}, interaction statistics can also be obtained via the more flexible functions \code{\link[=h2]{h2()}}, \code{\link[=h2_overall]{h2_overall()}}, \code{\link[=h2_pairwise]{h2_pairwise()}}, and @@ -172,10 +172,11 @@ fit <- lm(Sepal.Length ~ . + Petal.Width:Species, data = iris) s <- hstats(fit, X = iris[-1]) s plot(s) +plot(s, zero = FALSE) # Drop 0 interaction rows summary(s) # Absolute pairwise interaction strengths -h2_pairwise(s, normalize = FALSE, squared = FALSE, plot = FALSE) +h2_pairwise(s, normalize = FALSE, squared = FALSE, plot = FALSE, zero = FALSE) # MODEL 2: Multi-response linear regression fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris) @@ -192,9 +193,7 @@ summary(s) # On original scale, we have interactions everywhere... s <- hstats(fit, X = iris[-1], type = "response", verbose = FALSE) - -# All three types use different denominators -plot(s, which = 1:3, ncol = 1) +plot(s, which = 1:3, ncol = 1) # All three types use different denominators # All statistics on same scale (of predictions) plot(s, which = 1:3, squared = FALSE, normalize = FALSE, facet_scale = "free_y") diff --git a/man/pd_importance.Rd b/man/pd_importance.Rd index 2da3dc94..10f3e998 100644 --- a/man/pd_importance.Rd +++ b/man/pd_importance.Rd @@ -16,6 +16,7 @@ pd_importance(object, ...) squared = TRUE, sort = TRUE, top_m = 15L, + zero = TRUE, eps = 1e-08, plot = FALSE, fill = "#2b51a1", @@ -36,6 +37,8 @@ pd_importance(object, ...) \item{top_m}{How many rows should be shown? (\code{Inf} to show all.)} +\item{zero}{Should rows with all 0 be shown? Default is \code{TRUE}.} + \item{eps}{Threshold below which numerator values are set to 0.} \item{plot}{Should results be plotted as barplot? Default is \code{FALSE}.} diff --git a/man/plot.hstats.Rd b/man/plot.hstats.Rd index 7cbc7de0..7164cb40 100644 --- a/man/plot.hstats.Rd +++ b/man/plot.hstats.Rd @@ -11,6 +11,7 @@ squared = TRUE, sort = TRUE, top_m = 15L, + zero = TRUE, eps = 1e-08, fill = "#2b51a1", facet_scales = "free", @@ -35,6 +36,8 @@ use \code{1:3}.} \item{top_m}{How many rows should be shown? (\code{Inf} to show all.)} +\item{zero}{Should rows with all 0 be shown? Default is \code{TRUE}.} + \item{eps}{Threshold below which numerator values are set to 0.} \item{fill}{Color of bar (only for univariate statistics).} diff --git a/man/summary.hstats.Rd b/man/summary.hstats.Rd index a8ea4983..3bfa42eb 100644 --- a/man/summary.hstats.Rd +++ b/man/summary.hstats.Rd @@ -10,6 +10,7 @@ squared = TRUE, sort = TRUE, top_m = Inf, + zero = TRUE, eps = 1e-08, ... ) @@ -26,15 +27,20 @@ \item{top_m}{How many rows should be shown? (\code{Inf} to show all.)} +\item{zero}{Should rows with all 0 be shown? Default is \code{TRUE}.} + \item{eps}{Threshold below which numerator values are set to 0.} \item{...}{Currently not used.} } \value{ -An object of class "summary_hstats" representing a named list with statistics. +An object of class "summary_hstats" representing a named list with statistics +"h2", "h2_overall", "h2_pairwise", "h2_threeway", and the input flag "normalize". +Statistics that equal \code{NULL} are omitted from the list. } \description{ -Summary method for "hstats" object. +Summary method for "hstats" object. Note that \eqn{H^2} is not affected by +the arguments \code{normalize} and \code{squared}. } \seealso{ See \code{\link[=hstats]{hstats()}} for examples. From cd352652ba17e7bfd66316913bf23479cad04b8c Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Mon, 25 Sep 2023 19:58:17 +0200 Subject: [PATCH 07/11] Introduce init_numerator() --- R/H2_overall.R | 4 +--- R/H2_pairwise.R | 6 +----- R/H2_threeway.R | 6 +----- R/hstats.R | 40 ++++++++++++---------------------------- R/pd_importance.R | 4 +--- R/utils.R | 45 +++++++++++++++++++++++++++++++++++++++++++++ man/hstats.Rd | 5 +++-- 7 files changed, 64 insertions(+), 46 deletions(-) diff --git a/R/H2_overall.R b/R/H2_overall.R index 8ff25c12..5480caea 100644 --- a/R/H2_overall.R +++ b/R/H2_overall.R @@ -108,11 +108,9 @@ h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = T #' "f", "F_not_j", "F_j", "mean_f2", and "w". #' @returns A list with the numerator and denominator statistics. h2_overall_raw <- function(x) { - num <- with(x, matrix(nrow = length(v), ncol = K, dimnames = list(v, pred_names))) - + num <- init_numerator(x, way = 1L) for (z in x[["v"]]) { num[z, ] <- with(x, wcolMeans((f - F_j[[z]] - F_not_j[[z]])^2, w = w)) } - list(num = num, denom = x[["mean_f2"]]) } diff --git a/R/H2_pairwise.R b/R/H2_pairwise.R index bcbe81e6..93a513d8 100644 --- a/R/H2_pairwise.R +++ b/R/H2_pairwise.R @@ -116,11 +116,7 @@ h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = #' "F_jk", "F_j", and "w". #' @returns A list with the numerator and denominator statistics. h2_pairwise_raw <- function(x) { - # Initialize matrices - cn0 <- combn(x[["v_pairwise_0"]], 2L, FUN = paste, collapse = ":") - num <- with( - x, matrix(0, nrow = length(cn0), ncol = K, dimnames = list(cn0, pred_names)) - ) + num <- init_numerator(x, way = 2L) denom <- num + 1 # Note that F_jk are in the same order as x[["combs2"]] diff --git a/R/H2_threeway.R b/R/H2_threeway.R index f5ec7f99..6e4837e5 100644 --- a/R/H2_threeway.R +++ b/R/H2_threeway.R @@ -99,11 +99,7 @@ h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = #' "F_jkl", "F_jk", "F_j", and "w". #' @returns A list with the numerator and denominator statistics. h2_threeway_raw <- function(x) { - # Initialize matrices - cn0 <- utils::combn(x[["v_threeway_0"]], 3L, FUN = paste, collapse = ":") - num <- with( - x, matrix(0, nrow = length(cn0), ncol = K, dimnames = list(cn0, pred_names)) - ) + num <- init_numerator(x, way = 3L) denom <- num + 1 # Note that the F_jkl are in the same order as x[["combs3"]] diff --git a/R/hstats.R b/R/hstats.R index 7f414b34..15ad4d8f 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -58,11 +58,13 @@ #' functions \eqn{F_{\setminus j}} of other features. #' - `K`: Number of columns of prediction matrix. #' - `pred_names`: Column names of prediction matrix. +#' - `pairwise_m`: Like input `pairwise_m`, but capped at `length(v)`. +#' - `threeway_m`: Like input `threeway_m`, but capped at the smaller of +#' `length(v)` and `pairwise_m`. #' - `h2`: List with numerator and denominator of \eqn{H^2}. #' - `h2_overall`: List with numerator and denominator of \eqn{H^2_j}. #' - `v_pairwise`: Subset of `v` with largest \eqn{H^2_j} used for pairwise #' calculations. -#' - `v_pairwise_0`: Like `v_pairwise`, but padded to length `pairwise_m`. #' - `combs2`: Named list of variable pairs for which pairwise partial #' dependence functions are available. #' - `F_jk`: List of matrices, each representing (centered) bivariate @@ -71,7 +73,6 @@ #' Only if pairwise calculations have been done. #' - `v_threeway`: Subset of `v` with largest `h2_overall()` used for three-way #' calculations. -#' - `v_threeway_0`: Like `v_threeway`, but padded to length `threeway_m`. #' - `combs3`: Named list of variable triples for which three-way partial #' dependence functions are available. #' - `F_jkl`: List of matrices, each representing (centered) three-way @@ -127,12 +128,9 @@ hstats.default <- function(object, X, v = colnames(X), verbose = TRUE, ...) { basic_check(X = X, v = v, pred_fun = pred_fun, w = w) p <- length(v) - stopifnot( - p >= 2L, - threeway_m <= pairwise_m - ) + stopifnot(p >= 2L) pairwise_m <- min(pairwise_m, p) - threeway_m <- min(threeway_m, p) + threeway_m <- min(threeway_m, pairwise_m, p) # Reduce size of X (and w) if (nrow(X) > n_max) { @@ -197,7 +195,9 @@ hstats.default <- function(object, X, v = colnames(X), F_j = F_j, F_not_j = F_not_j, K = ncol(f), - pred_names = colnames(f) + pred_names = colnames(f), + pairwise_m = pairwise_m, + threeway_m = threeway_m ) # 0-way and 1-way stats @@ -206,8 +206,7 @@ hstats.default <- function(object, X, v = colnames(X), h2_ov <- .zap_small(out$h2_overall$num, eps = 1e-8) # Does eps need to be passed? if (pairwise_m >= 2L) { - out[c("v_pairwise", "v_pairwise_0")] <- get_v(h2_ov, m = pairwise_m) - v2 <- out[["v_pairwise"]] + out[["v_pairwise"]] <- v2 <- get_v(h2_ov, m = pairwise_m) if (length(v2) >= 2L) { out[c("combs2", "F_jk")] <- mway( object, v = v2, X = X, pred_fun = pred_fun, w = w, way = 2L, verb = verbose, ... @@ -216,8 +215,7 @@ hstats.default <- function(object, X, v = colnames(X), out[["h2_pairwise"]] <- h2_pairwise_raw(out) } if (threeway_m >= 3L) { - out[c("v_threeway", "v_threeway_0")] <- get_v(h2_ov, m = threeway_m) - v3 <- out[["v_threeway"]] + out[["v_threeway"]] <- v3 <- get_v(h2_ov, m = threeway_m) if (length(v3) >= 3L) { out[c("combs3", "F_jkl")] <- mway( object, v = v3, X = X, pred_fun = pred_fun, w = w, way = 3L, verb = verbose, ... @@ -498,29 +496,15 @@ mway <- function(object, v, X, pred_fun = stats::predict, w = NULL, #' @param H Unnormalized, unsorted H2_j values. #' @param m Number of features to pick per column. #' -#' @returns A list with two vectors of feature names. The first contains -#' only the m most important features (union over columns), while the first -#' is a padded version of length m. It is necessary in cases where less than m -#' features show interactions. +#' @returns A vector of the union of the m column-wise most important features. get_v <- function(H, m) { v <- rownames(H) - - # Get m strongest features per column selector <- function(vv) names(utils::head(sort(-vv[vv > 0]), m)) if (NCOL(H) == 1L) { v_cand <- selector(drop(H)) } else { v_cand <- Reduce(union, lapply(asplit(H, MARGIN = 2L), FUN = selector)) } - - # Do we need to add some features without interactions? - m_miss <- m - length(v_cand) - if (m_miss > 0L) { - v_cand_0 <- c(v_cand, utils::head(setdiff(v, v_cand), m_miss)) - } else { - v_cand_0 <- v_cand - } - # Bring vectors into same order as v - return(list(v[v %in% v_cand], v[v %in% v_cand_0])) + v[v %in% v_cand] } diff --git a/R/pd_importance.R b/R/pd_importance.R index b60a38dd..a9c60899 100644 --- a/R/pd_importance.R +++ b/R/pd_importance.R @@ -59,9 +59,7 @@ pd_importance.default <- function(object, ...) { pd_importance.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, top_m = 15L, zero = TRUE, eps = 1e-8, plot = FALSE, fill = "#2b51a1", ...) { - num <- with( - object, matrix(nrow = length(v), ncol = K, dimnames = list(v, pred_names)) - ) + num <- init_numerator(object, way = 1L) for (z in object[["v"]]) { num[z, ] <- with(object, wcolMeans((f - F_not_j[[z]])^2, w = w)) } diff --git a/R/utils.R b/R/utils.R index 481e536e..03f082b8 100644 --- a/R/utils.R +++ b/R/utils.R @@ -323,6 +323,51 @@ mat2df <- function(mat, id = "Overall") { poor_man_stack(out, to_stack = pred_names) } +#' Initializor of Numerator Statistics +#' +#' Internal helper function that returns a matrix of all zeros with the right +#' column and row names for statistics of any "way". If some features have been dropped +#' from the statistics calculations, they are added as 0. +#' +#' @noRd +#' @keywords internal +#' @param x A list containing the elements "v", "K", "pred_names", "v_pairwise", +#' "v_threeway", "pairwise_m", "threeway_m". +#' @param way Integer between 1 and 3 of the order of the interaction. +#' @returns A matrix of all zeros. +init_numerator <- function(x, way = 1L) { + stopifnot(way %in% 1:3) + + v <- x[["v"]] + K <- x[["K"]] + pred_names <- x[["pred_names"]] + + # Simple case + if (way == 1L) { + return(matrix(nrow = length(v), ncol = K, dimnames = list(v, pred_names))) + } + + # Determine v_cand_0, which is v_cand with additional features to end up with length m + if (way == 2L) { + v_cand <- x[["v_pairwise"]] + m <- x[["pairwise_m"]] + } else { + v_cand <- x[["v_threeway"]] + m <- x[["threeway_m"]] + } + m_miss <- m - length(v_cand) + if (m_miss > 0L) { + v_cand_0 <- c(v_cand, utils::head(setdiff(v, v_cand), m_miss)) + v_cand_0 <- v[v %in% v_cand_0] # Bring into order of v + } else { + v_cand_0 <- v_cand + } + + # Get all interactions of order "way" + cn0 <- combn(v_cand_0, m = way, FUN = paste, collapse = ":") + matrix(0, nrow = length(cn0), ncol = K, dimnames = list(cn0, pred_names)) +} + #' Bin into Quantiles #' #' Internal function. Applies [cut()] to quantile breaks. diff --git a/man/hstats.Rd b/man/hstats.Rd index a31e41e8..224cc32e 100644 --- a/man/hstats.Rd +++ b/man/hstats.Rd @@ -111,11 +111,13 @@ partial dependence functions \eqn{F_j}. functions \eqn{F_{\setminus j}} of other features. \item \code{K}: Number of columns of prediction matrix. \item \code{pred_names}: Column names of prediction matrix. +\item \code{pairwise_m}: Like input \code{pairwise_m}, but capped at \code{length(v)}. +\item \code{threeway_m}: Like input \code{threeway_m}, but capped at the smaller of +\code{length(v)} and \code{pairwise_m}. \item \code{h2}: List with numerator and denominator of \eqn{H^2}. \item \code{h2_overall}: List with numerator and denominator of \eqn{H^2_j}. \item \code{v_pairwise}: Subset of \code{v} with largest \eqn{H^2_j} used for pairwise calculations. -\item \code{v_pairwise_0}: Like \code{v_pairwise}, but padded to length \code{pairwise_m}. \item \code{combs2}: Named list of variable pairs for which pairwise partial dependence functions are available. \item \code{F_jk}: List of matrices, each representing (centered) bivariate @@ -124,7 +126,6 @@ partial dependence functions \eqn{F_{jk}}. Only if pairwise calculations have been done. \item \code{v_threeway}: Subset of \code{v} with largest \code{h2_overall()} used for three-way calculations. -\item \code{v_threeway_0}: Like \code{v_threeway}, but padded to length \code{threeway_m}. \item \code{combs3}: Named list of variable triples for which three-way partial dependence functions are available. \item \code{F_jkl}: List of matrices, each representing (centered) three-way From 7d1c657fa44b64dd85865558f3db0ab715e882cc Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Mon, 25 Sep 2023 20:32:07 +0200 Subject: [PATCH 08/11] fix problem in plot.hstats() with all 0 values --- R/hstats.R | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/R/hstats.R b/R/hstats.R index 15ad4d8f..f64048cd 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -406,14 +406,14 @@ plot.hstats <- function(x, which = 1:2, normalize = TRUE, squared = TRUE, sort = eps = eps ) - nms <- c(Overall = "h2_overall", Pairwise = "h2_pairwise", Threeway = "h2_threeway") - su <- su[nms[which]] - - if (length(su) == 0L) { + # This part could be simplified, especially the "match()" + stat_names <- c("h2_overall", "h2_pairwise", "h2_threeway")[which] + stat_labs <- c("Overall", "Pairwise", "Three-way")[which] + ok <- stat_names[stat_names %in% names(su)] + if (length(ok) == 0L) { return(NULL) } - - dat <- lapply(names(su), FUN = function(nm) mat2df(su[[nm]], id = nm)) + dat <- lapply(ok, FUN = function(nm) mat2df(su[[nm]], id = stat_labs[match(nm, stat_names)])) dat <- do.call(rbind, dat) p <- ggplot2::ggplot(dat, ggplot2::aes(x = value_, y = variable_)) + From f6472485829d30f789e8431900d4b24047a14e2b Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Tue, 26 Sep 2023 22:19:10 +0200 Subject: [PATCH 09/11] Started with tests_hstats.R --- R/H2_pairwise.R | 4 +--- R/hstats.R | 2 +- man/H2_pairwise.Rd | 4 +--- man/hstats.Rd | 2 +- tests/testthat/test_hstats.R | 40 ++++++++++++++++++++++++++++-------- tests/testthat/test_utils.R | 6 ++++++ 6 files changed, 42 insertions(+), 16 deletions(-) diff --git a/R/H2_pairwise.R b/R/H2_pairwise.R index 93a513d8..977aadff 100644 --- a/R/H2_pairwise.R +++ b/R/H2_pairwise.R @@ -61,9 +61,7 @@ #' # Proportion of joint effect coming from pairwise interaction #' # (for features with strongest overall interactions) #' h2_pairwise(s) -#' -#' # Drop 0 -#' h2_pairwise(s, zero = FALSE) +#' h2_pairwise(s, zero = FALSE) # Drop 0 #' #' # Absolute measure as alternative #' h2_pairwise(s, normalize = FALSE, squared = FALSE) diff --git a/R/hstats.R b/R/hstats.R index f64048cd..5d3cdb70 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -90,7 +90,7 @@ #' s <- hstats(fit, X = iris[-1]) #' s #' plot(s) -#' plot(s, zero = FALSE) # Drop 0 interaction rows +#' plot(s, zero = FALSE) # Drop 0 #' summary(s) #' #' # Absolute pairwise interaction strengths diff --git a/man/H2_pairwise.Rd b/man/H2_pairwise.Rd index ba3c915f..ac912b04 100644 --- a/man/H2_pairwise.Rd +++ b/man/H2_pairwise.Rd @@ -112,9 +112,7 @@ s <- hstats(fit, X = iris[-1]) # Proportion of joint effect coming from pairwise interaction # (for features with strongest overall interactions) h2_pairwise(s) - -# Drop 0 -h2_pairwise(s, zero = FALSE) +h2_pairwise(s, zero = FALSE) # Drop 0 # Absolute measure as alternative h2_pairwise(s, normalize = FALSE, squared = FALSE) diff --git a/man/hstats.Rd b/man/hstats.Rd index 224cc32e..560c828e 100644 --- a/man/hstats.Rd +++ b/man/hstats.Rd @@ -173,7 +173,7 @@ fit <- lm(Sepal.Length ~ . + Petal.Width:Species, data = iris) s <- hstats(fit, X = iris[-1]) s plot(s) -plot(s, zero = FALSE) # Drop 0 interaction rows +plot(s, zero = FALSE) # Drop 0 summary(s) # Absolute pairwise interaction strengths diff --git a/tests/testthat/test_hstats.R b/tests/testthat/test_hstats.R index e7d38fc4..078a2473 100644 --- a/tests/testthat/test_hstats.R +++ b/tests/testthat/test_hstats.R @@ -1,30 +1,47 @@ test_that("Additive models show 0 interactions (univariate)", { fit <- lm(Sepal.Length ~ ., data = iris) s <- hstats(fit, X = iris[-1L], verbose = FALSE) - expect_null(h2_pairwise(s)) - expect_null(h2_threeway(s)) + expect_null(h2_pairwise(s, zero = FALSE)) + expect_equal(c(h2_pairwise(s, sort = FALSE, top_m = Inf)), rep(0, choose(4, 2))) + + expect_null(h2_threeway(s, zero = FALSE)) + expect_equal(c(h2_threeway(s, sort = FALSE, top_m = Inf)), rep(0, choose(4, 3))) + expect_equal( - h2_overall(s, plot = FALSE), + h2_overall(s), matrix(c(0, 0, 0, 0), ncol = 1L, dimnames = list(colnames(iris[-1L]), NULL)) ) + expect_null(h2_overall(s, zero = FALSE)) + expect_equal(h2(s), 0) + expect_s3_class(h2_overall(s, plot = TRUE), "ggplot") expect_s3_class(plot(s, rotate_x = TRUE), "ggplot") + expect_null(h2_overall(s, zero = FALSE, plot = TRUE)) }) test_that("Additive models show 0 interactions (multivariate)", { fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris) s <- hstats(fit, X = iris[3:5], verbose = FALSE) - expect_null(h2_pairwise(s)) - expect_null(h2_threeway(s)) + + expect_null(h2_pairwise(s, zero = FALSE)) + expect_true(all(h2_pairwise(s) == 0)) + + expect_null(h2_threeway(s, zero = FALSE)) + expect_equal(unname(h2_threeway(s)), cbind(0, 0)) + expect_equal( - h2_overall(s, plot = FALSE), + h2_overall(s, sort = FALSE), matrix( 0, ncol = 2L, nrow = 3L, dimnames = list(colnames(iris[3:5]), colnames(iris[1:2])) ) ) + expect_null(h2_overall(s, zero = FALSE)) + expect_equal(h2(s), c(Sepal.Length = 0, Sepal.Width = 0)) + expect_s3_class(h2_overall(s, plot = TRUE), "ggplot") + expect_null(h2_overall(s, zero = FALSE, plot = TRUE)) expect_s3_class(plot(s), "ggplot") }) @@ -37,14 +54,19 @@ test_that("Non-additive models show interactions > 0 (one interaction)", { expect_true( all(rownames(out[out > 0, , drop = FALSE]) %in% c("Petal.Length", "Petal.Width")) ) + out <- h2_overall(s, zero = FALSE, sort = FALSE) + expect_true(all(rownames(out) %in% c("Petal.Length", "Petal.Width"))) - out <- h2_pairwise(s) + out <- h2_pairwise(s, zero = FALSE) expect_equal(rownames(out), "Petal.Length:Petal.Width") + out <- h2_pairwise(s) + expect_equal(rownames(out[out > 0, , drop = FALSE]), "Petal.Length:Petal.Width") expect_s3_class(h2_overall(s, plot = TRUE), "ggplot") expect_s3_class(h2_pairwise(s, plot = TRUE), "ggplot") expect_s3_class(plot(s), "ggplot") - expect_null(h2_threeway(s)) + expect_null(h2_threeway(s, zero = FALSE)) + expect_equal(c(h2_threeway(s)), rep(0, times = choose(4, 3))) }) fit <- lm( @@ -52,6 +74,8 @@ fit <- lm( ) s <- hstats(fit, X = iris[-1L], verbose = FALSE) + +# CONTINUE HERE test_that("Non-additive models show interactions > 0 (two interactions)", { expect_true(h2(s) > 0) diff --git a/tests/testthat/test_utils.R b/tests/testthat/test_utils.R index 18bfd0ff..4ad8db0f 100644 --- a/tests/testthat/test_utils.R +++ b/tests/testthat/test_utils.R @@ -238,6 +238,9 @@ test_that("postprocess() works for matrix input", { postprocess(num = num, denom = 1:2, sort = FALSE), num / cbind(c(1, 1, 1), c(2, 2, 2)) ) + + expect_equal(postprocess(num = cbind(0:1, 0:1), zero = FALSE), rbind(c(1, 1))) + expect_null(postprocess(num = cbind(0, 0), zero = FALSE)) }) test_that("postprocess() works for vector input", { @@ -249,6 +252,9 @@ test_that("postprocess() works for vector input", { expect_equal(postprocess(num = num, sort = FALSE), num) expect_equal(postprocess(num = num, sort = FALSE, top_m = 2), num[1:2]) expect_equal(postprocess(num = num, squared = FALSE), sqrt(num[3:1])) + + expect_equal(postprocess(num = 0:1, denom = c(2, 2), zero = FALSE), 0.5) + expect_null(postprocess(num = 0, zero = FALSE)) }) test_that(".zap_small() works for vector input", { From 51fdb988abd5192a0c9813ee79a2035519b24888 Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Wed, 27 Sep 2023 07:39:48 +0200 Subject: [PATCH 10/11] Undo removal of squared and normalize in h2() --- NEWS.md | 7 +++---- R/H2.R | 4 +++- R/hstats.R | 4 ++-- man/H2.Rd | 6 +++++- tests/testthat/test_hstats.R | 22 +++++++++++++++++----- 5 files changed, 30 insertions(+), 13 deletions(-) diff --git a/NEWS.md b/NEWS.md index 0e3865b7..535342fd 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,10 +2,9 @@ ## Visible changes -- `ice()` and `partial_dep()`: So far, the default grid strategy "uniform" used `pretty()` to generate the evaluation points. To provide more predictable grid sizes, and to be more in line with other implementations of partial dependence and ICE, we now use `seq()` to create the uniform grid. -- `h2()` will now always show the normalized $H^2$. Its options `normalize` and `squared` have been removed. -- `h2_pairwise()` and `h2_threeway()` will now also include (some) combinations with value 0. Use `zero = FALSE` to drop them, see below. -- `hstats()`: The default number of features considered for *three-way interactions* has been changed from `threeway_m = pairwise_m` to the more cautious `threeway_m = min(pairwise_m, 5L)`. +- Grid of `ice()` and `partial_dep()`: So far, the default grid strategy "uniform" used `pretty()` to generate the evaluation points. To provide more predictable grid sizes, and to be more in line with other implementations of partial dependence and ICE, we now use `seq()` to create the uniform grid. +- `h2_pairwise()` and `h2_threeway()` will now also include 0 values. Use `zero = FALSE` to drop them, see below. The padding with 0 is done at no computational cost, and will affect only up to `pairwise_m` and `threeway_m` features. +- `hstats()`: The default number of features considered for *three-way interactions* has been changed from `threeway_m = pairwise_m` to the more cautious `threeway_m = min(pairwise_m, 5L)`. Furthermore, `threeway_m` is capped at `pairwise_m`. - The `print()` method of `summary.hstats()` is less verbose. ## Improvements diff --git a/R/H2.R b/R/H2.R index ff9e0344..5f27dbc2 100644 --- a/R/H2.R +++ b/R/H2.R @@ -65,10 +65,12 @@ h2.default <- function(object, ...) { #' @describeIn h2 Total interaction strength from "interact" object. #' @export -h2.hstats <- function(object, eps = 1e-8, ...) { +h2.hstats <- function(object, normalize = TRUE, squared = TRUE, eps = 1e-8, ...) { postprocess( num = object$h2$num, denom = object$h2$denom, + normalize = normalize, + squared = squared, sort = FALSE, eps = eps ) diff --git a/R/hstats.R b/R/hstats.R index 5d3cdb70..b87c9283 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -338,7 +338,7 @@ summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE plot = FALSE ) out <- list( - h2 = h2(object, eps = eps), + h2 = h2(object, normalize = normalize, squared = squared, eps = eps), h2_overall = do.call(h2_overall, args), h2_pairwise = do.call(h2_pairwise, args), h2_threeway = do.call(h2_threeway, args), @@ -361,7 +361,7 @@ print.summary_hstats <- function(x, ...) { flag <- if (x[["normalize"]]) "relative" else "absolute" txt <- c( - h2 = "Relative prediction variability unexplained by main effects", + h2 = "Prediction variability unexplained by main effects", h2_overall = sprintf("Strongest %s overall interactions", flag), h2_pairwise = sprintf("Strongest %s pairwise interactions", flag), h2_threeway = sprintf("Strongest %s three-way interaction", flag) diff --git a/man/H2.Rd b/man/H2.Rd index 61435b32..b90443ea 100644 --- a/man/H2.Rd +++ b/man/H2.Rd @@ -10,13 +10,17 @@ h2(object, ...) \method{h2}{default}(object, ...) -\method{h2}{hstats}(object, eps = 1e-08, ...) +\method{h2}{hstats}(object, normalize = TRUE, squared = TRUE, eps = 1e-08, ...) } \arguments{ \item{object}{Object of class "hstats".} \item{...}{Currently unused.} +\item{normalize}{Should statistics be normalized? Default is \code{TRUE}.} + +\item{squared}{Should \emph{squared} statistics be returned? Default is \code{TRUE}.} + \item{eps}{Threshold below which numerator values are set to 0.} } \value{ diff --git a/tests/testthat/test_hstats.R b/tests/testthat/test_hstats.R index 078a2473..f8926945 100644 --- a/tests/testthat/test_hstats.R +++ b/tests/testthat/test_hstats.R @@ -74,8 +74,6 @@ fit <- lm( ) s <- hstats(fit, X = iris[-1L], verbose = FALSE) - -# CONTINUE HERE test_that("Non-additive models show interactions > 0 (two interactions)", { expect_true(h2(s) > 0) @@ -84,18 +82,27 @@ test_that("Non-additive models show interactions > 0 (two interactions)", { rownames(out[out > 0, , drop = FALSE]), c("Petal.Length", "Petal.Width", "Species") ) + out <- h2_overall(s, sort = FALSE, normalize = FALSE, squared = FALSE, zero = FALSE) + expect_equal( + rownames(out), c("Petal.Length", "Petal.Width", "Species") + ) out <- h2_pairwise(s, sort = FALSE, normalize = FALSE, squared = FALSE) expect_equal( rownames(out[out > 0, , drop = FALSE]), c("Petal.Length:Petal.Width", "Petal.Length:Species") ) + out <- h2_pairwise(s, sort = FALSE, normalize = FALSE, squared = FALSE, zero = FALSE) + expect_equal( + rownames(out), c("Petal.Length:Petal.Width", "Petal.Length:Species") + ) expect_s3_class(h2_overall(s, plot = TRUE), "ggplot") expect_s3_class(h2_pairwise(s, plot = TRUE), "ggplot") expect_s3_class(h2_threeway(s, plot = TRUE), "ggplot") + expect_null(h2_threeway(s, plot = TRUE, zero = FALSE)) - expect_equal(c(h2_threeway(s)), 0) + expect_equal(c(h2_threeway(s)), rep(0, choose(4, 3))) }) test_that("passing v works", { @@ -139,7 +146,7 @@ test_that("Stronger interactions get higher statistics", { expect_true( all(h2_overall(int2, top_m = 2L) > h2_overall(int1, top_m = 2L)) ) - expect_true(h2_pairwise(int2) > h2_pairwise(int1)) + expect_true(h2_pairwise(int2, zero = FALSE) > h2_pairwise(int1, zero = FALSE)) }) test_that("subsampling has an effect", { @@ -180,12 +187,17 @@ test_that("Three-way interaction behaves correctly across dimensions", { expect_equal(2 * out[, "up"], out[, "up2"]) }) -test_that("Three-way interaction can be suppressed", { +test_that("Pairwise and three-way interactions can be suppressed", { fit <- lm(uptake ~ Type * Treatment * conc, data = CO2) s <- hstats(fit, X = CO2[2:4], verbose = FALSE, threeway_m = 0L) expect_null(h2_threeway(s)) + s <- hstats(fit, X = CO2[2:4], verbose = FALSE, pairwise_m = 2L) + expect_equal(nrow(h2_pairwise(s)), 1L) + expect_null(h2_threeway(s)) + s <- hstats(fit, X = CO2[2:4], verbose = FALSE, pairwise_m = 0L) + expect_equal(nrow(h2_overall(s)), 3L) expect_null(h2_pairwise(s)) expect_null(h2_threeway(s)) }) From dba213e19a27e5fab3baff2eb00db23bb40538e3 Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Wed, 27 Sep 2023 20:26:24 +0200 Subject: [PATCH 11/11] turn result of combn() into a vector --- R/utils.R | 4 ++-- tests/testthat/test_hstats.R | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/R/utils.R b/R/utils.R index 03f082b8..5b32e278 100644 --- a/R/utils.R +++ b/R/utils.R @@ -363,8 +363,8 @@ init_numerator <- function(x, way = 1L) { v_cand_0 <- v_cand } - # Get all interactions of order "way" - cn0 <- combn(v_cand_0, m = way, FUN = paste, collapse = ":") + # Get all interactions of order "way". c() turns the array into a vector + cn0 <- c(utils::combn(v_cand_0, m = way, FUN = paste, collapse = ":")) matrix(0, nrow = length(cn0), ncol = K, dimnames = list(cn0, pred_names)) } diff --git a/tests/testthat/test_hstats.R b/tests/testthat/test_hstats.R index f8926945..4b18deef 100644 --- a/tests/testthat/test_hstats.R +++ b/tests/testthat/test_hstats.R @@ -221,7 +221,10 @@ test_that("Statistics react on normalize, (sorting), squaring, and top m", { s$h2_overall$num[1:2, , drop = FALSE] ) - expect_identical(h2_pairwise(s, sort = FALSE), s$h2_pairwise$num / s$h2_pairwise$denom) + expect_identical( + h2_pairwise(s, sort = FALSE), + s$h2_pairwise$num / s$h2_pairwise$denom + ) expect_identical(h2_pairwise(s, sort = FALSE, normalize = FALSE), s$h2_pairwise$num) expect_identical( h2_pairwise(s, sort = FALSE, normalize = FALSE, squared = FALSE),