Skip to content

Commit

Permalink
Merge pull request #74 from mayer79/defaults
Browse files Browse the repository at this point in the history
Change hstats() defaults
  • Loading branch information
mayer79 authored Oct 11, 2023
2 parents c2cac59 + db60d77 commit 6fb1884
Show file tree
Hide file tree
Showing 20 changed files with 961 additions and 1,347 deletions.
5 changes: 2 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# hstats 0.4.0

This release mainly changes the *output*. The numeric results are unchanged.

## Major changes

- `hstats()`: `n_max` has been increased from 300 to 500 rows. This will make estimates of H statistics more stable at the price of longer run time. Reduce to 300 for the old behaviour.
- `hstats()`: By default, three-way interactions are not calculated anymore. Set `threeway_m` to 5 for the old behaviour.
- Revised plots: The colors and color palettes have changed and can (also) be controlled via global options. For instance, to change the fill color of all bars, set `options(hstats.fill = new value)`. Value labels are more clear, and there are more options. Varying color/fill scales now use viridis (inferno). This can be modified on the fly or via `options(hstats.viridis_args = list(...))`.
- "hstats_matrix" object: All statistics functions, e.g., `h2_pairwise()` or `perm_importance()`, now return a "hstats_matrix". The values are stored in `$M` and can be plotted via `plot()`.
- `perm_importance()`: The `perms` argument has been changed to `m_rep`.
Expand All @@ -29,7 +29,6 @@ This is intended to be the last version before 1.0.0.

- Grid of `ice()` and `partial_dep()`: So far, the default grid strategy "uniform" used `pretty()` to generate the evaluation points. To provide more predictable grid sizes, and to be more in line with other implementations of partial dependence and ICE, we now use `seq()` to create the uniform grid.
- `h2_pairwise()` and `h2_threeway()` will now also include 0 values. Use `zero = FALSE` to drop them, see below. The padding with 0 is done at no computational cost, and will affect only up to `pairwise_m` and `threeway_m` features.
- `hstats()`: The default number of features considered for *three-way interactions* has been changed from `threeway_m = pairwise_m` to the more cautious `threeway_m = min(pairwise_m, 5L)`. Furthermore, `threeway_m` is capped at `pairwise_m`.
- The `print()` method of `summary.hstats()` is less verbose.

## Improvements
Expand Down
8 changes: 5 additions & 3 deletions R/H2_threeway.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#' Three-way Interaction Strength
#'
#' Friedman and Popescu's statistic of three-way interaction strength, see Details.
#' Use `plot()` to get a barplot.
#' Use `plot()` to get a barplot. In `hstats()`, set `threeway_m` to a value above 2
#' to calculate this statistic for all feature triples of the `threeway_m`
#' features with strongest overall interaction.
#'
#' Friedman and Popescu (2008) describe a test statistic to measure three-way
#' interactions: in case there are no three-way interactions between features
Expand Down Expand Up @@ -42,12 +44,12 @@
#' @examples
#' # MODEL 1: Linear regression
#' fit <- lm(uptake ~ Type * Treatment * conc, data = CO2)
#' s <- hstats(fit, X = CO2[2:4], verbose = FALSE)
#' s <- hstats(fit, X = CO2[2:4], threeway_m = 5)
#' h2_threeway(s)
#'
#' #' MODEL 2: Multivariate output (taking just twice the same response as example)
#' fit <- lm(cbind(up = uptake, up2 = 2 * uptake) ~ Type * Treatment * conc, data = CO2)
#' s <- hstats(fit, X = CO2[2:4], verbose = FALSE)
#' s <- hstats(fit, X = CO2[2:4], threeway_m = 5)
#' h2_threeway(s)
#' h2_threeway(s, normalize = FALSE, squared = FALSE) # Unnormalized H
#' plot(h2_threeway(s))
Expand Down
38 changes: 18 additions & 20 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#' - Friedman and Popescu's statistic \eqn{H^2_{jk}} of pairwise interaction strength,
#' see [h2_pairwise()] for details.
#' - Friedman and Popescu's statistic \eqn{H^2_{jkl}} of three-way interaction strength,
#' see [h2_threeway()] for details.
#' see [h2_threeway()] for details. To save time, this statistic is not calculated
#' by default. Set `threeway_m` to a value above 2 to get three-way statistics of the
#' `threeway_m` variables with strongest overall interaction.
#'
#' Furthermore, it allows to calculate an experimental partial dependence based
#' measure of feature importance, \eqn{\textrm{PDI}_j^2}. It equals the proportion of
Expand Down Expand Up @@ -41,7 +43,7 @@
#' strongest variable names is taken. This can lead to very long run-times.
#' @param threeway_m Like `pairwise_m`, but controls the feature count for
#' three-way interactions. Cannot be larger than `pairwise_m`.
#' The default is `min(pairwise_m, 5)`. Set to 0 to avoid three-way calculations.
#' To save computation time, the default is 0.
#' @param eps Threshold below which numerator values are set to 0. Default is 1e-10.
#' @param verbose Should a progress bar be shown? The default is `TRUE`.
#' @param ... Additional arguments passed to `pred_fun(object, X, ...)`,
Expand Down Expand Up @@ -115,22 +117,22 @@
#' s <- hstats(fit, X = iris[-1], verbose = FALSE)
#' summary(s)
#'
#' # On original scale, we have interactions everywhere...
#' s <- hstats(fit, X = iris[-1], type = "response", verbose = FALSE)
#' plot(s, which = 1:3, ncol = 1) # All three types use different denominators
#' # On original scale, we have interactions everywhere.
#' # To see three-way interactions, we set threeway_m to a value above 2.
#' s <- hstats(fit, X = iris[-1], type = "response", threeway_m = 5)
#' plot(s, ncol = 1) # All three types use different denominators
#'
#' # All statistics on same scale (of predictions)
#' plot(s, which = 1:3, squared = FALSE, normalize = FALSE, facet_scale = "free_y")
#' plot(s, squared = FALSE, normalize = FALSE, facet_scale = "free_y")
hstats <- function(object, ...) {
UseMethod("hstats")
}

