Skip to content

Commit

Permalink
Merge pull request #68 from mayer79/pd_ice
Browse files Browse the repository at this point in the history
Color scales
  • Loading branch information
mayer79 authored Oct 7, 2023
2 parents 34148e1 + 63ef433 commit 61e2b48
Show file tree
Hide file tree
Showing 41 changed files with 2,243 additions and 2,143 deletions.
10 changes: 6 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@ This release mainly changes the *output*. The numeric results are unchanged.

## Major changes

- Revised plots: The colors have changed and can (also) be controlled via global options. For instance, to change the fill color of all bars, set `options(stats.fill = new value)`. Value labels are more clear, and there are more options. Stacked barplots use viridis.
- "hstats_matrix" object: All statistics functions, e.g., `h2_pairwise()` or `perm_importance()`, return a "hstats_matrix" object. The values are stored in `$M` and can be plotted via `plot()`.
- Revised plots: The colors and color palettes have changed and can (also) be controlled via global options. For instance, to change the fill color of all bars, set `options(hstats.fill = new value)`. Value labels are more clear, and there are more options. Varying color/fill scales now use viridis (inferno). This can be modified on the fly or via `options(hstats.viridis_args = list(...))`.
- "hstats_matrix" object: All statistics functions, e.g., `h2_pairwise()` or `perm_importance()`, now return a "hstats_matrix". The values are stored in `$M` and can be plotted via `plot()`.
- `perm_importance()`: The `perms` argument has been changed to `m_rep`.
- All `print()` and `summary()` methods have been revised.
- `print()` and `summary()` methods have been revised.

## Minor changes

- Statistics: Their argument `top_m` has been moved to the `plot()` method.
- Statistics: The argument `top_m` has been moved to the `plot()` method.
- Statistics: The clipping threshold `eps` of squared numerator statistics has been reduced from `1e-8` to `1e-10`. It is now handled in `hstats()` instead of the statistic functions.
- `H-squared`: The $H^2$ statistic stored in a "hstats" object is now a matrix with one row (it was a vector).
- `pd_importance()`: The "hstats" object now contains pre-calculated PD-based importance values in `$pd_importance`.
- `summary.hstats()` now returns an object of class "hstats_summary" instead of "summary_hstats".
- `average_loss()` is more flexible regarding the group `BY` argument. It can also be a variable *name*. Non-discrete `BY` variables are now automatically binned. Like `partial_dep()`, binning is controlled by the `by_size = 4` argument.
- `average_loss()` also returns a "hstats_matrix" object with `print()` and `plot()` method. The values can be extracted via `$M`.

# hstats 0.3.0

