Skip to content

Commit

Permalink
Delegate OHE of response to losses
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Jul 6, 2024
1 parent 7960188 commit f003dc3
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 86 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
- Factor-valued predictions are no longer possible.
- Consequently, also removed "classification_error" loss.

## Minor changes

- Code simplifications.

# hstats 1.1.2

## ICE plots
Expand Down
18 changes: 11 additions & 7 deletions R/losses.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ check_loss <- function(actual, predicted) {
stopifnot(
is.vector(actual) || is.matrix(actual),
is.vector(predicted) || is.matrix(predicted),
is.numeric(actual),
is.numeric(predicted),
is.numeric(actual) || is.logical(actual),
is.numeric(predicted) || is.logical(predicted),
NROW(actual) == NROW(predicted),
NCOL(actual) == 1L || NCOL(actual) == NCOL(predicted)
)
Expand All @@ -27,10 +27,14 @@ check_loss <- function(actual, predicted) {
#' @noRd
#' @keywords internal
#'
#' @param actual A numeric vector or matrix.
#' @param actual A numeric vector or matrix, or a factor with levels in the same order
#' as the column names of `predicted`.
#' @param predicted A numeric vector or matrix.
#' @returns Vector or matrix of numeric losses.
loss_squared_error <- function(actual, predicted) {
if (is.factor(actual)) {
actual <- fdummy(actual)
}
check_loss(actual, predicted)

return((drop(actual) - predicted)^2)
Expand Down Expand Up @@ -147,11 +151,10 @@ loss_mlogloss <- function(actual, predicted) {
is.matrix(actual),
is.matrix(predicted),

is.numeric(actual),
is.numeric(predicted),
is.numeric(actual) || is.logical(actual),
is.numeric(predicted) || is.logical(predicted),

nrow(actual) == nrow(predicted),
ncol(actual) == ncol(predicted),
dim(actual) == dim(predicted),
ncol(predicted) >= 2L,

all(predicted >= 0),
Expand All @@ -176,6 +179,7 @@ loss_mlogloss <- function(actual, predicted) {
xlogy <- function(x, y) {
out <- x * log(y)
out[x == 0] <- 0

return(out)
}

Expand Down
8 changes: 4 additions & 4 deletions R/perm_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ perm_importance.default <- function(object, X, y, v = NULL,
if (nrow(X) > n_max) {
ix <- sample(nrow(X), n_max)
X <- X[ix, , drop = FALSE]
if (is.vector(y)) {
if (is.vector(y) || is.factor(y)) {
y <- y[ix]
} else {
} else { # matrix case
y <- y[ix, , drop = FALSE]
}
if (!is.null(w)) {
Expand All @@ -126,9 +126,9 @@ perm_importance.default <- function(object, X, y, v = NULL,
if (m_rep > 1L) {
ind <- rep.int(seq_len(n), m_rep)
X <- rep_rows(X, ind)
if (is.vector(y)) {
if (is.vector(y) || is.factor(y)) {
y <- y[ind]
} else {
} else { # matrix case
y <- y[ind, , drop = FALSE]
}
}
Expand Down
9 changes: 3 additions & 6 deletions R/utils_input.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ prepare_pred <- function(x) {
if (!is.vector(x) && !is.matrix(x)) {
x <- as.matrix(x)
}
if (!is.numeric(x)) {
if (!is.numeric(x) && !is.logical(x)) {
stop("Predictions must be numeric!")
}
return(x)
Expand Down Expand Up @@ -116,13 +116,10 @@ prepare_y <- function(y, X) {
stopifnot(NROW(y) == nrow(X))
y_names <- NULL
}
if (is.factor(y)) {
y <- fdummy(y)
}
if (!is.vector(y) && !is.matrix(y)) {
if (!is.vector(y) && !is.matrix(y) && !is.factor(y)) {
y <- as.matrix(y)
}
if (!is.numeric(y)) {
if (!is.numeric(y) && !is.logical(y) && !is.factor(y)) {
stop("Response must be numeric (or factor.)")
}
list(y = y, y_names = y_names)
Expand Down
80 changes: 42 additions & 38 deletions backlog/benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,11 @@ bench::mark(
check = FALSE,
min_iterations = 3
)

# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# 1 iml 1.72s 1.75s 0.574 210.6MB 1.34 3 7 5.23s <NULL>
# 2 dalex 744.82ms 760.02ms 1.31 35.2MB 0.877 3 2 2.28s <NULL>
# 3 flashlight 1.29s 1.35s 0.742 63MB 0.990 3 4 4.04s <NULL>
# 4 hstats 407.26ms 412.31ms 2.43 26.5MB 0 3 0 1.23s <NULL>
# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time result
# 1 iml 1.76s 1.76s 0.565 211.6MB 3.39 3 18 5.31s <NULL>
# 2 dalex 688.54ms 697.71ms 1.44 35.2MB 1.91 3 4 2.09s <NULL>
# 3 flashlight 667.51ms 676.07ms 1.47 28.1MB 1.96 3 4 2.04s <NULL>
# 4 hstats 392.15ms 414.41ms 2.39 26.6MB 0.796 3 1 1.26s <NULL>

# Partial dependence (cont)
v <- "tot_lvg_area"
Expand All @@ -132,12 +131,12 @@ bench::mark(
check = FALSE,
min_iterations = 3
)
# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# 1 iml 1.14s 1.16s 0.861 376.7MB 3.73 3 13 3.48s <NULL>
# 2 dalex 653.24ms 654.51ms 1.35 192.8MB 2.24 3 5 2.23s <NULL>
# 3 flashlight 352.34ms 361.79ms 2.72 66.7MB 0.906 3 1 1.1s <NULL>
# 4 hstats 239.03ms 242.79ms 4.04 14.2MB 1.35 3 1 743.43ms <NULL>

# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time result
# <bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl> <int> <dbl> <bch:tm> <list>
# 1 iml 1.2s 1.4s 0.726 376.9MB 4.12 3 17 4.13s <NULL>
# 2 dalex 759.3ms 760.6ms 1.28 192.8MB 2.55 3 6 2.35s <NULL>
# 3 flashlight 369.1ms 403.1ms 2.55 66.8MB 2.55 3 3 1.18s <NULL>
# 4 hstats 242.1ms 243.8ms 4.03 14.2MB 0 3 0 744.25ms <NULL>#
# Partial dependence (discrete)
v <- "structure_quality"
bench::mark(
Expand All @@ -148,30 +147,31 @@ bench::mark(
check = FALSE,
min_iterations = 3
)
# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# 1 iml 100.6ms 103.6ms 9.46 13.34MB 0 5 0 529ms <NULL>
# 2 dalex 172.4ms 177.9ms 5.62 20.55MB 2.81 2 1 356ms <NULL>
# 3 flashlight 43.5ms 45.5ms 21.9 6.36MB 2.19 10 1 457ms <NULL>
# 4 hstats 25.3ms 25.8ms 37.9 1.54MB 2.10 18 1 475ms <NULL>

# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time result
# <bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl> <int> <dbl> <bch:tm> <list>
# 1 iml 107.9ms 108ms 9.26 13.64MB 9.26 2 2 216ms <NULL>
# 2 dalex 172ms 172.2ms 5.81 21.14MB 2.90 2 1 344ms <NULL>
# 3 flashlight 40.3ms 41.6ms 23.8 8.61MB 2.16 11 1 462ms <NULL>
# 4 hstats 24.5ms 25.9ms 35.5 1.64MB 0 18 0 507ms <NULL>

# H-Stats -> we use a subset of 500 rows
X_v500 <- X_valid[1:500, ]
mod500 <- Predictor$new(fit, data = as.data.frame(X_v500), predict.function = predf)
fl500 <- flashlight(fl, data = as.data.frame(valid[1:500, ]))

# iml # 225s total, using slow exact calculations
system.time( # 90s
# iml # 243s total, using slow exact calculations
system.time( # 110s
iml_overall <- Interaction$new(mod500, grid.size = 500)
)
system.time( # 135s for all combinations of latitude
system.time( # 133s for all combinations of latitude
iml_pairwise <- Interaction$new(mod500, grid.size = 500, feature = "latitude")
)

# flashlight: 13s total, doing only one pairwise calculation, otherwise would take 63s
system.time( # 11.5s
# flashlight: 14s total, doing only one pairwise calculation, otherwise would take 63s
system.time( # 11.7s
fl_overall <- light_interaction(fl500, v = x, grid_size = Inf, n_max = Inf)
)
system.time( # 2.4s
system.time( # 2.3s
fl_pairwise <- light_interaction(
fl500, v = coord, grid_size = Inf, n_max = Inf, pairwise = TRUE
)
Expand All @@ -185,34 +185,38 @@ system.time({
}
)

# Using 50 quantiles to approximate dense numerics: 0.9s
# Using 50 quantiles to approximate dense numerics: 0.8s
system.time(
H_approx <- hstats(fit, v = x, X = X_v500, n_max = Inf, approx = TRUE)
)

# Overall statistics correspond exactly
iml_overall$results |> filter(.interaction > 1e-6)
iml_overall$results |>
filter(.interaction > 1e-6)
# .feature .interaction
# 1: latitude 0.2791144
# 2: longitude 0.2791144
# 1: latitude 0.2458269
# 2: longitude 0.2458269

fl_overall$data |> subset(value > 0, select = c(variable, value))
# variable value
# 1 latitude 0.279
# 2 longitude 0.279
fl_overall$data |>
subset(value_ > 0, select = c(variable_, value_))
# variable_ value_
# 3 latitude 0.2458269
# 4 longitude 0.2458269

hstats_overall
# longitude latitude
# 0.2791144 0.2791144
# 0.2458269 0.2458269

# Pairwise results match as well
iml_pairwise$results |> filter(.interaction > 1e-6)
iml_pairwise$results |>
filter(.interaction > 1e-6)
# .feature .interaction
# 1: longitude:latitude 0.4339574
# 1: longitude:latitude 0.3942526

fl_pairwise$data |> subset(value > 0, select = c(variable, value))
# latitude:longitude 0.434
fl_pairwise$data |>
subset(value_ > 0, select = c(variable_, value_))
# latitude:longitude 0.3942526

hstats_pairwise
# latitude:longitude
# 0.4339574
# 0.3942526
27 changes: 0 additions & 27 deletions backlog/colMeans_factors.R

This file was deleted.

2 changes: 1 addition & 1 deletion tests/testthat/test_average_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ test_that("mlogloss works with either matrix y or factor y", {
})

test_that("loss_mlogloss() is in line with loss_logloss() in binary case", {
y <- (iris$Species == "setosa") * 1
y <- (iris$Species == "setosa")
Y <- cbind(no = 1 - y, yes = y)
fit <- glm(y ~ Sepal.Length, data = iris, family = binomial())
pf <- function(m, X, multi = FALSE) {
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_input.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ test_that("prepare_by() works", {
test_that("prepare_y() works", {
# "Vector" interface
expect_equal(prepare_y(iris[1:4], X = iris)$y, data.matrix(iris[1:4]))
expect_equal(prepare_y(iris["Species"], X = iris)$y, fdummy(iris$Species))
expect_equal(prepare_y(iris["Species"], X = iris)$y, iris$Species)
expect_equal(prepare_y(iris$Sepal.Width, X = iris)$y, iris$Sepal.Width)
expect_equal(prepare_y(iris["Sepal.Width"], X = iris)$y, iris$Sepal.Width)

Expand All @@ -35,7 +35,7 @@ test_that("prepare_y() works", {
expect_equal(out$y_names, cn)

out <- prepare_y("Species", X = iris)
expect_equal(out$y, fdummy(iris$Species))
expect_equal(out$y, iris$Species)
expect_equal(out$y_names, "Species")

out <- prepare_y("Sepal.Width", X = iris)
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_perm_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ test_that("Single output multiple models works without recycling y", {
})

test_that("loss_mlogloss() is in line with loss_logloss() in binary case", {
y <- (iris$Species == "setosa") * 1
y <- (iris$Species == "setosa")
Y <- cbind(no = 1 - y, yes = y)
fit <- glm(y ~ Sepal.Length, data = iris, family = binomial())
pf <- function(m, X, multi = FALSE) {
Expand Down

0 comments on commit f003dc3

Please sign in to comment.