diff --git a/NEWS.md b/NEWS.md index 274798c..9303e4e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,10 +1,14 @@ # hstats 1.1.1 +## Performance improvements + +- For pure data.frames (no tibbles, data.tables etc.), most functions are significantly faster ([#110](https://github.com/mayer79/hstats/pull/110)). +- Slight speed-up of permutation importance for non-matrix `X` ([#109](https://github.com/mayer79/hstats/pull/109)). + ## Other changes - In multivariate cases, it was possible that normalized H-statistics could equal `0/0 (= NaN)`. Such values are now replaced by 0 ([#107](https://github.com/mayer79/hstats/issues/107)). - Removed an unnecessary special case when calculating column means ([#106](https://github.com/mayer79/hstats/pull/106)). -- Slight speed-up of permutation importance for non-matrix `X` ([#109](https://github.com/mayer79/hstats/pull/109)). # hstats 1.1.0 diff --git a/R/pd_raw.R b/R/pd_raw.R index 5d895e2..3bff4d5 100644 --- a/R/pd_raw.R +++ b/R/pd_raw.R @@ -64,11 +64,11 @@ ice_raw <- function(object, v, X, grid, pred_fun = stats::predict, n_grid <- NROW(grid) # Explode everything to n * n_grid rows - X_pred <- X[rep(seq_len(n), times = n_grid), , drop = FALSE] + X_pred <- rep_rows(X, rep.int(seq_len(n), n_grid)) if (D1) { grid_pred <- rep(grid, each = n) } else { - grid_pred <- grid[rep(seq_len(n_grid), each = n), ] + grid_pred <- rep_rows(grid, rep_each(n_grid, n)) } # Vary v @@ -119,7 +119,7 @@ ice_raw <- function(object, v, X, grid, pred_fun = stats::predict, # Compensate via w if (is.null(w)) { - w <- rep(1.0, times = nrow(X)) + w <- rep.int(1.0, nrow(X)) } if (anyNA(x_not_v)) { # rowsum() warns about NA in group = x_not_v -> integer encode @@ -153,8 +153,8 @@ ice_raw <- function(object, v, X, grid, pred_fun = stats::predict, } out <- list(grid = ugrid) if (NCOL(grid) >= 2L) { # Non-vector case - grid <- apply(grid, MARGIN = 1L, FUN = paste, collapse = "_:_") - ugrid <- apply(ugrid, MARGIN = 1L, FUN = paste, collapse = "_:_") + grid <- do.call(paste, c(as.data.frame(grid), sep = "_:_")) + ugrid <- do.call(paste, c(as.data.frame(ugrid), sep = "_:_")) if (anyDuplicated(ugrid)) { stop("String '_:_' found in grid values at unlucky position.") } diff --git a/R/perm_importance.R b/R/perm_importance.R index 4e1243f..51a2d86 100644 --- a/R/perm_importance.R +++ b/R/perm_importance.R @@ -124,17 +124,19 @@ perm_importance.default <- function(object, X, y, v = NULL, # Stack y and X m times if (m_rep > 1L) { - ind <- rep(seq_len(n), times = m_rep) - X <- X[ind, , drop = FALSE] + ind <- rep.int(seq_len(n), m_rep) + X <- rep_rows(X, ind) if (is.vector(y) || is.factor(y)) { y <- y[ind] } else { - y <- y[ind, , drop = FALSE] + y <- rep_rows(y, ind) } } shuffle_perf <- function(z, XX) { - ind <- c(replicate(m_rep, sample(seq_len(n)))) # shuffle within n rows + # Shuffle within n rows (could be slightly sped-up via lapply()) + ind <- c(replicate(m_rep, sample.int(n))) + if (is.matrix(XX) || length(z) > 1L) { XX[, z] <- XX[ind, z] } else { diff --git a/R/utils_calculate.R b/R/utils_calculate.R index 57a686f..49c50f1 100644 --- a/R/utils_calculate.R +++ b/R/utils_calculate.R @@ -1,6 +1,6 @@ #' Fast Index Generation #' -#' For not too small m, much faster than `rep(seq_len(m), each = each)`. +#' For not too small m much faster than `rep(seq_len(m), each = each)`. #' #' @noRd #' @keywords internal @@ -16,6 +16,14 @@ rep_each <- function(m, each) { dim(out) <- NULL out } +# +# # Same as rep.int(seq_len(m), times) +# rep_times <- function(m, times) { +# out <- .row(dim = c(m, times)) +# dim(out) <- NULL +# out +# } + #' Fast OHE #' @@ -263,3 +271,26 @@ wcenter <- function(x, w = NULL) { # sweep(x, MARGIN = 2L, STATS = wcolMeans(x, w = w)) # Slower x - matrix(wcolMeans(x, w = w), nrow = nrow(x), ncol = ncol(x), byrow = TRUE) } + +#' Fast Row Subsetting (from kernelshap) +#' +#' Internal function used to row-subset data.frames. +#' Brings a massive speed-up for data.frames. All other classes (tibble, data.table, +#' matrix) are subsetted in the usual way. +#' +#' @noRd +#' @keywords internal +#' +#' @param x A matrix-like object. +#' @param i Logical or integer vector of rows to pick. +#' @returns Subsetted version of `x`. +rep_rows <- function(x, i) { + if (!(all(class(x) == "data.frame"))) { + return(x[i, , drop = FALSE]) # matrix, tibble, data.table, ... + } + # data.frame + out <- lapply(x, function(z) if (length(dim(z)) != 2L) z[i] else z[i, , drop = FALSE]) + attr(out, "row.names") <- .set_row_names(length(i)) + class(out) <- "data.frame" + out +} diff --git a/README.md b/README.md index eae0da1..6321649 100644 --- a/README.md +++ b/README.md @@ -218,7 +218,7 @@ perm_importance(ex) # Permutation importance # Petal.Length Petal.Width Sepal.Width Species -# 0.59836442 0.11625137 0.08246635 0.03982554 +# 0.59836442 0.11625137 0.07966910 0.03982554 ``` ![](man/figures/dalex_hstats.svg) diff --git a/tests/testthat/test_calculate.R b/tests/testthat/test_calculate.R index 7c039c2..b3278ee 100644 --- a/tests/testthat/test_calculate.R +++ b/tests/testthat/test_calculate.R @@ -4,6 +4,26 @@ test_that("rep_each() works", { expect_true(is.integer(rep_each(100, 100))) }) +# The next two checks copied from {kernelshap} +test_that("rep_rows() gives the same as usual subsetting (except rownames)", { + setrn <- function(x) {rownames(x) <- 1:nrow(x); x} + + expect_equal(rep_rows(iris, 1), iris[1, ]) + expect_equal(rep_rows(iris, 2:1), setrn(iris[2:1, ])) + expect_equal(rep_rows(iris, c(1, 1, 1)), setrn(iris[c(1, 1, 1), ])) + + ir <- iris[1, ] + ir$y <- list(list(a = 1, b = 2)) + expect_equal(rep_rows(ir, c(1, 1)), setrn(ir[c(1, 1), ])) +}) + +test_that("rep_rows() gives the same as usual subsetting for matrices", { + ir <- data.matrix(iris[1:4]) + + expect_equal(rep_rows(ir, c(1, 1, 2)), ir[c(1, 1, 2), ]) + expect_equal(rep_rows(ir, 1), ir[1, , drop = FALSE]) +}) + test_that("fdummy() works", { x <- c("A", "A", "C", "D") mm <- matrix(model.matrix(~ x + 0), ncol = 3, dimnames = list(NULL, c("A", "C", "D")))