Skip to content

Commit

Permalink
Merge pull request #70 from mayer79/y_as_names
Browse files Browse the repository at this point in the history
Responses can now be passed as column names
  • Loading branch information
mayer79 authored Oct 7, 2023
2 parents 4cc732f + c81b84e commit 6aaac0a
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 44 deletions.
5 changes: 3 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ This release mainly changes the *output*. The numeric results are unchanged.
- `summary.hstats()` now returns an object of class "hstats_summary" instead of "summary_hstats".
- `average_loss()` is more flexible regarding the group `BY` argument. It can also be a variable *name*. Non-discrete `BY` variables are now automatically binned. Like `partial_dep()`, binning is controlled by the `by_size = 4` argument.
- `average_loss()` also returns a "hstats_matrix" object with `print()` and `plot()` method. The values can be extracted via `$M`.
- Case weights `w` can now also be passed as column name of `X`.
- The default `v` of `hstats()` and `perm_importance()` is now `NULL`. Internally, it is set to `colnames(X)` (minus the column name of `w` if passed as name).
- Case weights `w` can now also be passed as column name of `X` (to any function).
- `perm_importance()` and `average_loss()`: The response(s) `y` can now also be passed as column name(s) of `X`.
- The default `v` of `hstats()` and `perm_importance()` is now `NULL`. Internally, it is set to `colnames(X)` (minus the column names of `w` and `y` if passed as name).

# hstats 0.3.0

