diff --git a/NEWS.md b/NEWS.md index a29b186c..535342fd 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,15 +1,15 @@ # hstats 0.3.0 -## Major user visible changes +## Visible changes -- Grid calculation: So far, the default grid strategy "uniform" used `pretty()` to generate the evaluation points. To provide more predictable grid sizes, and to be more in line with other implementations of partial dependence and ICE, we now use `seq()` to create the uniform grid. This affects `ice()`, `partial_dep()` and the exported helper functions `univariate_grid()` and `multivariate_grid()`. +- Grid of `ice()` and `partial_dep()`: So far, the default grid strategy "uniform" used `pretty()` to generate the evaluation points. To provide more predictable grid sizes, and to be more in line with other implementations of partial dependence and ICE, we now use `seq()` to create the uniform grid. +- `h2_pairwise()` and `h2_threeway()` will now also include 0 values. Use `zero = FALSE` to drop them, see below. The padding with 0 is done at no computational cost, and will affect only up to `pairwise_m` and `threeway_m` features. +- `hstats()`: The default number of features considered for *three-way interactions* has been changed from `threeway_m = pairwise_m` to the more cautious `threeway_m = min(pairwise_m, 5L)`. Furthermore, `threeway_m` is capped at `pairwise_m`. +- The `print()` method of `summary.hstats()` is less verbose. -## Internal major changes - -- All available H-statistics are now calculated within `hstats()` and attached to the resulting object. Each statistic is stored as list with numerator and denominator matrices/vectors. The functions `h2()`, `h2_overall()`, `h2_pairwise()`, and `h2_threeway()`, `print.hstats()`, `summary().hstats()`, `plot.hstats()` will use these without having to recalculate the required numerators and denominators. The results, however, are unchanged. - -## Minor improvements +## Improvements +- `h2_overall()`, `h2_pairwise()`, `h2_threeway()`, `plot.hstats()`, and `summary.hstats()` have received an argument `zero = TRUE`. Set to `FALSE` to drop statistics having value 0. - `perm_importance()` and `average_loss()` will now recycle a univariate response when combined with multivariate predictions. This is useful, e.g., when the prediction function represents the predictions of multiple models that should be evaluated against a common response. ## Bug fixes @@ -17,6 +17,10 @@ - All progress bars were initialized 1 step too late. - `perm_importance()` and `average_loss()` would fail for "mlogloss" in case the response `y` was univariate *and* non-factor/non-character. +## Other changes + +- All available H-statistics are now calculated within `hstats()` and attached to the resulting object. Each statistic is stored as list with numerator and denominator matrices/vectors. The functions `h2()`, `h2_overall()`, `h2_pairwise()`, and `h2_threeway()`, `print.hstats()`, `summary().hstats()`, `plot.hstats()` will use these without having to recalculate the required numerators and denominators. The results, however, are unchanged. + # hstats 0.2.0 ## New major features diff --git a/R/H2.R b/R/H2.R index c4bc96cd..5f27dbc2 100644 --- a/R/H2.R +++ b/R/H2.R @@ -70,8 +70,8 @@ h2.hstats <- function(object, normalize = TRUE, squared = TRUE, eps = 1e-8, ...) num = object$h2$num, denom = object$h2$denom, normalize = normalize, - squared = squared, - sort = FALSE, + squared = squared, + sort = FALSE, eps = eps ) } diff --git a/R/H2_overall.R b/R/H2_overall.R index a4469ef3..5480caea 100644 --- a/R/H2_overall.R +++ b/R/H2_overall.R @@ -44,6 +44,7 @@ #' @param sort Should results be sorted? Default is `TRUE`. #' (Multioutput is sorted by row means.) #' @param top_m How many rows should be shown? (`Inf` to show all.) +#' @param zero Should rows with all 0 be shown? Default is `TRUE`. #' @param eps Threshold below which numerator values are set to 0. #' @param plot Should results be plotted as barplot? Default is `FALSE`. #' @param fill Color of bar (only for univariate statistics). @@ -64,7 +65,7 @@ #' # MODEL 2: Multi-response linear regression #' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris) #' s <- hstats(fit, X = iris[3:5], verbose = FALSE) -#' h2_overall(s, plot = TRUE) +#' h2_overall(s, plot = TRUE, zero = FALSE) h2_overall <- function(object, ...) { UseMethod("h2_overall") } @@ -78,8 +79,8 @@ h2_overall.default <- function(object, ...) { #' @describeIn h2_overall Overall interaction strength from "hstats" object. #' @export h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = 15L, eps = 1e-8, plot = FALSE, fill = "#2b51a1", - ...) { + top_m = 15L, zero = TRUE, eps = 1e-8, + plot = FALSE, fill = "#2b51a1", ...) { s <- object$h2_overall out <- postprocess( num = s$num, @@ -87,7 +88,8 @@ h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = T normalize = normalize, squared = squared, sort = sort, - top_m = top_m, + top_m = top_m, + zero = zero, eps = eps ) if (plot) plot_stat(out, fill = fill, ...) else out @@ -106,11 +108,9 @@ h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = T #' "f", "F_not_j", "F_j", "mean_f2", and "w". #' @returns A list with the numerator and denominator statistics. h2_overall_raw <- function(x) { - num <- with(x, matrix(nrow = length(v), ncol = K, dimnames = list(v, pred_names))) - + num <- init_numerator(x, way = 1L) for (z in x[["v"]]) { num[z, ] <- with(x, wcolMeans((f - F_j[[z]] - F_not_j[[z]])^2, w = w)) } - list(num = num, denom = x[["mean_f2"]]) } diff --git a/R/H2_pairwise.R b/R/H2_pairwise.R index b31bd1bf..977aadff 100644 --- a/R/H2_pairwise.R +++ b/R/H2_pairwise.R @@ -61,6 +61,7 @@ #' # Proportion of joint effect coming from pairwise interaction #' # (for features with strongest overall interactions) #' h2_pairwise(s) +#' h2_pairwise(s, zero = FALSE) # Drop 0 #' #' # Absolute measure as alternative #' h2_pairwise(s, normalize = FALSE, squared = FALSE) @@ -69,6 +70,7 @@ #' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris) #' s <- hstats(fit, X = iris[3:5], verbose = FALSE) #' h2_pairwise(s, plot = TRUE) +#' h2_pairwise(s, zero = FALSE, plot = TRUE) h2_pairwise <- function(object, ...) { UseMethod("h2_pairwise") } @@ -82,8 +84,8 @@ h2_pairwise.default <- function(object, ...) { #' @describeIn h2_pairwise Pairwise interaction strength from "hstats" object. #' @export h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = 15L, eps = 1e-8, plot = FALSE, - fill = "#2b51a1", ...) { + top_m = 15L, zero = TRUE, eps = 1e-8, + plot = FALSE, fill = "#2b51a1", ...) { s <- object$h2_pairwise if (is.null(s)) { return(NULL) @@ -94,7 +96,8 @@ h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = normalize = normalize, squared = squared, sort = sort, - top_m = top_m, + top_m = top_m, + zero = zero, eps = eps ) if (plot) plot_stat(out, fill = fill, ...) else out @@ -107,21 +110,21 @@ h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = #' #' @noRd #' @keywords internal -#' @param x A list containing the elements "combs2", "K", "pred_names", +#' @param x A list containing the elements "combs2", "v_pairwise_0", "K", "pred_names", #' "F_jk", "F_j", and "w". #' @returns A list with the numerator and denominator statistics. h2_pairwise_raw <- function(x) { + num <- init_numerator(x, way = 2L) + denom <- num + 1 + + # Note that F_jk are in the same order as x[["combs2"]] combs <- x[["combs2"]] - - # Note that F_jk are in the same order as combs - num <- denom <- with( - x, matrix(nrow = length(combs), ncol = K, dimnames = list(names(combs), pred_names)) - ) - - for (i in seq_along(combs)) { - z <- combs[[i]] - num[i, ] <- with(x, wcolMeans((F_jk[[i]] - F_j[[z[1L]]] - F_j[[z[2L]]])^2, w = w)) - denom[i, ] <- with(x, wcolMeans(F_jk[[i]]^2, w = w)) + if (!is.null(combs)) { + for (nm in names(combs)) { + z <- combs[[nm]] + num[nm, ] <- with(x, wcolMeans((F_jk[[nm]] - F_j[[z[1L]]] - F_j[[z[2L]]])^2, w = w)) + denom[nm, ] <- with(x, wcolMeans(F_jk[[nm]]^2, w = w)) + } } list(num = num, denom = denom) diff --git a/R/H2_threeway.R b/R/H2_threeway.R index 5ab87492..6e4837e5 100644 --- a/R/H2_threeway.R +++ b/R/H2_threeway.R @@ -69,8 +69,8 @@ h2_threeway.default <- function(object, ...) { #' @describeIn h2_threeway Pairwise interaction strength from "hstats" object. #' @export h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = 15L, eps = 1e-8, plot = FALSE, - fill = "#2b51a1", ...) { + top_m = 15L, zero = TRUE, eps = 1e-8, + plot = FALSE, fill = "#2b51a1", ...) { s <- object$h2_threeway if (is.null(s)) { return(NULL) @@ -81,7 +81,8 @@ h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = normalize = normalize, squared = squared, sort = sort, - top_m = top_m, + top_m = top_m, + zero = zero, eps = eps ) if (plot) plot_stat(out, fill = fill, ...) else out @@ -94,26 +95,27 @@ h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = #' #' @noRd #' @keywords internal -#' @param x A list containing the elements "combs3", "K", "pred_names", +#' @param x A list containing the elements "combs3", "v_threeway_0", "K", "pred_names", #' "F_jkl", "F_jk", "F_j", and "w". #' @returns A list with the numerator and denominator statistics. h2_threeway_raw <- function(x) { - combs <- x[["combs3"]] - - # Note that the F_jkl are in the same order as combs - num <- denom <- with( - x, matrix(nrow = length(combs), ncol = K, dimnames = list(names(combs), pred_names)) - ) + num <- init_numerator(x, way = 3L) + denom <- num + 1 - for (i in seq_along(combs)) { - z <- combs[[i]] - zz <- sapply(utils::combn(z, 2L, simplify = FALSE), paste, collapse = ":") - - num[i, ] <- with( - x, wcolMeans((F_jkl[[i]] - Reduce("+", F_jk[zz]) + Reduce("+", F_j[z]))^2, w = w) - ) - denom[i, ] <- with(x, wcolMeans(F_jkl[[i]]^2, w = w)) + # Note that the F_jkl are in the same order as x[["combs3"]] + combs <- x[["combs3"]] + if (!is.null(combs)) { + for (nm in names(combs)) { + z <- combs[[nm]] + zz <- utils::combn(z, 2L, paste, collapse = ":") + + num[nm, ] <- with( + x, + wcolMeans((F_jkl[[nm]] - Reduce("+", F_jk[zz]) + Reduce("+", F_j[z]))^2, w = w) + ) + denom[nm, ] <- with(x, wcolMeans(F_jkl[[nm]]^2, w = w)) + } } - + list(num = num, denom = denom) } diff --git a/R/hstats.R b/R/hstats.R index 9a49e0dc..b87c9283 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -15,7 +15,7 @@ #' Furthermore, it allows to calculate an experimental partial dependence based #' measure of feature importance, \eqn{\textrm{PDI}_j^2}. It equals the proportion of #' prediction variability unexplained by other features, see [pd_importance()] -#' for details. (This statistic is not shown by `summary()` or `plot()`.) +#' for details. This statistic is not shown by `summary()` or `plot()`. #' #' Instead of using `summary()`, interaction statistics can also be obtained via the #' more flexible functions [h2()], [h2_overall()], [h2_pairwise()], and @@ -36,11 +36,11 @@ #' @param pairwise_m Number of features for which pairwise statistics are to be #' calculated. The features are selected based on Friedman and Popescu's overall #' interaction strength \eqn{H^2_j}. Set to to 0 to avoid pairwise calculations. -#' For multivariate predictions, the union of the column-wise strongest variable -#' names is taken. This can lead to very long run-times. -#' @param threeway_m Same as `pairwise_m`, but controlling the number of features for -#' which threeway interactions should be calculated. Not larger than `pairwise_m`. -#' Set to 0 to avoid threeway calculations. +#' For multivariate predictions, the union of the `pairwise_m` column-wise +#' strongest variable names is taken. This can lead to very long run-times. +#' @param threeway_m Like `pairwise_m`, but controls the feature count for +#' three-way interactions. Cannot be larger than `pairwise_m`. +#' The default is `min(pairwise_m, 5)`. Set to 0 to avoid three-way calculations. #' @param verbose Should a progress bar be shown? The default is `TRUE`. #' @param ... Additional arguments passed to `pred_fun(object, X, ...)`, #' for instance `type = "response"` in a [glm()] model. @@ -50,33 +50,34 @@ #' - `w`: Input `w` (sampled to `n_max` values, or `NULL`). #' - `v`: Same as input `v`. #' - `f`: Matrix with (centered) predictions \eqn{F}. -#' - `mean_f2`: (Weighted) column means of `f`. Used to normalize most statistics. +#' - `mean_f2`: (Weighted) column means of `f`. Used to normalize \eqn{H^2} and +#' \eqn{H^2_j}. #' - `F_j`: List of matrices, each representing (centered) #' partial dependence functions \eqn{F_j}. #' - `F_not_j`: List of matrices with (centered) partial dependence #' functions \eqn{F_{\setminus j}} of other features. #' - `K`: Number of columns of prediction matrix. #' - `pred_names`: Column names of prediction matrix. +#' - `pairwise_m`: Like input `pairwise_m`, but capped at `length(v)`. +#' - `threeway_m`: Like input `threeway_m`, but capped at the smaller of +#' `length(v)` and `pairwise_m`. #' - `h2`: List with numerator and denominator of \eqn{H^2}. #' - `h2_overall`: List with numerator and denominator of \eqn{H^2_j}. -#' - `v_pairwise`: Subset of `v` with largest `h2_overall()` used for pairwise +#' - `v_pairwise`: Subset of `v` with largest \eqn{H^2_j} used for pairwise #' calculations. #' - `combs2`: Named list of variable pairs for which pairwise partial -#' dependence functions are available. Only if pairwise calculations have been done. +#' dependence functions are available. #' - `F_jk`: List of matrices, each representing (centered) bivariate -#' partial dependence functions \eqn{F_{jk}}. -#' Only if pairwise calculations have been done. +#' partial dependence functions \eqn{F_{jk}}. #' - `h2_pairwise`: List with numerator and denominator of \eqn{H^2_{jk}}. #' Only if pairwise calculations have been done. #' - `v_threeway`: Subset of `v` with largest `h2_overall()` used for three-way #' calculations. #' - `combs3`: Named list of variable triples for which three-way partial -#' dependence functions are available. Only if threeway calculations have been done. +#' dependence functions are available. #' - `F_jkl`: List of matrices, each representing (centered) three-way -#' partial dependence functions \eqn{F_{jkl}}. -#' Only if threeway calculations have been done. +#' partial dependence functions \eqn{F_{jkl}}. #' - `h2_threeway`: List with numerator and denominator of \eqn{H^2_{jkl}}. -#' Only if threeway calculations have been done. #' @references #' Friedman, Jerome H., and Bogdan E. Popescu. *"Predictive Learning via Rule Ensembles."* #' The Annals of Applied Statistics 2, no. 3 (2008): 916-54. @@ -89,10 +90,11 @@ #' s <- hstats(fit, X = iris[-1]) #' s #' plot(s) +#' plot(s, zero = FALSE) # Drop 0 #' summary(s) #' #' # Absolute pairwise interaction strengths -#' h2_pairwise(s, normalize = FALSE, squared = FALSE, plot = FALSE) +#' h2_pairwise(s, normalize = FALSE, squared = FALSE, plot = FALSE, zero = FALSE) #' #' # MODEL 2: Multi-response linear regression #' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris) @@ -109,9 +111,7 @@ #' #' # On original scale, we have interactions everywhere... #' s <- hstats(fit, X = iris[-1], type = "response", verbose = FALSE) -#' -#' # All three types use different denominators -#' plot(s, which = 1:3, ncol = 1) +#' plot(s, which = 1:3, ncol = 1) # All three types use different denominators #' #' # All statistics on same scale (of predictions) #' plot(s, which = 1:3, squared = FALSE, normalize = FALSE, facet_scale = "free_y") @@ -123,10 +123,14 @@ hstats <- function(object, ...) { #' @export hstats.default <- function(object, X, v = colnames(X), pred_fun = stats::predict, n_max = 300L, - w = NULL, pairwise_m = 5L, threeway_m = pairwise_m, + w = NULL, pairwise_m = 5L, + threeway_m = min(pairwise_m, 5L), verbose = TRUE, ...) { basic_check(X = X, v = v, pred_fun = pred_fun, w = w) - stopifnot(threeway_m <= pairwise_m) + p <- length(v) + stopifnot(p >= 2L) + pairwise_m <- min(pairwise_m, p) + threeway_m <- min(threeway_m, pairwise_m, p) # Reduce size of X (and w) if (nrow(X) > n_max) { @@ -142,9 +146,7 @@ hstats.default <- function(object, X, v = colnames(X), mean_f2 <- wcolMeans(f^2, w = w) # A vector # Initialize first progress bar - p <- length(v) - show_bar <- verbose && p >= 2L - if (show_bar) { + if (verbose) { cat("1-way calculations...\n") pb <- utils::txtProgressBar(max = p, style = 3) } @@ -175,11 +177,11 @@ hstats.default <- function(object, X, v = colnames(X), w = w ) - if (show_bar) { + if (verbose) { utils::setTxtProgressBar(pb, j) } } - if (show_bar) { + if (verbose) { cat("\n") } @@ -193,31 +195,35 @@ hstats.default <- function(object, X, v = colnames(X), F_j = F_j, F_not_j = F_not_j, K = ncol(f), - pred_names = colnames(f) + pred_names = colnames(f), + pairwise_m = pairwise_m, + threeway_m = threeway_m ) # 0-way and 1-way stats out[["h2"]] <- h2_raw(out) out[["h2_overall"]] <- h2_overall_raw(out) - - # 2+way stats are calculated only for features with largest overall interactions h2_ov <- .zap_small(out$h2_overall$num, eps = 1e-8) # Does eps need to be passed? - out[["v_pairwise"]] <- v2 <- get_v(h2_ov, m = pairwise_m) - if (min(pairwise_m, length(v2)) >= 2L) { - out[c("combs2", "F_jk")] <- mway( - object, v = v2, X = X, pred_fun = pred_fun, w = w, way = 2L, verb = verbose, ... - ) + + if (pairwise_m >= 2L) { + out[["v_pairwise"]] <- v2 <- get_v(h2_ov, m = pairwise_m) + if (length(v2) >= 2L) { + out[c("combs2", "F_jk")] <- mway( + object, v = v2, X = X, pred_fun = pred_fun, w = w, way = 2L, verb = verbose, ... + ) + } out[["h2_pairwise"]] <- h2_pairwise_raw(out) } - - out[["v_threeway"]] <- v3 <- get_v(h2_ov, m = threeway_m) - if (min(threeway_m, length(v3)) >= 3L) { - out[c("combs3", "F_jkl")] <- mway( - object, v = v3, X = X, pred_fun = pred_fun, w = w, way = 3L, verb = verbose, ... - ) + if (threeway_m >= 3L) { + out[["v_threeway"]] <- v3 <- get_v(h2_ov, m = threeway_m) + if (length(v3) >= 3L) { + out[c("combs3", "F_jkl")] <- mway( + object, v = v3, X = X, pred_fun = pred_fun, w = w, way = 3L, verb = verbose, ... + ) + } out[["h2_threeway"]] <- h2_threeway_raw(out) } - + structure(out, class = "hstats") } @@ -226,7 +232,8 @@ hstats.default <- function(object, X, v = colnames(X), hstats.ranger <- function(object, X, v = colnames(X), pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions, n_max = 300L, w = NULL, pairwise_m = 5L, - threeway_m = pairwise_m, verbose = TRUE, ...) { + threeway_m = min(pairwise_m, 5L), + verbose = TRUE, ...) { hstats.default( object = object, X = X, @@ -246,7 +253,8 @@ hstats.ranger <- function(object, X, v = colnames(X), hstats.Learner <- function(object, X, v = colnames(X), pred_fun = NULL, n_max = 300L, w = NULL, pairwise_m = 5L, - threeway_m = pairwise_m, verbose = TRUE, ...) { + threeway_m = min(pairwise_m, 5L), + verbose = TRUE, ...) { if (is.null(pred_fun)) { pred_fun <- mlr3_pred_fun(object, X = X) } @@ -270,7 +278,8 @@ hstats.explainer <- function(object, X = object[["data"]], v = colnames(X), pred_fun = object[["predict_function"]], n_max = 300L, w = object[["weights"]], - pairwise_m = 5L, threeway_m = pairwise_m, + pairwise_m = 5L, + threeway_m = min(pairwise_m, 5L), verbose = TRUE, ...) { hstats.default( object = object[["model"]], @@ -296,7 +305,7 @@ hstats.explainer <- function(object, X = object[["data"]], #' @export #' @seealso See [hstats()] for examples. print.hstats <- function(x, ...) { - cat("'hstats' object. Run plot() or summary() for details.\n\n") + cat("'hstats' object. Use plot() or summary() for details.\n\n") cat("Proportion of prediction variability unexplained by main effects of v:\n") print(h2(x)) cat("\n") @@ -305,33 +314,38 @@ print.hstats <- function(x, ...) { #' Summary Method #' -#' Summary method for "hstats" object. +#' Summary method for "hstats" object. Note that \eqn{H^2} is not affected by +#' the arguments `normalize` and `squared`. #' #' @inheritParams h2_overall #' @param ... Currently not used. #' @returns -#' An object of class "summary_hstats" representing a named list with statistics. +#' An object of class "summary_hstats" representing a named list with statistics +#' "h2", "h2_overall", "h2_pairwise", "h2_threeway", and the input flag "normalize". +#' Statistics that equal `NULL` are omitted from the list. #' @export #' @seealso See [hstats()] for examples. summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = Inf, eps = 1e-8, ...) { + top_m = Inf, zero = TRUE, eps = 1e-8, ...) { args <- list( object = object, normalize = normalize, squared = squared, sort = sort, top_m = top_m, + zero = zero, eps = eps, plot = FALSE ) - out <- list( + out <- list( h2 = h2(object, normalize = normalize, squared = squared, eps = eps), h2_overall = do.call(h2_overall, args), h2_pairwise = do.call(h2_pairwise, args), - h2_threeway = do.call(h2_threeway, args) + h2_threeway = do.call(h2_threeway, args), + normalize = normalize ) - class(out) <- "summary_hstats" - out + out <- out[!sapply(out, is.null)] + structure(out, class = "summary_hstats") } #' Print Method @@ -344,18 +358,20 @@ summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE #' @export #' @seealso See [hstats()] for examples. print.summary_hstats <- function(x, ...) { - addon <- "(for features with strong overall interactions)" + flag <- if (x[["normalize"]]) "relative" else "absolute" + txt <- c( - h2 = "Proportion of prediction variability unexplained by main effects of v", - h2_overall = "Strongest overall interactions", - h2_pairwise = paste0("Strongest relative pairwise interactions\n", addon), - h2_threeway = paste0("Strongest relative three-way interactions\n", addon) + h2 = "Prediction variability unexplained by main effects", + h2_overall = sprintf("Strongest %s overall interactions", flag), + h2_pairwise = sprintf("Strongest %s pairwise interactions", flag), + h2_threeway = sprintf("Strongest %s three-way interaction", flag) ) + top_n <- c(h2 = 1L, h2_overall = 4L, h2_pairwise = 3L, h2_threeway = 1L) - for (nm in names(Filter(Negate(is.null), x))) { - cat(txt[[nm]]) + for (nm in setdiff(names(x), "normalize")) { + cat(txt[nm]) cat("\n") - print(utils::head(drop(x[[nm]]))) + print(utils::head(x[[nm]], top_n[nm])) cat("\n") } invisible(x) @@ -378,15 +394,28 @@ print.summary_hstats <- function(x, ...) { #' @export #' @seealso See [hstats()] for examples. plot.hstats <- function(x, which = 1:2, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = 15L, eps = 1e-8, fill = "#2b51a1", + top_m = 15L, zero = TRUE, eps = 1e-8, fill = "#2b51a1", facet_scales = "free", ncol = 2L, rotate_x = FALSE, ...) { su <- summary( - x, normalize = normalize, squared = squared, sort = sort, top_m = top_m, eps = eps + x, + normalize = normalize, + squared = squared, + sort = sort, + top_m = top_m, + zero = zero, + eps = eps ) - nms <- c("h2_overall", "h2_pairwise", "h2_threeway") - ids <- c("Overall", "Pairwise", "Threeway") - dat <- lapply(which, FUN = function(j) mat2df(su[[nms[j]]], id = ids[j])) + + # This part could be simplified, especially the "match()" + stat_names <- c("h2_overall", "h2_pairwise", "h2_threeway")[which] + stat_labs <- c("Overall", "Pairwise", "Three-way")[which] + ok <- stat_names[stat_names %in% names(su)] + if (length(ok) == 0L) { + return(NULL) + } + dat <- lapply(ok, FUN = function(nm) mat2df(su[[nm]], id = stat_labs[match(nm, stat_names)])) dat <- do.call(rbind, dat) + p <- ggplot2::ggplot(dat, ggplot2::aes(x = value_, y = variable_)) + ggplot2::ylab(ggplot2::element_blank()) + ggplot2::xlab("Value") @@ -434,8 +463,7 @@ mway <- function(object, v, X, pred_fun = stats::predict, w = NULL, F_way <- vector("list", length = n_combs) names(F_way) <- names(combs) <- sapply(combs, paste, collapse = ":") - show_bar <- verb && (n_combs >= way) - if (show_bar) { + if (verb) { cat(way, "way calculations...\n", sep = "-") pb <- utils::txtProgressBar(max = n_combs, style = 3) } @@ -446,11 +474,11 @@ mway <- function(object, v, X, pred_fun = stats::predict, w = NULL, pd_raw(object, v = z, X = X, grid = X[, z], pred_fun = pred_fun, w = w, ...), w = w ) - if (show_bar) { + if (verb) { utils::setTxtProgressBar(pb, i) } } - if (show_bar) { + if (verb) { cat("\n") } list(combs, F_way) @@ -468,17 +496,15 @@ mway <- function(object, v, X, pred_fun = stats::predict, w = NULL, #' @param H Unnormalized, unsorted H2_j values. #' @param m Number of features to pick per column. #' -#' @returns A vector of feature names. +#' @returns A vector of the union of the m column-wise most important features. get_v <- function(H, m) { - # Get largest m positive values per column + v <- rownames(H) selector <- function(vv) names(utils::head(sort(-vv[vv > 0]), m)) if (NCOL(H) == 1L) { v_cand <- selector(drop(H)) } else { v_cand <- Reduce(union, lapply(asplit(H, MARGIN = 2L), FUN = selector)) } - # Same order as in v - v <- rownames(H) v[v %in% v_cand] } diff --git a/R/pd_importance.R b/R/pd_importance.R index dbb74f90..a9c60899 100644 --- a/R/pd_importance.R +++ b/R/pd_importance.R @@ -57,11 +57,9 @@ pd_importance.default <- function(object, ...) { #' @describeIn pd_importance PD based feature importance from "hstats" object. #' @export pd_importance.hstats <- function(object, normalize = TRUE, squared = TRUE, - sort = TRUE, top_m = 15L, eps = 1e-8, plot = FALSE, - fill = "#2b51a1", ...) { - num <- with( - object, matrix(nrow = length(v), ncol = K, dimnames = list(v, pred_names)) - ) + sort = TRUE, top_m = 15L, zero = TRUE, eps = 1e-8, + plot = FALSE, fill = "#2b51a1", ...) { + num <- init_numerator(object, way = 1L) for (z in object[["v"]]) { num[z, ] <- with(object, wcolMeans((f - F_not_j[[z]])^2, w = w)) } @@ -71,7 +69,8 @@ pd_importance.hstats <- function(object, normalize = TRUE, squared = TRUE, normalize = normalize, squared = squared, sort = sort, - top_m = top_m, + top_m = top_m, + zero = zero, eps = eps ) if (plot) plot_stat(out, fill = fill, ...) else out diff --git a/R/utils.R b/R/utils.R index 9f9a0fc2..5b32e278 100644 --- a/R/utils.R +++ b/R/utils.R @@ -216,9 +216,9 @@ basic_check <- function(X, v, pred_fun, w = NULL) { #' @inheritParams H2_overall #' @param num Matrix or vector of statistic. #' @param denom Denominator of statistic (a matrix, number, or vector compatible with `num`). -#' @returns Matrix or vector of statistics. +#' @returns Matrix or vector of statistics. If length of output is 0, then `NULL`. postprocess <- function(num, denom = 1, normalize = TRUE, squared = TRUE, - sort = TRUE, top_m = Inf, eps = 1e-8) { + sort = TRUE, top_m = Inf, zero = TRUE, eps = 1e-8) { out <- .zap_small(num, eps = eps) if (normalize) { if (length(denom) == 1L || length(num) == length(denom)) { @@ -239,7 +239,15 @@ postprocess <- function(num, denom = 1, normalize = TRUE, squared = TRUE, out <- sort(out, decreasing = TRUE) } } - utils::head(out, n = top_m) + if (!zero) { + if (is.matrix(out)) { + out <- out[rowSums(out) > 0, , drop = FALSE] + } else { + out <- out[out > 0] + } + } + out <- utils::head(out, n = top_m) + if (length(out) == 0L) NULL else out } #' Zap Small Values @@ -315,6 +323,51 @@ mat2df <- function(mat, id = "Overall") { poor_man_stack(out, to_stack = pred_names) } +#' Initializor of Numerator Statistics +#' +#' Internal helper function that returns a matrix of all zeros with the right +#' column and row names for statistics of any "way". If some features have been dropped +#' from the statistics calculations, they are added as 0. +#' +#' @noRd +#' @keywords internal +#' @param x A list containing the elements "v", "K", "pred_names", "v_pairwise", +#' "v_threeway", "pairwise_m", "threeway_m". +#' @param way Integer between 1 and 3 of the order of the interaction. +#' @returns A matrix of all zeros. +init_numerator <- function(x, way = 1L) { + stopifnot(way %in% 1:3) + + v <- x[["v"]] + K <- x[["K"]] + pred_names <- x[["pred_names"]] + + # Simple case + if (way == 1L) { + return(matrix(nrow = length(v), ncol = K, dimnames = list(v, pred_names))) + } + + # Determine v_cand_0, which is v_cand with additional features to end up with length m + if (way == 2L) { + v_cand <- x[["v_pairwise"]] + m <- x[["pairwise_m"]] + } else { + v_cand <- x[["v_threeway"]] + m <- x[["threeway_m"]] + } + m_miss <- m - length(v_cand) + if (m_miss > 0L) { + v_cand_0 <- c(v_cand, utils::head(setdiff(v, v_cand), m_miss)) + v_cand_0 <- v[v %in% v_cand_0] # Bring into order of v + } else { + v_cand_0 <- v_cand + } + + # Get all interactions of order "way". c() turns the array into a vector + cn0 <- c(utils::combn(v_cand_0, m = way, FUN = paste, collapse = ":")) + matrix(0, nrow = length(cn0), ncol = K, dimnames = list(cn0, pred_names)) +} + #' Bin into Quantiles #' #' Internal function. Applies [cut()] to quantile breaks. @@ -339,8 +392,11 @@ qcut <- function(x, m) { #' @param x A matrix of statistics with rownames. #' @param fill Color of bar (only for univariate statistics). #' @param ... Arguments passed to `geom_bar()`. -#' @returns An object of class "ggplot". +#' @returns An object of class "ggplot", or `NULL`. plot_stat <- function(x, fill = "#2b51a1", ...) { + if (is.null(x)) { + return(NULL) + } p <- ggplot2::ggplot(mat2df(x), ggplot2::aes(x = value_, y = variable_)) + ggplot2::ylab(ggplot2::element_blank()) + ggplot2::xlab("Value") diff --git a/man/H2_overall.Rd b/man/H2_overall.Rd index 67127e99..5b2a1164 100644 --- a/man/H2_overall.Rd +++ b/man/H2_overall.Rd @@ -16,6 +16,7 @@ h2_overall(object, ...) squared = TRUE, sort = TRUE, top_m = 15L, + zero = TRUE, eps = 1e-08, plot = FALSE, fill = "#2b51a1", @@ -36,6 +37,8 @@ h2_overall(object, ...) \item{top_m}{How many rows should be shown? (\code{Inf} to show all.)} +\item{zero}{Should rows with all 0 be shown? Default is \code{TRUE}.} + \item{eps}{Threshold below which numerator values are set to 0.} \item{plot}{Should results be plotted as barplot? Default is \code{FALSE}.} @@ -103,7 +106,7 @@ h2_overall(s, plot = TRUE) # MODEL 2: Multi-response linear regression fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris) s <- hstats(fit, X = iris[3:5], verbose = FALSE) -h2_overall(s, plot = TRUE) +h2_overall(s, plot = TRUE, zero = FALSE) } \references{ Friedman, Jerome H., and Bogdan E. Popescu. \emph{"Predictive Learning via Rule Ensembles."} diff --git a/man/H2_pairwise.Rd b/man/H2_pairwise.Rd index bdc26676..ac912b04 100644 --- a/man/H2_pairwise.Rd +++ b/man/H2_pairwise.Rd @@ -16,6 +16,7 @@ h2_pairwise(object, ...) squared = TRUE, sort = TRUE, top_m = 15L, + zero = TRUE, eps = 1e-08, plot = FALSE, fill = "#2b51a1", @@ -36,6 +37,8 @@ h2_pairwise(object, ...) \item{top_m}{How many rows should be shown? (\code{Inf} to show all.)} +\item{zero}{Should rows with all 0 be shown? Default is \code{TRUE}.} + \item{eps}{Threshold below which numerator values are set to 0.} \item{plot}{Should results be plotted as barplot? Default is \code{FALSE}.} @@ -109,6 +112,7 @@ s <- hstats(fit, X = iris[-1]) # Proportion of joint effect coming from pairwise interaction # (for features with strongest overall interactions) h2_pairwise(s) +h2_pairwise(s, zero = FALSE) # Drop 0 # Absolute measure as alternative h2_pairwise(s, normalize = FALSE, squared = FALSE) @@ -117,6 +121,7 @@ h2_pairwise(s, normalize = FALSE, squared = FALSE) fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris) s <- hstats(fit, X = iris[3:5], verbose = FALSE) h2_pairwise(s, plot = TRUE) +h2_pairwise(s, zero = FALSE, plot = TRUE) } \references{ Friedman, Jerome H., and Bogdan E. Popescu. \emph{"Predictive Learning via Rule Ensembles."} diff --git a/man/H2_threeway.Rd b/man/H2_threeway.Rd index 4a4eddb1..dcaf3cf5 100644 --- a/man/H2_threeway.Rd +++ b/man/H2_threeway.Rd @@ -16,6 +16,7 @@ h2_threeway(object, ...) squared = TRUE, sort = TRUE, top_m = 15L, + zero = TRUE, eps = 1e-08, plot = FALSE, fill = "#2b51a1", @@ -36,6 +37,8 @@ h2_threeway(object, ...) \item{top_m}{How many rows should be shown? (\code{Inf} to show all.)} +\item{zero}{Should rows with all 0 be shown? Default is \code{TRUE}.} + \item{eps}{Threshold below which numerator values are set to 0.} \item{plot}{Should results be plotted as barplot? Default is \code{FALSE}.} diff --git a/man/hstats.Rd b/man/hstats.Rd index 79c00ce5..560c828e 100644 --- a/man/hstats.Rd +++ b/man/hstats.Rd @@ -18,7 +18,7 @@ hstats(object, ...) n_max = 300L, w = NULL, pairwise_m = 5L, - threeway_m = pairwise_m, + threeway_m = min(pairwise_m, 5L), verbose = TRUE, ... ) @@ -31,7 +31,7 @@ hstats(object, ...) n_max = 300L, w = NULL, pairwise_m = 5L, - threeway_m = pairwise_m, + threeway_m = min(pairwise_m, 5L), verbose = TRUE, ... ) @@ -44,7 +44,7 @@ hstats(object, ...) n_max = 300L, w = NULL, pairwise_m = 5L, - threeway_m = pairwise_m, + threeway_m = min(pairwise_m, 5L), verbose = TRUE, ... ) @@ -57,7 +57,7 @@ hstats(object, ...) n_max = 300L, w = object[["weights"]], pairwise_m = 5L, - threeway_m = pairwise_m, + threeway_m = min(pairwise_m, 5L), verbose = TRUE, ... ) @@ -87,12 +87,12 @@ selected from \code{X}. In this case, set a random seed for reproducibility.} \item{pairwise_m}{Number of features for which pairwise statistics are to be calculated. The features are selected based on Friedman and Popescu's overall interaction strength \eqn{H^2_j}. Set to to 0 to avoid pairwise calculations. -For multivariate predictions, the union of the column-wise strongest variable -names is taken. This can lead to very long run-times.} +For multivariate predictions, the union of the \code{pairwise_m} column-wise +strongest variable names is taken. This can lead to very long run-times.} -\item{threeway_m}{Same as \code{pairwise_m}, but controlling the number of features for -which threeway interactions should be calculated. Not larger than \code{pairwise_m}. -Set to 0 to avoid threeway calculations.} +\item{threeway_m}{Like \code{pairwise_m}, but controls the feature count for +three-way interactions. Cannot be larger than \code{pairwise_m}. +The default is \code{min(pairwise_m, 5)}. Set to 0 to avoid three-way calculations.} \item{verbose}{Should a progress bar be shown? The default is \code{TRUE}.} } @@ -103,33 +103,34 @@ An object of class "hstats" containing these elements: \item \code{w}: Input \code{w} (sampled to \code{n_max} values, or \code{NULL}). \item \code{v}: Same as input \code{v}. \item \code{f}: Matrix with (centered) predictions \eqn{F}. -\item \code{mean_f2}: (Weighted) column means of \code{f}. Used to normalize most statistics. +\item \code{mean_f2}: (Weighted) column means of \code{f}. Used to normalize \eqn{H^2} and +\eqn{H^2_j}. \item \code{F_j}: List of matrices, each representing (centered) partial dependence functions \eqn{F_j}. \item \code{F_not_j}: List of matrices with (centered) partial dependence functions \eqn{F_{\setminus j}} of other features. \item \code{K}: Number of columns of prediction matrix. \item \code{pred_names}: Column names of prediction matrix. +\item \code{pairwise_m}: Like input \code{pairwise_m}, but capped at \code{length(v)}. +\item \code{threeway_m}: Like input \code{threeway_m}, but capped at the smaller of +\code{length(v)} and \code{pairwise_m}. \item \code{h2}: List with numerator and denominator of \eqn{H^2}. \item \code{h2_overall}: List with numerator and denominator of \eqn{H^2_j}. -\item \code{v_pairwise}: Subset of \code{v} with largest \code{h2_overall()} used for pairwise +\item \code{v_pairwise}: Subset of \code{v} with largest \eqn{H^2_j} used for pairwise calculations. \item \code{combs2}: Named list of variable pairs for which pairwise partial -dependence functions are available. Only if pairwise calculations have been done. +dependence functions are available. \item \code{F_jk}: List of matrices, each representing (centered) bivariate partial dependence functions \eqn{F_{jk}}. -Only if pairwise calculations have been done. \item \code{h2_pairwise}: List with numerator and denominator of \eqn{H^2_{jk}}. Only if pairwise calculations have been done. \item \code{v_threeway}: Subset of \code{v} with largest \code{h2_overall()} used for three-way calculations. \item \code{combs3}: Named list of variable triples for which three-way partial -dependence functions are available. Only if threeway calculations have been done. +dependence functions are available. \item \code{F_jkl}: List of matrices, each representing (centered) three-way partial dependence functions \eqn{F_{jkl}}. -Only if threeway calculations have been done. \item \code{h2_threeway}: List with numerator and denominator of \eqn{H^2_{jkl}}. -Only if threeway calculations have been done. } } \description{ @@ -149,7 +150,7 @@ see \code{\link[=h2_threeway]{h2_threeway()}} for details. Furthermore, it allows to calculate an experimental partial dependence based measure of feature importance, \eqn{\textrm{PDI}_j^2}. It equals the proportion of prediction variability unexplained by other features, see \code{\link[=pd_importance]{pd_importance()}} -for details. (This statistic is not shown by \code{summary()} or \code{plot()}.) +for details. This statistic is not shown by \code{summary()} or \code{plot()}. Instead of using \code{summary()}, interaction statistics can also be obtained via the more flexible functions \code{\link[=h2]{h2()}}, \code{\link[=h2_overall]{h2_overall()}}, \code{\link[=h2_pairwise]{h2_pairwise()}}, and @@ -172,10 +173,11 @@ fit <- lm(Sepal.Length ~ . + Petal.Width:Species, data = iris) s <- hstats(fit, X = iris[-1]) s plot(s) +plot(s, zero = FALSE) # Drop 0 summary(s) # Absolute pairwise interaction strengths -h2_pairwise(s, normalize = FALSE, squared = FALSE, plot = FALSE) +h2_pairwise(s, normalize = FALSE, squared = FALSE, plot = FALSE, zero = FALSE) # MODEL 2: Multi-response linear regression fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris) @@ -192,9 +194,7 @@ summary(s) # On original scale, we have interactions everywhere... s <- hstats(fit, X = iris[-1], type = "response", verbose = FALSE) - -# All three types use different denominators -plot(s, which = 1:3, ncol = 1) +plot(s, which = 1:3, ncol = 1) # All three types use different denominators # All statistics on same scale (of predictions) plot(s, which = 1:3, squared = FALSE, normalize = FALSE, facet_scale = "free_y") diff --git a/man/pd_importance.Rd b/man/pd_importance.Rd index 2da3dc94..10f3e998 100644 --- a/man/pd_importance.Rd +++ b/man/pd_importance.Rd @@ -16,6 +16,7 @@ pd_importance(object, ...) squared = TRUE, sort = TRUE, top_m = 15L, + zero = TRUE, eps = 1e-08, plot = FALSE, fill = "#2b51a1", @@ -36,6 +37,8 @@ pd_importance(object, ...) \item{top_m}{How many rows should be shown? (\code{Inf} to show all.)} +\item{zero}{Should rows with all 0 be shown? Default is \code{TRUE}.} + \item{eps}{Threshold below which numerator values are set to 0.} \item{plot}{Should results be plotted as barplot? Default is \code{FALSE}.} diff --git a/man/plot.hstats.Rd b/man/plot.hstats.Rd index 7cbc7de0..7164cb40 100644 --- a/man/plot.hstats.Rd +++ b/man/plot.hstats.Rd @@ -11,6 +11,7 @@ squared = TRUE, sort = TRUE, top_m = 15L, + zero = TRUE, eps = 1e-08, fill = "#2b51a1", facet_scales = "free", @@ -35,6 +36,8 @@ use \code{1:3}.} \item{top_m}{How many rows should be shown? (\code{Inf} to show all.)} +\item{zero}{Should rows with all 0 be shown? Default is \code{TRUE}.} + \item{eps}{Threshold below which numerator values are set to 0.} \item{fill}{Color of bar (only for univariate statistics).} diff --git a/man/summary.hstats.Rd b/man/summary.hstats.Rd index a8ea4983..3bfa42eb 100644 --- a/man/summary.hstats.Rd +++ b/man/summary.hstats.Rd @@ -10,6 +10,7 @@ squared = TRUE, sort = TRUE, top_m = Inf, + zero = TRUE, eps = 1e-08, ... ) @@ -26,15 +27,20 @@ \item{top_m}{How many rows should be shown? (\code{Inf} to show all.)} +\item{zero}{Should rows with all 0 be shown? Default is \code{TRUE}.} + \item{eps}{Threshold below which numerator values are set to 0.} \item{...}{Currently not used.} } \value{ -An object of class "summary_hstats" representing a named list with statistics. +An object of class "summary_hstats" representing a named list with statistics +"h2", "h2_overall", "h2_pairwise", "h2_threeway", and the input flag "normalize". +Statistics that equal \code{NULL} are omitted from the list. } \description{ -Summary method for "hstats" object. +Summary method for "hstats" object. Note that \eqn{H^2} is not affected by +the arguments \code{normalize} and \code{squared}. } \seealso{ See \code{\link[=hstats]{hstats()}} for examples. diff --git a/tests/testthat/test_hstats.R b/tests/testthat/test_hstats.R index e7d38fc4..4b18deef 100644 --- a/tests/testthat/test_hstats.R +++ b/tests/testthat/test_hstats.R @@ -1,30 +1,47 @@ test_that("Additive models show 0 interactions (univariate)", { fit <- lm(Sepal.Length ~ ., data = iris) s <- hstats(fit, X = iris[-1L], verbose = FALSE) - expect_null(h2_pairwise(s)) - expect_null(h2_threeway(s)) + expect_null(h2_pairwise(s, zero = FALSE)) + expect_equal(c(h2_pairwise(s, sort = FALSE, top_m = Inf)), rep(0, choose(4, 2))) + + expect_null(h2_threeway(s, zero = FALSE)) + expect_equal(c(h2_threeway(s, sort = FALSE, top_m = Inf)), rep(0, choose(4, 3))) + expect_equal( - h2_overall(s, plot = FALSE), + h2_overall(s), matrix(c(0, 0, 0, 0), ncol = 1L, dimnames = list(colnames(iris[-1L]), NULL)) ) + expect_null(h2_overall(s, zero = FALSE)) + expect_equal(h2(s), 0) + expect_s3_class(h2_overall(s, plot = TRUE), "ggplot") expect_s3_class(plot(s, rotate_x = TRUE), "ggplot") + expect_null(h2_overall(s, zero = FALSE, plot = TRUE)) }) test_that("Additive models show 0 interactions (multivariate)", { fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris) s <- hstats(fit, X = iris[3:5], verbose = FALSE) - expect_null(h2_pairwise(s)) - expect_null(h2_threeway(s)) + + expect_null(h2_pairwise(s, zero = FALSE)) + expect_true(all(h2_pairwise(s) == 0)) + + expect_null(h2_threeway(s, zero = FALSE)) + expect_equal(unname(h2_threeway(s)), cbind(0, 0)) + expect_equal( - h2_overall(s, plot = FALSE), + h2_overall(s, sort = FALSE), matrix( 0, ncol = 2L, nrow = 3L, dimnames = list(colnames(iris[3:5]), colnames(iris[1:2])) ) ) + expect_null(h2_overall(s, zero = FALSE)) + expect_equal(h2(s), c(Sepal.Length = 0, Sepal.Width = 0)) + expect_s3_class(h2_overall(s, plot = TRUE), "ggplot") + expect_null(h2_overall(s, zero = FALSE, plot = TRUE)) expect_s3_class(plot(s), "ggplot") }) @@ -37,14 +54,19 @@ test_that("Non-additive models show interactions > 0 (one interaction)", { expect_true( all(rownames(out[out > 0, , drop = FALSE]) %in% c("Petal.Length", "Petal.Width")) ) + out <- h2_overall(s, zero = FALSE, sort = FALSE) + expect_true(all(rownames(out) %in% c("Petal.Length", "Petal.Width"))) - out <- h2_pairwise(s) + out <- h2_pairwise(s, zero = FALSE) expect_equal(rownames(out), "Petal.Length:Petal.Width") + out <- h2_pairwise(s) + expect_equal(rownames(out[out > 0, , drop = FALSE]), "Petal.Length:Petal.Width") expect_s3_class(h2_overall(s, plot = TRUE), "ggplot") expect_s3_class(h2_pairwise(s, plot = TRUE), "ggplot") expect_s3_class(plot(s), "ggplot") - expect_null(h2_threeway(s)) + expect_null(h2_threeway(s, zero = FALSE)) + expect_equal(c(h2_threeway(s)), rep(0, times = choose(4, 3))) }) fit <- lm( @@ -60,18 +82,27 @@ test_that("Non-additive models show interactions > 0 (two interactions)", { rownames(out[out > 0, , drop = FALSE]), c("Petal.Length", "Petal.Width", "Species") ) + out <- h2_overall(s, sort = FALSE, normalize = FALSE, squared = FALSE, zero = FALSE) + expect_equal( + rownames(out), c("Petal.Length", "Petal.Width", "Species") + ) out <- h2_pairwise(s, sort = FALSE, normalize = FALSE, squared = FALSE) expect_equal( rownames(out[out > 0, , drop = FALSE]), c("Petal.Length:Petal.Width", "Petal.Length:Species") ) + out <- h2_pairwise(s, sort = FALSE, normalize = FALSE, squared = FALSE, zero = FALSE) + expect_equal( + rownames(out), c("Petal.Length:Petal.Width", "Petal.Length:Species") + ) expect_s3_class(h2_overall(s, plot = TRUE), "ggplot") expect_s3_class(h2_pairwise(s, plot = TRUE), "ggplot") expect_s3_class(h2_threeway(s, plot = TRUE), "ggplot") + expect_null(h2_threeway(s, plot = TRUE, zero = FALSE)) - expect_equal(c(h2_threeway(s)), 0) + expect_equal(c(h2_threeway(s)), rep(0, choose(4, 3))) }) test_that("passing v works", { @@ -115,7 +146,7 @@ test_that("Stronger interactions get higher statistics", { expect_true( all(h2_overall(int2, top_m = 2L) > h2_overall(int1, top_m = 2L)) ) - expect_true(h2_pairwise(int2) > h2_pairwise(int1)) + expect_true(h2_pairwise(int2, zero = FALSE) > h2_pairwise(int1, zero = FALSE)) }) test_that("subsampling has an effect", { @@ -156,12 +187,17 @@ test_that("Three-way interaction behaves correctly across dimensions", { expect_equal(2 * out[, "up"], out[, "up2"]) }) -test_that("Three-way interaction can be suppressed", { +test_that("Pairwise and three-way interactions can be suppressed", { fit <- lm(uptake ~ Type * Treatment * conc, data = CO2) s <- hstats(fit, X = CO2[2:4], verbose = FALSE, threeway_m = 0L) expect_null(h2_threeway(s)) + s <- hstats(fit, X = CO2[2:4], verbose = FALSE, pairwise_m = 2L) + expect_equal(nrow(h2_pairwise(s)), 1L) + expect_null(h2_threeway(s)) + s <- hstats(fit, X = CO2[2:4], verbose = FALSE, pairwise_m = 0L) + expect_equal(nrow(h2_overall(s)), 3L) expect_null(h2_pairwise(s)) expect_null(h2_threeway(s)) }) @@ -185,7 +221,10 @@ test_that("Statistics react on normalize, (sorting), squaring, and top m", { s$h2_overall$num[1:2, , drop = FALSE] ) - expect_identical(h2_pairwise(s, sort = FALSE), s$h2_pairwise$num / s$h2_pairwise$denom) + expect_identical( + h2_pairwise(s, sort = FALSE), + s$h2_pairwise$num / s$h2_pairwise$denom + ) expect_identical(h2_pairwise(s, sort = FALSE, normalize = FALSE), s$h2_pairwise$num) expect_identical( h2_pairwise(s, sort = FALSE, normalize = FALSE, squared = FALSE), diff --git a/tests/testthat/test_utils.R b/tests/testthat/test_utils.R index 18bfd0ff..4ad8db0f 100644 --- a/tests/testthat/test_utils.R +++ b/tests/testthat/test_utils.R @@ -238,6 +238,9 @@ test_that("postprocess() works for matrix input", { postprocess(num = num, denom = 1:2, sort = FALSE), num / cbind(c(1, 1, 1), c(2, 2, 2)) ) + + expect_equal(postprocess(num = cbind(0:1, 0:1), zero = FALSE), rbind(c(1, 1))) + expect_null(postprocess(num = cbind(0, 0), zero = FALSE)) }) test_that("postprocess() works for vector input", { @@ -249,6 +252,9 @@ test_that("postprocess() works for vector input", { expect_equal(postprocess(num = num, sort = FALSE), num) expect_equal(postprocess(num = num, sort = FALSE, top_m = 2), num[1:2]) expect_equal(postprocess(num = num, squared = FALSE), sqrt(num[3:1])) + + expect_equal(postprocess(num = 0:1, denom = c(2, 2), zero = FALSE), 0.5) + expect_null(postprocess(num = 0, zero = FALSE)) }) test_that(".zap_small() works for vector input", {