Skip to content

Commit

Permalink
Merge pull request #110 from mayer79/performance
Browse files Browse the repository at this point in the history
Performance improvements
  • Loading branch information
mayer79 authored Dec 1, 2023
2 parents fab07aa + cc86da5 commit ef25e05
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 12 deletions.
6 changes: 5 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# hstats 1.1.1

## Performance improvements

- For pure data.frames (no tibbles, data.tables etc.), most functions are significantly faster ([#110](https://github.com/mayer79/hstats/pull/110)).
- Slight speed-up of permutation importance for non-matrix `X` ([#109](https://github.com/mayer79/hstats/pull/109)).

## Other changes

- In multivariate cases, it was possible that normalized H-statistics could equal `0/0 (= NaN)`. Such values are now replaced by 0 ([#107](https://github.com/mayer79/hstats/issues/107)).
- Removed an unnecessary special case when calculating column means ([#106](https://github.com/mayer79/hstats/pull/106)).
- Slight speed-up of permutation importance for non-matrix `X` ([#109](https://github.com/mayer79/hstats/pull/109)).

# hstats 1.1.0

Expand Down
10 changes: 5 additions & 5 deletions R/pd_raw.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ ice_raw <- function(object, v, X, grid, pred_fun = stats::predict,
n_grid <- NROW(grid)

# Explode everything to n * n_grid rows
X_pred <- X[rep(seq_len(n), times = n_grid), , drop = FALSE]
X_pred <- rep_rows(X, rep.int(seq_len(n), n_grid))
if (D1) {
grid_pred <- rep(grid, each = n)
} else {
grid_pred <- grid[rep(seq_len(n_grid), each = n), ]
grid_pred <- rep_rows(grid, rep_each(n_grid, n))
}

# Vary v
Expand Down Expand Up @@ -119,7 +119,7 @@ ice_raw <- function(object, v, X, grid, pred_fun = stats::predict,

# Compensate via w
if (is.null(w)) {
w <- rep(1.0, times = nrow(X))
w <- rep.int(1.0, nrow(X))
}
if (anyNA(x_not_v)) {
# rowsum() warns about NA in group = x_not_v -> integer encode
Expand Down Expand Up @@ -153,8 +153,8 @@ ice_raw <- function(object, v, X, grid, pred_fun = stats::predict,
}
out <- list(grid = ugrid)
if (NCOL(grid) >= 2L) { # Non-vector case
grid <- apply(grid, MARGIN = 1L, FUN = paste, collapse = "_:_")
ugrid <- apply(ugrid, MARGIN = 1L, FUN = paste, collapse = "_:_")
grid <- do.call(paste, c(as.data.frame(grid), sep = "_:_"))
ugrid <- do.call(paste, c(as.data.frame(ugrid), sep = "_:_"))
if (anyDuplicated(ugrid)) {
stop("String '_:_' found in grid values at unlucky position.")
}
Expand Down
10 changes: 6 additions & 4 deletions R/perm_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,19 @@ perm_importance.default <- function(object, X, y, v = NULL,

# Stack y and X m times
if (m_rep > 1L) {
ind <- rep(seq_len(n), times = m_rep)
X <- X[ind, , drop = FALSE]
ind <- rep.int(seq_len(n), m_rep)
X <- rep_rows(X, ind)
if (is.vector(y) || is.factor(y)) {
y <- y[ind]
} else {
y <- y[ind, , drop = FALSE]
y <- rep_rows(y, ind)
}
}

shuffle_perf <- function(z, XX) {
ind <- c(replicate(m_rep, sample(seq_len(n)))) # shuffle within n rows
# Shuffle within n rows (could be slightly sped-up via lapply())
ind <- c(replicate(m_rep, sample.int(n)))

if (is.matrix(XX) || length(z) > 1L) {
XX[, z] <- XX[ind, z]
} else {
Expand Down
33 changes: 32 additions & 1 deletion R/utils_calculate.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' Fast Index Generation
#'
#' For not too small m, much faster than `rep(seq_len(m), each = each)`.
#' For not too small m much faster than `rep(seq_len(m), each = each)`.
#'
#' @noRd
#' @keywords internal
Expand All @@ -16,6 +16,14 @@ rep_each <- function(m, each) {
dim(out) <- NULL
out
}
#
# # Same as rep.int(seq_len(m), times)
# rep_times <- function(m, times) {
# out <- .row(dim = c(m, times))
# dim(out) <- NULL
# out
# }


#' Fast OHE
#'
Expand Down Expand Up @@ -263,3 +271,26 @@ wcenter <- function(x, w = NULL) {
# sweep(x, MARGIN = 2L, STATS = wcolMeans(x, w = w)) # Slower
x - matrix(wcolMeans(x, w = w), nrow = nrow(x), ncol = ncol(x), byrow = TRUE)
}

#' Fast Row Subsetting (from kernelshap)
#'
#' Internal function used to row-subset data.frames.
#' Brings a massive speed-up for data.frames. All other classes (tibble, data.table,
#' matrix) are subsetted in the usual way.
#'
#' @noRd
#' @keywords internal
#'
#' @param x A matrix-like object.
#' @param i Logical or integer vector of rows to pick.
#' @returns Subsetted version of `x`.
rep_rows <- function(x, i) {
if (!(all(class(x) == "data.frame"))) {
return(x[i, , drop = FALSE]) # matrix, tibble, data.table, ...
}
# data.frame
out <- lapply(x, function(z) if (length(dim(z)) != 2L) z[i] else z[i, , drop = FALSE])
attr(out, "row.names") <- .set_row_names(length(i))
class(out) <- "data.frame"
out
}
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ perm_importance(ex)

# Permutation importance
# Petal.Length Petal.Width Sepal.Width Species
# 0.59836442 0.11625137 0.08246635 0.03982554
# 0.59836442 0.11625137 0.07966910 0.03982554
```

![](man/figures/dalex_hstats.svg)
Expand Down
20 changes: 20 additions & 0 deletions tests/testthat/test_calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@ test_that("rep_each() works", {
expect_true(is.integer(rep_each(100, 100)))
})

# The next two checks copied from {kernelshap}
test_that("rep_rows() gives the same as usual subsetting (except rownames)", {
setrn <- function(x) {rownames(x) <- 1:nrow(x); x}

expect_equal(rep_rows(iris, 1), iris[1, ])
expect_equal(rep_rows(iris, 2:1), setrn(iris[2:1, ]))
expect_equal(rep_rows(iris, c(1, 1, 1)), setrn(iris[c(1, 1, 1), ]))

ir <- iris[1, ]
ir$y <- list(list(a = 1, b = 2))
expect_equal(rep_rows(ir, c(1, 1)), setrn(ir[c(1, 1), ]))
})

test_that("rep_rows() gives the same as usual subsetting for matrices", {
ir <- data.matrix(iris[1:4])

expect_equal(rep_rows(ir, c(1, 1, 2)), ir[c(1, 1, 2), ])
expect_equal(rep_rows(ir, 1), ir[1, , drop = FALSE])
})

test_that("fdummy() works", {
x <- c("A", "A", "C", "D")
mm <- matrix(model.matrix(~ x + 0), ncol = 3, dimnames = list(NULL, c("A", "C", "D")))
Expand Down

0 comments on commit ef25e05

Please sign in to comment.