Skip to content

Commit

Permalink
Merge pull request #63 from mayer79/pad_zeros
Browse files Browse the repository at this point in the history
Pad zeros
  • Loading branch information
mayer79 authored Sep 27, 2023
2 parents d0abdbd + dba213e commit 523ff21
Show file tree
Hide file tree
Showing 17 changed files with 326 additions and 168 deletions.
18 changes: 11 additions & 7 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
# hstats 0.3.0

## Major user visible changes
## Visible changes

- Grid calculation: So far, the default grid strategy "uniform" used `pretty()` to generate the evaluation points. To provide more predictable grid sizes, and to be more in line with other implementations of partial dependence and ICE, we now use `seq()` to create the uniform grid. This affects `ice()`, `partial_dep()` and the exported helper functions `univariate_grid()` and `multivariate_grid()`.
- Grid of `ice()` and `partial_dep()`: So far, the default grid strategy "uniform" used `pretty()` to generate the evaluation points. To provide more predictable grid sizes, and to be more in line with other implementations of partial dependence and ICE, we now use `seq()` to create the uniform grid.
- `h2_pairwise()` and `h2_threeway()` will now also include 0 values. Use `zero = FALSE` to drop them, see below. The padding with 0 is done at no computational cost, and will affect only up to `pairwise_m` and `threeway_m` features.
- `hstats()`: The default number of features considered for *three-way interactions* has been changed from `threeway_m = pairwise_m` to the more cautious `threeway_m = min(pairwise_m, 5L)`. Furthermore, `threeway_m` is capped at `pairwise_m`.
- The `print()` method of `summary.hstats()` is less verbose.

## Internal major changes

- All available H-statistics are now calculated within `hstats()` and attached to the resulting object. Each statistic is stored as list with numerator and denominator matrices/vectors. The functions `h2()`, `h2_overall()`, `h2_pairwise()`, and `h2_threeway()`, `print.hstats()`, `summary().hstats()`, `plot.hstats()` will use these without having to recalculate the required numerators and denominators. The results, however, are unchanged.

## Minor improvements
## Improvements

- `h2_overall()`, `h2_pairwise()`, `h2_threeway()`, `plot.hstats()`, and `summary.hstats()` have received an argument `zero = TRUE`. Set to `FALSE` to drop statistics having value 0.
- `perm_importance()` and `average_loss()` will now recycle a univariate response when combined with multivariate predictions. This is useful, e.g., when the prediction function represents the predictions of multiple models that should be evaluated against a common response.

## Bug fixes

- All progress bars were initialized 1 step too late.
- `perm_importance()` and `average_loss()` would fail for "mlogloss" in case the response `y` was univariate *and* non-factor/non-character.

## Other changes

- All available H-statistics are now calculated within `hstats()` and attached to the resulting object. Each statistic is stored as list with numerator and denominator matrices/vectors. The functions `h2()`, `h2_overall()`, `h2_pairwise()`, and `h2_threeway()`, `print.hstats()`, `summary().hstats()`, `plot.hstats()` will use these without having to recalculate the required numerators and denominators. The results, however, are unchanged.

# hstats 0.2.0

