Skip to content

Commit

Permalink
Merge pull request #67 from mayer79/colors
Browse files Browse the repository at this point in the history
Colors
  • Loading branch information
mayer79 authored Oct 3, 2023
2 parents 9cc20fe + 313c4ff commit 34148e1
Show file tree
Hide file tree
Showing 23 changed files with 1,000 additions and 918 deletions.
14 changes: 7 additions & 7 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
# hstats 0.4.0

This release comes with a cleaner output API. The numeric results are unchanged.
This release mainly changes the *output*. The numeric results are unchanged.

## Major changes

- `h2()`, `h2_overall()`, `h2_pairwise()`, `h2_threeway()`, `perm_importance()`, and `pd_importance()` now return an object of type "hstats_matrix" with a `print()` and `plot()` method. The values can be extracted via `$M`.
- Their argument `top_m` has been moved to `plot()`.
- `perm_importance()`: The `perms` argument has been renamed to `m_rep`. Since the output is now of class "hstats_matrix", the resulting importance values are stored as `$M`.
- All `print()`, `summary()`, and `plot()` methods have been revised.
- Revised plots: The colors have changed and can (also) be controlled via global options. For instance, to change the fill color of all bars, set `options(stats.fill = new value)`. Value labels are more clear, and there are more options. Stacked barplots use viridis.
- "hstats_matrix" object: All statistics functions, e.g., `h2_pairwise()` or `perm_importance()`, return a "hstats_matrix" object. The values are stored in `$M` and can be plotted via `plot()`.
- `perm_importance()`: The `perms` argument has been changed to `m_rep`.
- All `print()` and `summary()` methods have been revised.

## Minor changes

- Plotting the result of `perm_importance()` on a multi-output model now produces a stacked barplot. Set `multi_output = "facets"` for the old behaviour.
- Statistics: Their argument `top_m` has been moved to the `plot()` method.
- Statistics: The clipping threshold `eps` of squared numerator statistics has been reduced from `1e-8` to `1e-10`. It is now handled in `hstats()` instead of the statistic functions.
- `H-squared`: The $H^2$ statistic stored in a "hstats" object is now a matrix with one row (it was a vector).
- `eps`: The clipping threshold of squared numerator statistics has been reduced from 1e-8 to 1e-10. It is now handled in `hstats()` instead of the statistic functions.
- `pd_importance()`: The "hstats" object now contains pre-calculated PD-based importance values in `$pd_importance`.
- `summary.hstats()` now returns an object of class "hstats_summary" instead of "summary_hstats".

Expand Down
22 changes: 18 additions & 4 deletions R/average_loss.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Average Loss
#'
#' Calculates the average loss of a model on a given dataset,
#' optionally grouped by a discrete vector.
#' optionally grouped by a discrete vector. Use `plot()` to get a barplot.
#'
#' @section Losses:
#'
Expand Down Expand Up @@ -40,7 +40,7 @@
#' The latter case is handled via `model.matrix(~ as.factor(y) + 0)`.
#' For "classification_error", both predictions and responses can be non-numeric.
#' @param BY Optional grouping vector.
#' @returns A matrix with one row per group and one column per loss dimension.
#' @inherit h2_overall return
#' @export
#' @examples
#' # MODEL 1: Linear regression
Expand All @@ -51,7 +51,10 @@
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
#' average_loss(fit, X = iris, y = iris[1:2])
#' average_loss(fit, X = iris, y = iris[1:2], loss = "gamma", BY = iris$Species)
#' L <- average_loss(fit, X = iris, y = iris[1:2], loss = "gamma", BY = iris$Species)
#' L
#' plot(L)
#' plot(L, multi_output = "facets")
average_loss <- function(object, ...) {
UseMethod("average_loss")
}
Expand Down Expand Up @@ -81,7 +84,18 @@ average_loss.default <- function(object, X, y,

# Real work
L <- as.matrix(loss(y, pred_fun(object, X, ...)))
gwColMeans(L, g = BY, w = w)
M <- gwColMeans(L, g = BY, w = w)

structure(
list(
M = M,
SE = NULL,
mrep = NULL,
statistic = "average_loss",
description = "Average loss"
),
class = "hstats_matrix"
)
}