Expand Down
30 changes: 17 additions & 13 deletions R/average_loss.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Average Loss
#'
#' Calculates the average loss of a model on a given dataset,
#' optionally grouped by a discrete vector. Use `plot()` to get a barplot.
#' optionally grouped by a variable. Use `plot()` to visualize the results.
#'
#' @section Losses:
#'
Expand Down Expand Up @@ -39,22 +39,26 @@
#' For "mlogloss", the response `y` can either be a dummy matrix or a discrete vector.
#' The latter case is handled via `model.matrix(~ as.factor(y) + 0)`.
#' For "classification_error", both predictions and responses can be non-numeric.
#' @param BY Optional grouping vector.
#' @param BY Optional grouping vector or column name.
#' Numeric `BY` variables with more than `by_size` disjoint values will be
#' binned into `by_size` quantile groups of similar size.
#' @param by_size Numeric `BY` variables with more than `by_size` unique values will
#' be binned into quantile groups. Only relevant if `BY` is not `NULL`.
#' @inherit h2_overall return
#' @export
#' @examples
#' # MODEL 1: Linear regression
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' average_loss(fit, X = iris, y = iris$Sepal.Length)
#' average_loss(fit, X = iris, y = iris$Sepal.Length, BY = iris$Species)
#' average_loss(fit, X = iris, y = iris$Sepal.Length, BY = "Sepal.Width")
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
#' average_loss(fit, X = iris, y = iris[1:2])
#' L <- average_loss(fit, X = iris, y = iris[1:2], loss = "gamma", BY = iris$Species)
#' L <- average_loss(fit, X = iris, y = iris[1:2], loss = "gamma", BY = "Species")
#' L
#' plot(L)
#' plot(L, multi_output = "facets")
average_loss <- function(object, ...) {
UseMethod("average_loss")
}
Expand All @@ -63,7 +67,8 @@ average_loss <- function(object, ...) {
#' @export
average_loss.default <- function(object, X, y,
pred_fun = stats::predict,
BY = NULL, loss = "squared_error", w = NULL, ...) {
loss = "squared_error",
BY = NULL, by_size = 4L, w = NULL, ...) {
stopifnot(
is.matrix(X) || is.data.frame(X),
nrow(X) >= 1L,
Expand All @@ -72,11 +77,7 @@ average_loss.default <- function(object, X, y,
NROW(y) == nrow(X)
)
if (!is.null(BY)) {
stopifnot(
NCOL(BY) == 1L,
is.vector(BY) || is.factor(BY),
length(BY) == nrow(X)
)
BY <- prepare_by(BY = BY, X = X, by_size = by_size)[["BY"]]
}
if (!is.function(loss)) {
loss <- get_loss_fun(loss)
Expand All @@ -102,7 +103,8 @@ average_loss.default <- function(object, X, y,
#' @export
average_loss.ranger <- function(object, X, y,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
BY = NULL, loss = "squared_error", w = NULL, ...) {
loss = "squared_error",
BY = NULL, by_size = 4L, w = NULL, ...) {
average_loss.default(
object = object,
X = X,
Expand All @@ -119,7 +121,8 @@ average_loss.ranger <- function(object, X, y,
#' @export
average_loss.Learner <- function(object, v, X, y,
pred_fun = NULL,
BY = NULL, loss = "squared_error", w = NULL, ...) {
loss = "squared_error",
BY = NULL, by_size = 4L, w = NULL, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
Expand All @@ -141,8 +144,9 @@ average_loss.explainer <- function(object,
X = object[["data"]],
y = object[["y"]],
pred_fun = object[["predict_function"]],
BY = NULL,
loss = "squared_error",
BY = NULL,
by_size = 4L,
w = object[["weights"]],
...) {
average_loss.default(
Expand Down
28 changes: 17 additions & 11 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,12 @@ print.hstats_summary <- function(x, ...) {
plot.hstats <- function(x, which = 1:2, normalize = TRUE, squared = TRUE,
sort = TRUE, top_m = 15L, zero = TRUE,
fill = getOption("hstats.fill"),
scale_fill_d = getOption("hstats.scale_fill_d"),
viridis_args = getOption("hstats.viridis_args"),
facet_scales = "free", ncol = 2L, rotate_x = FALSE, ...) {
if (is.null(viridis_args)) {
viridis_args <- list()
}

su <- summary(x, normalize = normalize, squared = squared, sort = sort, zero = zero)
su <- su[sapply(su, FUN = function(z) !is.null(z[["M"]]))]

Expand All @@ -407,28 +411,30 @@ plot.hstats <- function(x, which = 1:2, normalize = TRUE, squared = TRUE,
mat2df(utils::head(su[[nm]]$M, top_m), id = stat_labs[match(nm, stat_names)])
)
dat <- do.call(rbind, dat)
dat <- barplot_reverter(dat)

p <- ggplot2::ggplot(dat, ggplot2::aes(x = value_, y = variable_)) +
ggplot2::ylab(ggplot2::element_blank()) +
ggplot2::xlab(su$h2$description) # Generic enough?

if (length(ok) > 1L) {
p <- p +
ggplot2::facet_wrap(~ id_, ncol = ncol, scales = facet_scales)
}
if (rotate_x) {
p <- p + rotate_x_labs()
}
if (x[["K"]] == 1L) {
p + ggplot2::geom_bar(fill = fill, stat = "identity", ...)
p <- p + ggplot2::geom_bar(fill = fill, stat = "identity", ...)
} else {
p +
p <- p +
ggplot2::geom_bar(
ggplot2::aes(fill = varying_), stat = "identity", position = "dodge", ...
) +
ggplot2::theme(legend.title = ggplot2::element_blank()) +
scale_fill_d
do.call(ggplot2::scale_fill_viridis_d, viridis_args) +
ggplot2::guides(fill = ggplot2::guide_legend(reverse = TRUE))
}
if (length(ok) > 1L) {
p <- p + ggplot2::facet_wrap(~ id_, ncol = ncol, scales = facet_scales)
}
if (rotate_x) {
p <- p + rotate_x_labs()
}
p
}

# Helper functions used only in this script
Expand Down
49 changes: 38 additions & 11 deletions R/ice.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#' ic <- ice(fit, v = "Petal.Width", X = iris, BY = iris$Species)
#' plot(ic)
#' plot(ic, center = TRUE)
#' plot(ic, swap_dim = TRUE)
#'
#' # MODEL 3: Gamma GLM -> pass options to predict() via ...
#' fit <- glm(Sepal.Length ~ ., data = iris, family = Gamma(link = log))
Expand All @@ -69,7 +70,7 @@ ice.default <- function(object, v, X, pred_fun = stats::predict,
check_grid(g = grid, v = v, X_is_matrix = is.matrix(X))
}

# Prepare BY
# Prepare BY (could be integrated into prepare_by())
if (!is.null(BY)) {
if (length(BY) <= 2L && all(BY %in% colnames(X))) {
by_names <- BY
Expand Down Expand Up @@ -224,24 +225,30 @@ print.ice <- function(x, n = 3L, ...) {
#' @param x An object of class "ice".
#' @param center Should curves be centered? Default is `FALSE`.
#' @param alpha Transparency passed to `ggplot2::geom_line()`.
#' @param swap_dim Swaps between color groups and facets. Default is `FALSE`.
#' @export
#' @returns An object of class "ggplot".
#' @seealso See [ice()] for examples.
plot.ice <- function(x, center = FALSE, alpha = 0.2,
color = getOption("hstats.color"),
swap_dim = FALSE,
viridis_args = getOption("hstats.viridis_args"),
facet_scales = "fixed",
rotate_x = FALSE, ...) {
v <- x[["v"]]
K <- x[["K"]]
data <- x[["data"]]
pred_names <- x[["pred_names"]]
by_names <- x[["by_names"]]
if (is.null(viridis_args)) {
viridis_args <- list()
}

if (length(v) > 1L) {
stop("Maximal one feature v can be plotted.")
}
if ((K > 1L) + length(by_names) > 2L) {
stop("Two BY variables and multivariate output has no plot method yet.")
stop("Two BY variables of multivariate output is not supported yet.")
}
if (center) {
pos <- trunc((NROW(x[["grid"]]) + 1) / 2)
Expand All @@ -252,22 +259,42 @@ plot.ice <- function(x, center = FALSE, alpha = 0.2,
}
data <- poor_man_stack(data, to_stack = pred_names)

# Distinguish all possible cases
grp <- if (is.null(by_names) && K > 1L) "varying_" else by_names[1L] # can be NULL
wrp <- if (!is.null(by_names) && K > 1L) "varying_"
if (length(by_names) == 2L) {
wrp <- by_names[2L]
}
if (swap_dim) {
tmp <- grp
grp <- wrp
wrp <- tmp
}

if (!is.null(grp) && grp == "varying_") {
data <- transform(data, obs_ = interaction(obs_, varying_))
}

p <- ggplot2::ggplot(data, ggplot2::aes(x = .data[[v]], y = value_, group = obs_)) +
ggplot2::labs(x = v, y = if (center) "Centered ICE" else "ICE")

if (is.null(by_names)) {
if (is.null(grp)) {
p <- p + ggplot2::geom_line(color = color, alpha = alpha, ...)
} else {
p <- p +
ggplot2::geom_line(
ggplot2::aes(color = .data[[by_names[1L]]]), alpha = alpha, ...
) +
ggplot2::labs(color = by_names[1L]) +
ggplot2::geom_line(ggplot2::aes(color = .data[[grp]]), alpha = alpha, ...) +
ggplot2::labs(color = grp) +
do.call(get_color_scale(data[[grp]]), viridis_args) +
ggplot2::guides(color = ggplot2::guide_legend(override.aes = list(alpha = 1)))
if (grp == "varying_") {
p <- p + ggplot2::theme(legend.title = ggplot2::element_blank())
}
}
if (K > 1L || length(by_names) == 2L) { # Only one is possible
wrp <- if (K > 1L) "varying_" else by_names[2L]
if (!is.null(wrp)) {
p <- p + ggplot2::facet_wrap(wrp, scales = facet_scales)
}
if (rotate_x) p + rotate_x_labs() else p
if (rotate_x) {
p <- p + rotate_x_labs()
}
p
}
6 changes: 2 additions & 4 deletions R/onLoad.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
op <- options()
op.hstats <- list(
hstats.fill = "#fca50a",
hstats.color = "#3b528b",
hstats.scale_fill_d = ggplot2::scale_fill_viridis_d(
begin = 0.25, end = 0.85, option = "inferno"
)
hstats.color = "#420a68",
hstats.viridis_args = list(begin = 0.2, end = 0.8, option = "B")
)
toset <- !(names(op.hstats) %in% names(op))
if (any(toset)) {
Expand Down
Loading

0 comments on commit 61e2b48

Please sign in to comment.