## New major features
Expand Down
4 changes: 2 additions & 2 deletions R/H2.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ h2.hstats <- function(object, normalize = TRUE, squared = TRUE, eps = 1e-8, ...)
num = object$h2$num,
denom = object$h2$denom,
normalize = normalize,
squared = squared,
sort = FALSE,
squared = squared,
sort = FALSE,
eps = eps
)
}
Expand Down
14 changes: 7 additions & 7 deletions R/H2_overall.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#' @param sort Should results be sorted? Default is `TRUE`.
#' (Multioutput is sorted by row means.)
#' @param top_m How many rows should be shown? (`Inf` to show all.)
#' @param zero Should rows with all 0 be shown? Default is `TRUE`.
#' @param eps Threshold below which numerator values are set to 0.
#' @param plot Should results be plotted as barplot? Default is `FALSE`.
#' @param fill Color of bar (only for univariate statistics).
Expand All @@ -64,7 +65,7 @@
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' s <- hstats(fit, X = iris[3:5], verbose = FALSE)
#' h2_overall(s, plot = TRUE)
#' h2_overall(s, plot = TRUE, zero = FALSE)
h2_overall <- function(object, ...) {
UseMethod("h2_overall")
}
Expand All @@ -78,16 +79,17 @@ h2_overall.default <- function(object, ...) {
#' @describeIn h2_overall Overall interaction strength from "hstats" object.
#' @export
h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE,
top_m = 15L, eps = 1e-8, plot = FALSE, fill = "#2b51a1",
...) {
top_m = 15L, zero = TRUE, eps = 1e-8,
plot = FALSE, fill = "#2b51a1", ...) {
s <- object$h2_overall
out <- postprocess(
num = s$num,
denom = s$denom,
normalize = normalize,
squared = squared,
sort = sort,
top_m = top_m,
top_m = top_m,
zero = zero,
eps = eps
)
if (plot) plot_stat(out, fill = fill, ...) else out
Expand All @@ -106,11 +108,9 @@ h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = T
#' "f", "F_not_j", "F_j", "mean_f2", and "w".
#' @returns A list with the numerator and denominator statistics.
h2_overall_raw <- function(x) {
num <- with(x, matrix(nrow = length(v), ncol = K, dimnames = list(v, pred_names)))

num <- init_numerator(x, way = 1L)
for (z in x[["v"]]) {
num[z, ] <- with(x, wcolMeans((f - F_j[[z]] - F_not_j[[z]])^2, w = w))
}

list(num = num, denom = x[["mean_f2"]])
}
31 changes: 17 additions & 14 deletions R/H2_pairwise.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
#' # Proportion of joint effect coming from pairwise interaction
#' # (for features with strongest overall interactions)
#' h2_pairwise(s)
#' h2_pairwise(s, zero = FALSE) # Drop 0
#'
#' # Absolute measure as alternative
#' h2_pairwise(s, normalize = FALSE, squared = FALSE)
Expand All @@ -69,6 +70,7 @@
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' s <- hstats(fit, X = iris[3:5], verbose = FALSE)
#' h2_pairwise(s, plot = TRUE)
#' h2_pairwise(s, zero = FALSE, plot = TRUE)
h2_pairwise <- function(object, ...) {
UseMethod("h2_pairwise")
}
Expand All @@ -82,8 +84,8 @@ h2_pairwise.default <- function(object, ...) {
#' @describeIn h2_pairwise Pairwise interaction strength from "hstats" object.
#' @export
h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE,
top_m = 15L, eps = 1e-8, plot = FALSE,
fill = "#2b51a1", ...) {
top_m = 15L, zero = TRUE, eps = 1e-8,
plot = FALSE, fill = "#2b51a1", ...) {
s <- object$h2_pairwise
if (is.null(s)) {
return(NULL)
Expand All @@ -94,7 +96,8 @@ h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE, sort =
normalize = normalize,
squared = squared,
sort = sort,
top_m = top_m,
top_m = top_m,
zero = zero,
eps = eps
)
if (plot) plot_stat(out, fill = fill, ...) else out
Expand All @@ -107,21 +110,21 @@ h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE, sort =
#'
#' @noRd
#' @keywords internal
#' @param x A list containing the elements "combs2", "K", "pred_names",
#' @param x A list containing the elements "combs2", "v_pairwise_0", "K", "pred_names",
#' "F_jk", "F_j", and "w".
#' @returns A list with the numerator and denominator statistics.
h2_pairwise_raw <- function(x) {
num <- init_numerator(x, way = 2L)
denom <- num + 1

# Note that F_jk are in the same order as x[["combs2"]]
combs <- x[["combs2"]]

# Note that F_jk are in the same order as combs
num <- denom <- with(
x, matrix(nrow = length(combs), ncol = K, dimnames = list(names(combs), pred_names))
)

for (i in seq_along(combs)) {
z <- combs[[i]]
num[i, ] <- with(x, wcolMeans((F_jk[[i]] - F_j[[z[1L]]] - F_j[[z[2L]]])^2, w = w))
denom[i, ] <- with(x, wcolMeans(F_jk[[i]]^2, w = w))
if (!is.null(combs)) {
for (nm in names(combs)) {
z <- combs[[nm]]
num[nm, ] <- with(x, wcolMeans((F_jk[[nm]] - F_j[[z[1L]]] - F_j[[z[2L]]])^2, w = w))
denom[nm, ] <- with(x, wcolMeans(F_jk[[nm]]^2, w = w))
}
}

list(num = num, denom = denom)
Expand Down
40 changes: 21 additions & 19 deletions R/H2_threeway.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ h2_threeway.default <- function(object, ...) {
#' @describeIn h2_threeway Pairwise interaction strength from "hstats" object.
#' @export
h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE,
top_m = 15L, eps = 1e-8, plot = FALSE,
fill = "#2b51a1", ...) {
top_m = 15L, zero = TRUE, eps = 1e-8,
plot = FALSE, fill = "#2b51a1", ...) {
s <- object$h2_threeway
if (is.null(s)) {
return(NULL)
Expand All @@ -81,7 +81,8 @@ h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE, sort =
normalize = normalize,
squared = squared,
sort = sort,
top_m = top_m,
top_m = top_m,
zero = zero,
eps = eps
)
if (plot) plot_stat(out, fill = fill, ...) else out
Expand All @@ -94,26 +95,27 @@ h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE, sort =
#'
#' @noRd
#' @keywords internal
#' @param x A list containing the elements "combs3", "K", "pred_names",
#' @param x A list containing the elements "combs3", "v_threeway_0", "K", "pred_names",
#' "F_jkl", "F_jk", "F_j", and "w".
#' @returns A list with the numerator and denominator statistics.
h2_threeway_raw <- function(x) {
combs <- x[["combs3"]]

# Note that the F_jkl are in the same order as combs
num <- denom <- with(
x, matrix(nrow = length(combs), ncol = K, dimnames = list(names(combs), pred_names))
)
num <- init_numerator(x, way = 3L)
denom <- num + 1

for (i in seq_along(combs)) {
z <- combs[[i]]
zz <- sapply(utils::combn(z, 2L, simplify = FALSE), paste, collapse = ":")

num[i, ] <- with(
x, wcolMeans((F_jkl[[i]] - Reduce("+", F_jk[zz]) + Reduce("+", F_j[z]))^2, w = w)
)
denom[i, ] <- with(x, wcolMeans(F_jkl[[i]]^2, w = w))
# Note that the F_jkl are in the same order as x[["combs3"]]
combs <- x[["combs3"]]
if (!is.null(combs)) {
for (nm in names(combs)) {
z <- combs[[nm]]
zz <- utils::combn(z, 2L, paste, collapse = ":")

num[nm, ] <- with(
x,
wcolMeans((F_jkl[[nm]] - Reduce("+", F_jk[zz]) + Reduce("+", F_j[z]))^2, w = w)
)
denom[nm, ] <- with(x, wcolMeans(F_jkl[[nm]]^2, w = w))
}
}

list(num = num, denom = denom)
}
Loading

0 comments on commit 523ff21

Please sign in to comment.