Expand Down
16 changes: 9 additions & 7 deletions R/average_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#' vector or matrix of the same length as the input.
#'
#' @inheritParams hstats
#' @param y Vector/matrix of the response corresponding to `X`.
#' @param y Vector/matrix of the response, or the corresponding column names in `X`.
#' @param loss One of "squared_error", "logloss", "mlogloss", "poisson",
#' "gamma", "absolute_error", "classification_error". Alternatively, a loss function
#' can be provided that turns observed and predicted values into a numeric vector or
Expand All @@ -49,14 +49,16 @@
#' @examples
#' # MODEL 1: Linear regression
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' average_loss(fit, X = iris, y = iris$Sepal.Length)
#' average_loss(fit, X = iris, y = iris$Sepal.Length, BY = iris$Species)
#' average_loss(fit, X = iris, y = iris$Sepal.Length, BY = "Sepal.Width")
#' average_loss(fit, X = iris, y = "Sepal.Length")
#' average_loss(fit, X = iris, y = iris$Sepal.Length, BY = iris$Sepal.Width)
#' average_loss(fit, X = iris, y = "Sepal.Length", BY = "Sepal.Width")
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
#' average_loss(fit, X = iris, y = iris[1:2])
#' L <- average_loss(fit, X = iris, y = iris[1:2], loss = "gamma", BY = "Species")
#' L <- average_loss(
#' fit, X = iris, y = iris[1:2], loss = "gamma", BY = "Species"
#' )
#' L
#' plot(L)
average_loss <- function(object, ...) {
Expand All @@ -72,9 +74,9 @@ average_loss.default <- function(object, X, y,
w = NULL, ...) {
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun),
NROW(y) == nrow(X)
is.function(pred_fun)
)
y <- prepare_y(y = y, X = X)[["y"]]
if (!is.null(w)) {
w <- prepare_w(w = w, X = X)[["w"]]
}
Expand Down
7 changes: 7 additions & 0 deletions R/losses.R
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ expand_actual <- function(actual, predicted) {
pp <- NCOL(predicted)
pa <- NCOL(actual)
if (pa == pp) {
if (pa > 1L) {
nmp <- colnames(predicted)
nma <- colnames(actual)
if (!is.null(nmp) && !is.null(nma) && !identical(nmp, nma)) {
stop("Column names of multi-output response must correspond to predictions.")
}
}
return(actual)
}
if (pp > 1L && pa == 1L) {
Expand Down
17 changes: 12 additions & 5 deletions R/perm_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
#' @inheritSection average_loss Losses
#'
#' @param v Vector of feature names, or named list of feature groups.
#' The default (`NULL`) will use all column names of `X` except the column name
#' of the optional case weight `w` (if specified as name).
#' The default (`NULL`) will use all column names of `X` with the following exception:
#' If `y` or `w` are passed as column names, they are dropped.
#' @param m_rep Number of permutations (default 4).
#' @param agg_cols Should multivariate losses be summed up? Default is `FALSE`.
#' @param normalize Should importance statistics be divided by average loss?
Expand All @@ -30,7 +30,7 @@
#' @examples
#' # MODEL 1: Linear regression
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' s <- perm_importance(fit, X = iris[-1], y = iris$Sepal.Length)
#' s <- perm_importance(fit, X = iris, y = "Sepal.Length")
#' s
#' s$M
#' s$SE # Standard errors are available thanks to repeated shuffling
Expand All @@ -39,7 +39,7 @@
#'
#' # Groups of features can be passed as named list
#' v <- list(petal = c("Petal.Length", "Petal.Width"), species = "Species")
#' s <- perm_importance(fit, X = iris, y = iris$Sepal.Length, v = v)
#' s <- perm_importance(fit, X = iris, y = "Sepal.Length", v = v)
#' s
#' plot(s)
#'
Expand All @@ -64,10 +64,14 @@ perm_importance.default <- function(object, X, y, v = NULL,
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun),
NROW(y) == nrow(X),
m_rep >= 1L
)

# Are y column names or a vector/matrix?
y2 <- prepare_y(y = y, X = X)
y <- y2[["y"]]
y_names <- y2[["y_names"]]

# Is w a column name or a vector?
if (!is.null(w)) {
w2 <- prepare_w(w = w, X = X)
Expand All @@ -81,6 +85,9 @@ perm_importance.default <- function(object, X, y, v = NULL,
if (!is.null(w) && !is.null(w_name)) {
v <- setdiff(v, w_name)
}
if (!is.null(y_names)) {
v <- setdiff(v, y_names)
}
} else {
v_c <- unlist(v, use.names = FALSE, recursive = FALSE)
stopifnot(all(v_c %in% colnames(X)))
Expand Down
21 changes: 21 additions & 0 deletions R/utils_input.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,27 @@ prepare_w <- function(w, X) {
list(w = w, w_name = w_name)
}

#' Prepares Response y
#'
#' Internal function that prepares the response `y`.
#'
#' @noRd
#' @keywords internal
#' @param y Vector/matrix-like of the same length as `X`, or column names in `X`.
#' @param X Matrix-like.
#'
#' @returns A list.
prepare_y <- function(y, X) {
if (NROW(y) < nrow(X) && all(y %in% colnames(X))) {
y_names <- y
y <- X[, y]
} else {
stopifnot(NROW(y) == nrow(X))
y_names <- NULL
}
list(y = y, y_names = y_names)
}

#' mlr3 Helper
#'
#' Returns the prediction function of a mlr3 Learner.
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ library(ggplot2)
set.seed(1)

fit <- ranger(Species ~ ., data = iris, probability = TRUE)
average_loss(fit, X = iris, y = iris$Species, loss = "mlogloss") # 0.0521
average_loss(fit, X = iris, y = "Species", loss = "mlogloss") # 0.0521

s <- hstats(fit, X = iris[-5])
s
Expand All @@ -267,7 +267,7 @@ ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width", n_max = 150) |>
plot(center = TRUE) +
ggtitle("Centered ICE plots")

perm_importance(fit, X = iris[-5], y = iris$Species, loss = "mlogloss")
perm_importance(fit, X = iris, y = "Species", loss = "mlogloss")
# Permutation importance
# Petal.Length Petal.Width Sepal.Length Sepal.Width
# 0.50941613 0.49187688 0.05669978 0.00950009
Expand Down Expand Up @@ -306,7 +306,7 @@ s <- hstats(fit, X = iris[-1])
s # 0 -> no interactions
plot(partial_dep(fit, v = "Petal.Width", X = iris))

imp <- perm_importance(fit, X = iris[-1], y = iris$Sepal.Length)
imp <- perm_importance(fit, X = iris, y = "Sepal.Length")
imp
# Permutation importance
# Petal.Length Species Petal.Width Sepal.Width
Expand Down Expand Up @@ -334,7 +334,7 @@ fit <- train(
h2(hstats(fit, X = iris[-1])) # 0

plot(ice(fit, v = "Petal.Width", X = iris), center = TRUE)
plot(perm_importance(fit, X = iris[-1], y = iris$Sepal.Length))
plot(perm_importance(fit, X = iris, y = "Sepal.Length"))
```

### mlr3
Expand All @@ -354,7 +354,7 @@ s <- hstats(fit_rf, X = iris[-5], threeway_m = 0)
plot(s)

# Permutation importance
perm_importance(fit_rf, X = iris[-5], y = iris$Species, loss = "mlogloss") |>
perm_importance(fit_rf, X = iris, y = "Species", loss = "mlogloss") |>
plot()
```

Expand Down
16 changes: 16 additions & 0 deletions backlog/hstats_explainer.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
hstats_explainer <- function(object, X, pred_fun = stats::predict,
y = NULL, loss = "squared_error",
w = NULL, ...) {
structure(
list(
object = object,
X = X,
pred_fun = function(m, x) pred_fun(m, x, ...),
y = y,
loss = loss,
w = w
),
class = "hstats_explainer"
)
}

5 changes: 2 additions & 3 deletions backlog/modeltuner.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ fit_glm <- model(glm(form, iris, weights = Petal.Width, family = Gamma(link = "l
mm <- c(lm = fit_lm, glm = fit_glm)
predict(mm, head(iris))

average_loss(mm, X = iris, y = iris$Sepal.Length, BY = "Species", w = "Petal.Width") |>
average_loss(mm, X = iris, y = "Sepal.Length", BY = "Species", w = "Petal.Width") |>
plot()
partial_dep(mm, v = "Sepal.Width", X = iris, BY = "Species", w = "Petal.Width") |>
plot(show_points = FALSE)
ice(mm, v = "Sepal.Width", X = iris, BY = "Species") |>
plot(facet_scales = "fixed")

perm_importance(mm, X = iris[-1], y = iris[, 1], w = "Petal.Width") |>
perm_importance(mm, X = iris, y = "Sepal.Length", w = "Petal.Width") |>
plot()

# Interaction statistics (H-statistics)
Expand All @@ -24,4 +24,3 @@ H
plot(H)
h2_pairwise(H, normalize = FALSE, squared = FALSE) |>
plot()

12 changes: 7 additions & 5 deletions man/average_loss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions man/perm_importance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 16 additions & 11 deletions tests/testthat/test_average_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ y <- iris$Sepal.Length

test_that("average_loss() works ungrouped for regression", {
s <- average_loss(fit, X = iris, y = y)$M
s2 <- average_loss(fit, X = iris, y = "Sepal.Length")$M
expect_equal(drop(s), mean((y - predict(fit, iris))^2))
expect_equal(s, s2)

s <- average_loss(fit, X = iris, y = y, loss = "absolute_error")$M
expect_equal(drop(s), mean(abs(y - predict(fit, iris))))
Expand All @@ -21,24 +23,24 @@ test_that("average_loss() works ungrouped for regression", {

test_that("average_loss() works with groups for regression", {
s <- average_loss(fit, X = iris, y = y, BY = iris$Species)$M
s2 <- average_loss(fit, X = iris, y = y, BY = "Species")$M
s2 <- average_loss(fit, X = iris, y = "Sepal.Length", BY = "Species")$M

xpect <- by((y - predict(fit, iris))^2, FUN = mean, INDICES = iris$Species)
expect_equal(drop(s), c(xpect))
expect_equal(s, s2)

expect_equal(dim(average_loss(fit, X = iris, y = y, BY = "Sepal.Width")$M), c(4L, 1L))
expect_equal(
dim(average_loss(fit, X = iris, y = y, BY = "Sepal.Width", by_size = 2L)$M),
dim(average_loss(fit, X = iris, y = "Sepal.Width", BY = "Sepal.Width", by_size = 2L)$M),
c(2L, 1L)
)
})

test_that("average_loss() works with weights for regression", {
s1 <- average_loss(fit, X = iris, y = y)
s2 <- average_loss(fit, X = iris, y = y, w = rep(2, times = 150))
s2 <- average_loss(fit, X = iris, y = "Sepal.Length", w = rep(2, times = 150))
s3 <- average_loss(fit, X = iris, y = y, w = "Petal.Width")
s4 <- average_loss(fit, X = iris, y = y, w = iris$Petal.Width)
s4 <- average_loss(fit, X = iris, y = "Sepal.Length", w = iris$Petal.Width)

expect_equal(s1, s2)
expect_false(identical(s2, s3))
Expand All @@ -51,9 +53,9 @@ test_that("average_loss() works with weights and grouped for regression", {
g <- iris$Species
s1 <- average_loss(fit, X = iris, y = y, BY = g)
s2 <- average_loss(
fit, X = iris, y = y, w = rep(2, times = 150), BY = "Species"
fit, X = iris, y = "Sepal.Length", w = rep(2, times = 150), BY = "Species"
)
s3 <- average_loss(fit, X = iris, y = y, w = "Petal.Width", BY = g)
s3 <- average_loss(fit, X = iris, y = "Sepal.Length", w = "Petal.Width", BY = g)
s4 <- average_loss(fit, X = iris, y = y, w = iris$Petal.Width, BY = g)

expect_equal(s1, s2)
Expand All @@ -66,11 +68,14 @@ test_that("average_loss() works with weights and grouped for regression", {
#================================================

y <- as.matrix(iris[1:2])
yy <- colnames(y)
fit <- lm(y ~ Petal.Length + Species, data = iris)

test_that("average_loss() works ungrouped (multi regression)", {
s <- average_loss(fit, X = iris, y = y)$M
expect_equal(drop(s), colMeans((y - predict(fit, iris))^2))
s2 <- average_loss(fit, X = iris, y = yy)$M
expect_equal(s, s2)

s <- average_loss(fit, X = iris, y = y, loss = "absolute_error")$M
expect_equal(drop(s), colMeans(abs(y - predict(fit, iris))))
Expand All @@ -81,7 +86,7 @@ test_that("average_loss() works ungrouped (multi regression)", {
s <- average_loss(fit, X = iris, y = y, loss = "poisson")$M
expect_equal(drop(s), colMeans(poisson()$dev.resid(y, predict(fit, iris), 1)))

s <- average_loss(fit, X = iris, y = y, loss = "gamma")$M
s <- average_loss(fit, X = iris, y = yy, loss = "gamma")$M
expect_equal(drop(s), colMeans(Gamma()$dev.resid(y, predict(fit, iris), 1)))
})

Expand All @@ -93,9 +98,9 @@ test_that("average_loss() works with groups (multi regression)", {

test_that("average_loss() works with weights (multi regression)", {
s1 <- average_loss(fit, X = iris, y = y)
s2 <- average_loss(fit, X = iris, y = y, w = rep(2, times = 150))
s2 <- average_loss(fit, X = iris, y = yy, w = rep(2, times = 150))
s3 <- average_loss(fit, X = iris, y = y, w = iris$Petal.Width)
s4 <- average_loss(fit, X = iris, y = y, w = "Petal.Width")
s4 <- average_loss(fit, X = iris, y = yy, w = "Petal.Width")

expect_equal(s1, s2)
expect_false(identical(s2, s3))
Expand All @@ -105,9 +110,9 @@ test_that("average_loss() works with weights (multi regression)", {
test_that("average_loss() works with weights and grouped (multi regression)", {
g <- iris$Species
s1 <- average_loss(fit, X = iris, y = y, BY = g)
s2 <- average_loss(fit, X = iris, y = y, w = rep(2, times = 150), BY = g)
s2 <- average_loss(fit, X = iris, y = yy, w = rep(2, times = 150), BY = g)
s3 <- average_loss(fit, X = iris, y = y, w = iris$Petal.Width, BY = g)
s4 <- average_loss(fit, X = iris, y = y, w = "Petal.Width", BY = g)
s4 <- average_loss(fit, X = iris, y = yy, w = "Petal.Width", BY = g)

expect_equal(s1, s2)
expect_false(identical(s2, s3))
Expand Down
Loading

0 comments on commit 6aaac0a

Please sign in to comment.