#' @describeIn hstats Default hstats method.
#' @export
hstats.default <- function(object, X, v = NULL,
pred_fun = stats::predict, n_max = 300L,
w = NULL, pairwise_m = 5L,
threeway_m = min(pairwise_m, 5L),
pred_fun = stats::predict, n_max = 500L,
w = NULL, pairwise_m = 5L, threeway_m = 0L,
eps = 1e-10, verbose = TRUE, ...) {
stopifnot(
is.matrix(X) || is.data.frame(X),
Expand Down Expand Up @@ -260,8 +262,7 @@ hstats.default <- function(object, X, v = NULL,
#' @export
hstats.ranger <- function(object, X, v = NULL,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
n_max = 300L, w = NULL, pairwise_m = 5L,
threeway_m = min(pairwise_m, 5L),
n_max = 500L, w = NULL, pairwise_m = 5L, threeway_m = 0L,
eps = 1e-10, verbose = TRUE, ...) {
hstats.default(
object = object,
Expand All @@ -282,8 +283,7 @@ hstats.ranger <- function(object, X, v = NULL,
#' @export
hstats.Learner <- function(object, X, v = NULL,
pred_fun = NULL,
n_max = 300L, w = NULL, pairwise_m = 5L,
threeway_m = min(pairwise_m, 5L),
n_max = 500L, w = NULL, pairwise_m = 5L, threeway_m = 0L,
eps = 1e-10, verbose = TRUE, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
Expand All @@ -308,9 +308,8 @@ hstats.Learner <- function(object, X, v = NULL,
hstats.explainer <- function(object, X = object[["data"]],
v = NULL,
pred_fun = object[["predict_function"]],
n_max = 300L, w = object[["weights"]],
pairwise_m = 5L,
threeway_m = min(pairwise_m, 5L),
n_max = 500L, w = object[["weights"]],
pairwise_m = 5L, threeway_m = 0L,
eps = 1e-10, verbose = TRUE, ...) {
hstats.default(
object = object[["model"]],
Expand Down Expand Up @@ -401,15 +400,14 @@ print.hstats_summary <- function(x, ...) {
#' Plot method for object of class "hstats".
#'
#' @param x Object of class "hstats".
#' @param which Which statistic(s) to be shown? Default is `1:2`, i.e., show both
#' \eqn{H^2_j} (1) and \eqn{H^2_{jk}} (2). To also show three-way interactions,
#' use `1:3`.
#' @param which Which statistic(s) to be shown? Default is `1:3`, i.e.,
#' show \eqn{H^2_j} (1), \eqn{H^2_{jk}} (2), and \eqn{H^2_{jkl}} (3).
#' @inheritParams plot.hstats_matrix
#' @inheritParams h2_overall
#' @returns An object of class "ggplot".
#' @export
#' @seealso See [hstats()] for examples.
plot.hstats <- function(x, which = 1:2, normalize = TRUE, squared = TRUE,
plot.hstats <- function(x, which = 1:3, normalize = TRUE, squared = TRUE,
sort = TRUE, top_m = 15L, zero = TRUE,
fill = getOption("hstats.fill"),
viridis_args = getOption("hstats.viridis_args"),
Expand Down
46 changes: 15 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The core functions `hstats()`, `partial_dep()`, `ice()`, `perm_importance()`, an
## Limitations

1. H-statistics are based on partial dependence estimates and are thus as good or bad as these. One of their problems is that the model is applied to unseen/impossible feature combinations. In extreme cases, H-statistics intended to be in the range between 0 and 1 can become larger than 1. Accumulated local effects (ALE) [8] mend above problem of partial dependence estimates. They, however, depend on the notion of "closeness", which is highly non-trivial in higher dimension and for discrete features.
2. Due to their computational complexity of $O(n^2)$, where $n$ is the number of rows considered, H-statistics are usually evaluated on relatively small subsets of the training (or validation/test) data. Consequently, the estimates are typically not very robust. To get more robust results, increase the default `n_max = 300` of `hstats()`.
2. Due to their computational complexity of $O(n^2)$, where $n$ is the number of rows considered, H-statistics are usually evaluated on relatively small subsets of the training (or validation/test) data. Consequently, the estimates are typically not very robust. To get more robust results, increase the default `n_max = 500` of `hstats()`.

## Landscape

Expand Down Expand Up @@ -102,7 +102,7 @@ average_loss(fit, X = X_valid, y = y_valid)
Let's calculate different H-statistics via `hstats()`:

```r
# 3 seconds on simple laptop - a random forest will take 1-2 minutes
# 4 seconds on simple laptop - a random forest will take 2-3 minutes
set.seed(782)
system.time(
s <- hstats(fit, X = X_train)
Expand All @@ -128,24 +128,18 @@ plot(s) # Or summary(s) for numeric output
**Remarks**

1. Pairwise statistics $H^2_{jk}$ are calculated only for the features with strong overall interactions $H^2_j$.
2. H-statistics need to repeatedly calculate predictions on up to $n^2$ rows. That is why {hstats} samples 300 rows by default. To get more robust results, increase this value at the price of slower run time.
2. H-statistics need to repeatedly calculate predictions on up to $n^2$ rows. That is why {hstats} samples 500 rows by default. To get more robust results, increase this value at the price of slower run time.
3. Pairwise statistics $H^2_{jk}$ measures interaction strength relative to the combined effect of the two features. This does not necessarily show which interactions are strongest in absolute numbers. To do so, we can study unnormalized statistics:

```r
plot(h2_pairwise(s, normalize = FALSE, squared = FALSE), top_m = 5)
plot(h2_pairwise(s, normalize = FALSE, squared = FALSE))
```

![](man/figures/hstats_pairwise.svg)

Since distance to the ocean and age have high values in overall interaction strength, it is not surprising that a strong relative pairwise interaction is translated into a strong absolute one.

{hstats} crunches three-way interaction statistics $H^2_{jkl}$ as well. The following plot shows them together with the other statistics on prediction scale (`normalize = FALSE` and `squared = FALSE`). The three-way interactions are weaker than the pairwise interactions, yet not negligible:

```r
plot(s, which = 1:3, normalize = F, squared = F, facet_scales = "free_y", ncol = 1)
```

![](man/figures/hstats3.svg)
Note: {hstats} can crunch **three-way** interaction statistics $H^2_{jkl}$ as well. To calculate them for $m$ features with strongest overall interactions, set `threeway_m = m`.

### Describe interaction

Expand Down Expand Up @@ -177,16 +171,6 @@ plot(ic, center = TRUE)

![](man/figures/ice.svg)

The last figure tries to visualize the strongest three-way interaction, without much success though:

```r
BY <- data.frame(X_train[, c("age", "log_ocean")])
BY$log_ocean <- BY$log_ocean < 10
plot(ice(fit, v = "tot_lvg_area", X = X_train, BY = BY), center = TRUE)
```

![](man/figures/pdp3.svg)

### Variable importance

In the spirit of [1], and related to [4], we can extract from the "hstats" objects a partial dependence based variable importance measure. It measures not only the main effect strength (see [4]), but also all its interaction effects. It is rather experimental, so use it with care (details in the section "Background"):
Expand Down Expand Up @@ -281,7 +265,7 @@ perm_importance(fit, X = iris, y = "Species", loss = "mlogloss")

### LightGBM

Note: Versions below 4.0.0 require to pass `reshape = TRUE` to the prediction function.
Note: Versions from 4.0.0 upwards to not anymore require passing `reshape = TRUE` to the prediction function.

```r
library(hstats)
Expand Down Expand Up @@ -310,7 +294,7 @@ fit <- lgb.train(
)

# Check that predictions require reshape = TRUE to be a matrix
predict(fit, head(X_pred, 2), reshape = TRUE)
predict(fit, head(X_train, 2), reshape = TRUE)
# [,1] [,2] [,3]
# [1,] 0.9999997 2.918695e-07 2.858720e-14
# [2,] 0.9999999 1.038470e-07 7.337221e-10
Expand All @@ -331,9 +315,9 @@ perm_importance(
# Petal.Length Petal.Width Sepal.Width Sepal.Length
# 2.61783760 1.00647382 0.08414687 0.01011645

# Interaction statistics (H-statistics)
(H <- hstats(fit, X = X_train, reshape = TRUE)) # 0.3010446 0.4167927 0.1623982
plot(H, normalize = FALSE, squared = FALSE)
# Interaction statistics, including three-way stats
(H <- hstats(fit, X = X_train, reshape = TRUE, threeway_m = 4)) # 0.3010446 0.4167927 0.1623982
plot(H, normalize = FALSE, squared = FALSE, facet_scales = "free_y", ncol = 1)
```

![](man/figures/lightgbm.svg)
Expand Down Expand Up @@ -369,7 +353,7 @@ fit <- xgb.train(
)

# We need to pass reshape = TRUE to get a beautiful matrix
predict(fit, head(X_pred, 2), reshape = TRUE)
predict(fit, head(X_train, 2), reshape = TRUE)
# [,1] [,2] [,3]
# [1,] 0.9974016 0.002130089 0.0004682819
# [2,] 0.9971375 0.002129525 0.0007328897
Expand All @@ -390,9 +374,9 @@ perm_importance(
# Petal.Length Petal.Width Sepal.Length Sepal.Width
# 1.731532873 0.276671377 0.009158659 0.005717263

# Interaction statistics (H-statistics)
(H <- hstats(fit, X = X_train, reshape = TRUE)) # 0.02714399 0.16067364 0.11606973
plot(H, normalize = FALSE, squared = FALSE)
# Interaction statistics including three-way stats
(H <- hstats(fit, X = X_train, reshape = TRUE, threeway_m = 4)) # 0.02714399 0.16067364 0.11606973
plot(H, normalize = FALSE, squared = FALSE, facet_scales = "free_y", ncol = 1)
```

![](man/figures/xgboost.svg)
Expand Down Expand Up @@ -621,7 +605,7 @@ In [5], $1 - H^2$ is called *additivity index*. A similar measure using accumula

Calculation of all $H_j^2$ requires $O(n^2 p)$ predictions, while calculating of all pairwise $H_{jk}$ requires $O(n^2 p^2$ predictions. Therefore, we suggest to reduce the workflow in two important ways:

1. Evaluate the statistics only on a subset of the data, e.g., on $n' = 300$ observations.
1. Evaluate the statistics only on a subset of the data, e.g., on $n' = 500$ observations.
2. Calculate $H_j^2$ for all features. Then, select a small number $m = O(\sqrt{p})$ of features with highest $H^2_j$ and do pairwise calculations only on this subset.

This leads to a total number of $O(n'^2 p)$ predictions. If also three-way interactions are to be studied, $m$ should be of the order $p^{1/3}$.
Expand Down
9 changes: 5 additions & 4 deletions backlog/modeltuner.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ predict(mm, head(iris))

average_loss(mm, X = iris, y = "Sepal.Length", BY = "Species", w = "Petal.Width") |>
plot()
partial_dep(mm, v = "Sepal.Width", X = iris, BY = "Species", w = "Petal.Width") |>
plot(show_points = FALSE)
pd <- partial_dep(mm, v = "Sepal.Width", X = iris, BY = "Species", w = "Petal.Width")
plot(pd, show_points = FALSE)
plot(pd, show_points = FALSE, swap_dim = TRUE)
ice(mm, v = "Sepal.Width", X = iris, BY = "Species") |>
plot(facet_scales = "fixed")
plot(center = TRUE)

perm_importance(mm, X = iris, y = "Sepal.Length", w = "Petal.Width") |>
plot()
Expand All @@ -23,4 +24,4 @@ H <- hstats(mm, X = iris[-1], w = "Petal.Width")
H
plot(H)
h2_pairwise(H, normalize = FALSE, squared = FALSE) |>
plot()
plot(swap_dim = TRUE)
8 changes: 5 additions & 3 deletions man/H2_threeway.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 6fb1884

Please sign in to comment.