Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split quant_approx into two arguments #83

Merged
merged 1 commit into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Major changes

- `hstats()` has received an argument `quant_approx` to speed-up calculations by quantile binning. Dense numeric variables are replaced by midpoints of `quant_approx + 1` uniform quantiles. By default, the value is `NULL` (no approximation). Even relatively high values like 50 will bring a massive speed-up for dense features, mainly for the one-way calculations. Use this option when calculations are slow, or when you want to increase `n_max`.
- Quantile approximation: `hstats()` now has the option `approx = FALSE`. Set to `TRUE` to replace values of dense numeric columns by `grid_size = 50` quantile midpoints. This will bring a massive speed-up for one-way calculations. Use this option when one-way calculations are slow, or when you want to increase `n_max`.
- `hstats()`: `n_max` has been increased from 300 to 500 rows. This will make estimates of H statistics more stable at the price of longer run time. Reduce to 300 for the old behaviour.
- `hstats()`: Three-way interactions are not anymore calculated by default. Set `threeway_m` to 5 for the old behaviour.
- Revised plots: The colors and color palettes have changed and can now also be controlled via global options. For instance, to change the fill color of all bars, set `options(hstats.fill = new value)`. Value labels are more clear, and there are more options. Varying color/fill scales now use viridis (inferno). This can be modified on the fly or via `options(hstats.viridis_args = list(...))`.
Expand Down
97 changes: 31 additions & 66 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@
#' @param threeway_m Like `pairwise_m`, but controls the feature count for
#' three-way interactions. Cannot be larger than `pairwise_m`.
#' To save computation time, the default is 0.
#' @param quant_approx Integer. Dense numeric variables in `X` are replaced by midpoints
#' of `quant_approx + 1` uniform quantiles. By default, the value is `NULL`
#' (no approximation). Even relatively high values like 50 will bring a massive
#' speed-up for dense features, mainly for one-way statistics.
#' Note that the quantiles are calculated after subsampling to `n_max` rows.
#' @param eps Threshold below which numerator values are set to 0. Default is 1e-10.
#' @param approx Should quantile approximation be applied to dense numeric features?
#' The default is `FALSE`. Setting this option to `TRUE` brings a massive speed-up
#' for one-way calculations. It can, e.g., be used when the number of features is
#' very large.
#' @param grid_size Integer controlling the number of quantile midpoints used to
#' approximate dense numerics. The quantile midpoints are calculated after
#' subampling via `n_max`. Only relevant if `approx = TRUE`.
#' @param n_max If `X` has more than `n_max` rows, a random sample of `n_max` rows is
#' selected from `X`. In this case, set a random seed for reproducibility.
#' @param eps Threshold below which numerator values are set to 0. Default is 1e-10.
#' @param w Optional vector of case weights. Can also be a column name of `X`.
#' @param verbose Should a progress bar be shown? The default is `TRUE`.
#' @param ... Additional arguments passed to `pred_fun(object, X, ...)`,
Expand Down Expand Up @@ -141,8 +143,9 @@ hstats <- function(object, ...) {
hstats.default <- function(object, X, v = NULL,
pred_fun = stats::predict,
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = NULL, verbose = TRUE, ...) {
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = NULL, verbose = TRUE, ...) {
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun)
Expand Down Expand Up @@ -180,8 +183,8 @@ hstats.default <- function(object, X, v = NULL,
}

# Quantile approximation to speedup things for dense features
if (!is.null(quant_approx)) {
X <- approx_matrix_or_df(X = X, v = v, m = quant_approx)
if (isTRUE(approx)) {
X <- approx_matrix_or_df(X = X, v = v, m = grid_size)
}

# Predictions ("F" in Friedman and Popescu) always calculated (cheap)
Expand Down Expand Up @@ -277,18 +280,20 @@ hstats.default <- function(object, X, v = NULL,
hstats.ranger <- function(object, X, v = NULL,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = NULL, verbose = TRUE, ...) {
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = NULL, verbose = TRUE, ...) {
hstats.default(
object = object,
X = X,
v = v,
pred_fun = pred_fun,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
quant_approx = quant_approx,
eps = eps,
approx = approx,
grid_size = grid_size,
n_max = n_max,
eps = eps,
w = w,
verbose = verbose,
...
Expand All @@ -300,8 +305,9 @@ hstats.ranger <- function(object, X, v = NULL,
hstats.Learner <- function(object, X, v = NULL,
pred_fun = NULL,
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = NULL, verbose = TRUE, ...) {
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = NULL, verbose = TRUE, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
Expand All @@ -312,9 +318,10 @@ hstats.Learner <- function(object, X, v = NULL,
pred_fun = pred_fun,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
quant_approx = quant_approx,
eps = eps,
approx = approx,
grid_size = grid_size,
n_max = n_max,
eps = eps,
w = w,
verbose = verbose,
...
Expand All @@ -327,19 +334,20 @@ hstats.explainer <- function(object, X = object[["data"]],
v = NULL,
pred_fun = object[["predict_function"]],
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = object[["weights"]],
verbose = TRUE, ...) {
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = object[["weights"]], verbose = TRUE, ...) {
hstats.default(
object = object[["model"]],
X = X,
v = v,
pred_fun = pred_fun,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
quant_approx = quant_approx,
eps = eps,
approx = approx,
grid_size = grid_size,
n_max = n_max,
eps = eps,
w = w,
verbose = verbose,
...
Expand Down Expand Up @@ -548,46 +556,3 @@ get_v <- function(H, m) {
}
v[v %in% v_cand]
}

#' Approximate Vector
#'
#' Internal function. Approximates values by the average of the two closest quantiles.
#'
#' @noRd
#' @keywords internal
#'
#' @param x A vector or factor.
#' @param m Number of unique values.
#' @returns An approximation of `x` (or `x` if non-numeric or discrete).
approx_vector <- function(x, m = 25L) {
if (!is.numeric(x) || length(unique(x)) <= m) {
return(x)
}
p <- seq(0, 1, length.out = m + 1L)
q <- unique(stats::quantile(x, probs = p, names = FALSE, na.rm = TRUE))
mids <- (q[-length(q)] + q[-1L]) / 2
return(mids[findInterval(x, q, rightmost.closed = TRUE)])
}

#' Approximate df or Matrix
#'
#' Internal function. Calls `approx_vector()` to each column in matrix or data.frame.
#'
#' @noRd
#' @keywords internal
#'
#' @param X A matrix or data.frame.
#' @param m Number of unique values.
#' @returns An approximation of `X` (or `X` if non-numeric or discrete).
approx_matrix_or_df <- function(X, v = colnames(X), m = 25L) {
stopifnot(
m >= 2L,
is.data.frame(X) || is.matrix(X)
)
if (is.data.frame(X)) {
X[v] <- lapply(X[v], FUN = approx_vector, m = m)
} else { # Matrix
X[, v] <- apply(X[, v, drop = FALSE], MARGIN = 2L, FUN = approx_vector, m = m)
}
return(X)
}
2 changes: 1 addition & 1 deletion R/partial_dep.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
#' A partial dependence plot (PDP) plots the values of \eqn{\hat F_s(\mathbf{x}_s)}
#' over a grid of evaluation points \eqn{\mathbf{x}_s}.
#'
#' @inheritParams hstats
#' @inheritParams multivariate_grid
#' @inheritParams hstats
#' @param v One or more column names over which you want to calculate the partial
#' dependence.
#' @param grid Evaluation grid. A vector (if `length(v) == 1L`), or a matrix/data.frame
Expand Down
59 changes: 59 additions & 0 deletions R/utils_calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,62 @@ wcenter <- function(x, w = NULL) {
# sweep(x, MARGIN = 2L, STATS = wcolMeans(x, w = w)) # Slower
x - matrix(wcolMeans(x, w = w), nrow = nrow(x), ncol = ncol(x), byrow = TRUE)
}

#' Bin into Quantiles
#'
#' Internal function. Applies [cut()] to quantile breaks.
#'
#' @noRd
#' @keywords internal
#'
#' @param x A numeric vector.
#' @param m Number of intervals.
#' @returns A factor, representing binned `x`.
qcut <- function(x, m) {
p <- seq(0, 1, length.out = m + 1L)
g <- stats::quantile(x, probs = p, names = FALSE, type = 1L, na.rm = TRUE)
cut(x, breaks = unique(g), include.lowest = TRUE)
}

#' Approximate Vector
#'
#' Internal function. Approximates values by the average of the two closest quantiles.
#'
#' @noRd
#' @keywords internal
#'
#' @param x A vector or factor.
#' @param m Number of unique values.
#' @returns An approximation of `x` (or `x` if non-numeric or discrete).
approx_vector <- function(x, m = 50L) {
if (!is.numeric(x) || length(unique(x)) <= m) {
return(x)
}
p <- seq(0, 1, length.out = m + 1L)
q <- unique(stats::quantile(x, probs = p, names = FALSE, na.rm = TRUE))
mids <- (q[-length(q)] + q[-1L]) / 2
return(mids[findInterval(x, q, rightmost.closed = TRUE)])
}

#' Approximate df or Matrix
#'
#' Internal function. Calls `approx_vector()` to each column in matrix or data.frame.
#'
#' @noRd
#' @keywords internal
#'
#' @param X A matrix or data.frame.
#' @param m Number of unique values.
#' @returns An approximation of `X` (or `X` if non-numeric or discrete).
approx_matrix_or_df <- function(X, v = colnames(X), m = 50L) {
stopifnot(
m >= 2L,
is.data.frame(X) || is.matrix(X)
)
if (is.data.frame(X)) {
X[v] <- lapply(X[v], FUN = approx_vector, m = m)
} else { # Matrix
X[, v] <- apply(X[, v, drop = FALSE], MARGIN = 2L, FUN = approx_vector, m = m)
}
return(X)
}
16 changes: 0 additions & 16 deletions R/utils_input.R
Original file line number Diff line number Diff line change
@@ -1,19 +1,3 @@
#' Bin into Quantiles
#'
#' Internal function. Applies [cut()] to quantile breaks.
#'
#' @noRd
#' @keywords internal
#'
#' @param x A numeric vector.
#' @param m Number of intervals.
#' @returns A factor, representing binned `x`.
qcut <- function(x, m) {
p <- seq(0, 1, length.out = m + 1L)
g <- stats::quantile(x, probs = p, names = FALSE, type = 1L, na.rm = TRUE)
cut(x, breaks = unique(g), include.lowest = TRUE)
}

#' Prepares Group BY Variable
#'
#' Internal function that prepares a BY variable or BY column name.
Expand Down
35 changes: 21 additions & 14 deletions man/hstats.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion packaging.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ build()
# build(binary = TRUE)
install(upgrade = FALSE)


# Run only if package is public(!) and should go to CRAN
if (FALSE) {
check_win_devel()
Expand Down
Loading