#' @describeIn average_loss Method for "ranger" models.
Expand Down
13 changes: 8 additions & 5 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -386,11 +386,13 @@ print.hstats_summary <- function(x, ...) {
#' @export
#' @seealso See [hstats()] for examples.
plot.hstats <- function(x, which = 1:2, normalize = TRUE, squared = TRUE,
sort = TRUE, top_m = 15L, zero = TRUE, fill = "#2b51a1",
sort = TRUE, top_m = 15L, zero = TRUE,
fill = getOption("hstats.fill"),
scale_fill_d = getOption("hstats.scale_fill_d"),
facet_scales = "free", ncol = 2L, rotate_x = FALSE, ...) {
su <- summary(x, normalize = normalize, squared = squared, sort = sort, zero = zero)
su <- su[sapply(su, FUN = function(z) !is.null(z[["M"]]))]

# This part could be simplified, especially the "match()"
stat_names <- c("h2_overall", "h2_pairwise", "h2_threeway")[which]
stat_labs <- c("Overall", "Pairwise", "Three-way")[which]
Expand All @@ -410,21 +412,22 @@ plot.hstats <- function(x, which = 1:2, normalize = TRUE, squared = TRUE,
ggplot2::ylab(ggplot2::element_blank()) +
ggplot2::xlab(su$h2$description) # Generic enough?

if (length(unique(dat[["id_"]])) > 1L) {
if (length(ok) > 1L) {
p <- p +
ggplot2::facet_wrap(~ id_, ncol = ncol, scales = facet_scales)
}
if (rotate_x) {
p <- p + rotate_x_labs()
}
if (length(unique(dat[["varying_"]])) == 1L) {
if (x[["K"]] == 1L) {
p + ggplot2::geom_bar(fill = fill, stat = "identity", ...)
} else {
p +
ggplot2::geom_bar(
ggplot2::aes(fill = varying_), stat = "identity", position = "dodge", ...
) +
ggplot2::theme(legend.title = ggplot2::element_blank())
ggplot2::theme(legend.title = ggplot2::element_blank()) +
scale_fill_d
}
}

Expand Down
7 changes: 5 additions & 2 deletions R/ice.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,18 @@ print.ice <- function(x, n = 3L, ...) {
#' Plot method for objects of class "ice".
#'
#' @importFrom ggplot2 .data
#' @inheritParams plot.hstats_matrix
#' @inheritParams plot.partial_dep
#' @param x An object of class "ice".
#' @param center Should curves be centered? Default is `FALSE`.
#' @param alpha Transparency passed to `ggplot2::geom_line()`.
#' @export
#' @returns An object of class "ggplot".
#' @seealso See [ice()] for examples.
plot.ice <- function(x, center = FALSE, alpha = 0.2, rotate_x = FALSE,
color = "#2b51a1", facet_scales = "fixed", ...) {
plot.ice <- function(x, center = FALSE, alpha = 0.2,
color = getOption("hstats.color"),
facet_scales = "fixed",
rotate_x = FALSE, ...) {
v <- x[["v"]]
K <- x[["K"]]
data <- x[["data"]]
Expand Down
15 changes: 15 additions & 0 deletions R/onLoad.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
.onLoad <- function(libname, pkgname) {
op <- options()
op.hstats <- list(
hstats.fill = "#fca50a",
hstats.color = "#3b528b",
hstats.scale_fill_d = ggplot2::scale_fill_viridis_d(
begin = 0.25, end = 0.85, option = "inferno"
)
)
toset <- !(names(op.hstats) %in% names(op))
if (any(toset)) {
options(op.hstats[toset])
}
invisible()
}
9 changes: 7 additions & 2 deletions R/partial_dep.R
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,19 @@ print.partial_dep <- function(x, n = 3L, ...) {
#' @importFrom ggplot2 .data
#' @param x An object of class "partial_dep".
#' @param color Color of lines and points (in case there is no color/fill aesthetic).
#' The default equals the global option `hstats.color = "#3b528b"`.
#' To change the global option, use `options(stats.color = new value)`.
#' @param show_points Logical flag indicating whether to show points (default) or not.
#' @param ... Arguments passed to geometries.
#' @inheritParams plot.hstats_matrix
#' @export
#' @returns An object of class "ggplot".
#' @seealso See [partial_dep()] for examples.
plot.partial_dep <- function(x, rotate_x = FALSE, color = "#2b51a1",
facet_scales = "free_y", show_points = TRUE, ...) {
plot.partial_dep <- function(x,
color = getOption("hstats.color"),
facet_scales = "free_y",
rotate_x = FALSE,
show_points = TRUE, ...) {
v <- x[["v"]]
by_name <- x[["by_name"]]
K <- x[["K"]]
Expand Down
18 changes: 14 additions & 4 deletions R/statistics.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,16 @@ print.hstats_matrix <- function(x, top_m = Inf, ...) {
#'
#' @importFrom ggplot2 .data
#' @param x An object of class "hstats_matrix".
#' @param top_m How many rows should be plotted? (`Inf` for all.)
#' @param top_m How many rows should be plotted? `Inf` for all.
#' @param fill Fill color of ungrouped bars. The default equals the global option
#' `hstats.fill = "#fca50a"`. To change the global option, use
#' `options(stats.fill = new value)`.
#' @param multi_output How should multi-output models be represented?
#' Either as "grouped" barplot (the default) or via "facets".
#' @param fill Color of bars.
#' @param scale_fill_d Discrete fill scale for grouped bars. The default equals the
#' global option `hstats.scale_fill_d`, which equals
#' `scale_fill_viridis_d(begin = 0.25, end = 0.85, option = "inferno")`.
#' To change the global option, use `options(hstats.scale_fill_d = new value)`.
#' @param facet_scales Value passed as `scales` argument to `[ggplot2::facet_wrap()]`.
#' @param ncol Passed to `[ggplot2::facet_wrap()]`.
#' @param rotate_x Should x axis labels be rotated by 45 degrees?
Expand All @@ -146,8 +152,10 @@ print.hstats_matrix <- function(x, top_m = Inf, ...) {
#' @export
#' @returns An object of class "ggplot".
plot.hstats_matrix <- function(x, top_m = 15L,
fill = getOption("hstats.fill"),
multi_output = c("grouped", "facets"),
fill = "#2b51a1", facet_scales = "free",
scale_fill_d = getOption("hstats.scale_fill_d"),
facet_scales = "free",
ncol = 2L, rotate_x = FALSE,
err_type = c("SE", "SD", "No"), ...) {
err_type <- match.arg(err_type)
Expand Down Expand Up @@ -182,7 +190,9 @@ plot.hstats_matrix <- function(x, top_m = 15L,
} else {
p <- p + ggplot2::geom_bar(
ggplot2::aes(fill = varying_), stat = "identity", position = "dodge", ...
) + ggplot2::theme(legend.title = ggplot2::element_blank())
) +
ggplot2::theme(legend.title = ggplot2::element_blank()) +
scale_fill_d
}
if (err_type != "No") {
if (!grouped) {
Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,7 @@ s
# setosa versicolor virginica
# 0.001547791 0.064550141 0.049758237

plot(s, normalize = FALSE, squared = FALSE) +
scale_fill_viridis_d(begin = 0.1, end = 0.9)
plot(s, normalize = FALSE, squared = FALSE)

ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width", n_max = 150) |>
plot(center = TRUE) +
Expand Down
19 changes: 16 additions & 3 deletions man/average_loss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 34148e1

Please sign in to comment.