diff --git a/R/partial_dep.R b/R/partial_dep.R index 0ea5fc3e..a805be6f 100644 --- a/R/partial_dep.R +++ b/R/partial_dep.R @@ -285,6 +285,8 @@ print.partial_dep <- function(x, n = 3L, ...) { #' To change the global option, use `options(stats.color = new value)`. #' @param show_points Logical flag indicating whether to show points (default) or not. #' No effect for 2D PDPs. +#' @param d2_geom The geometry used for 2D PDPs, by default "tile". The other option is +#' "point", which is useful, e.g., when the grid represents spatial points. #' @param ... Arguments passed to geometries. #' @inheritParams plot.hstats_matrix #' @export @@ -295,7 +297,8 @@ plot.partial_dep <- function(x, swap_dim = FALSE, viridis_args = getOption("hstats.viridis_args"), facet_scales = "fixed", - rotate_x = FALSE, show_points = TRUE, ...) { + rotate_x = FALSE, show_points = TRUE, + d2_geom = c("tile", "point"), ...) { v <- x[["v"]] by_name <- x[["by_name"]] K <- x[["K"]] @@ -347,15 +350,20 @@ plot.partial_dep <- function(x, } } else if (length(v) == 2L) { # Heat maps + d2_geom <- match.arg(d2_geom) if (K > 1L || !is.null(by_name)) { # Only one is possible wrp <- if (K > 1L) "varying_" else by_name } - p <- ggplot2::ggplot( - data, ggplot2::aes(x = .data[[v[1L]]], y = .data[[v[2L]]], fill = value_) - ) + - ggplot2::geom_tile(...) + - do.call(ggplot2::scale_fill_viridis_c, viridis_args) + - ggplot2::labs(fill = "PD") + p <- ggplot2::ggplot(data, ggplot2::aes(x = .data[[v[1L]]], y = .data[[v[2L]]])) + if (d2_geom == "tile") { + p <- p + ggplot2::geom_tile(ggplot2::aes(fill = value_), ...) + + do.call(ggplot2::scale_fill_viridis_c, viridis_args) + + ggplot2::labs(fill = "PD") + } else if (d2_geom == "point") { + p <- p + ggplot2::geom_point(ggplot2::aes(color = value_), ...) + + do.call(ggplot2::scale_color_viridis_c, viridis_args) + + ggplot2::labs(color = "PD") + } } if (!is.null(wrp)) { p <- p + ggplot2::facet_wrap(wrp, scales = facet_scales) diff --git a/R/utils_plot.R b/R/utils_plot.R index 42b82f9b..a545f86c 100644 --- a/R/utils_plot.R +++ b/R/utils_plot.R @@ -34,7 +34,7 @@ get_color_scale <- function(x) { #' #' @returns A data.frame with reverted factor levels. barplot_reverter <- function(df, group = TRUE) { - x <- c("variable_", if (group) "varying_") + x <- c("variable_", if (group) "varying_") for (z in x) { f <- df[[z]] df[[z]] <- factor(f, levels = rev(levels(f))) diff --git a/backlog/benchmark.R b/backlog/benchmark.R new file mode 100644 index 00000000..5d40f10b --- /dev/null +++ b/backlog/benchmark.R @@ -0,0 +1,201 @@ +library(hstats) +library(iml) +library(DALEX) +library(ingredients) +library(flashlight) + +library(shapviz) +library(xgboost) +library(ggplot2) +library(microbenchmark) + +# future::plan(multisession, workers = 1) + +colnames(miami) <- tolower(colnames(miami)) +miami <- transform(miami, log_price = log(sale_prc)) +x <- c("tot_lvg_area", "lnd_sqfoot", "latitude", "longitude", + "structure_quality", "age", "month_sold") +coord <- c("longitude", "latitude") + +# Train/valid split +set.seed(1) +ix <- sample(nrow(miami), 0.8 * nrow(miami)) +train <- data.frame(miami[ix, ]) +valid <- data.frame(miami[-ix, ]) +y_train <- train$log_price +y_valid <- valid$log_price +X_train <- data.matrix(train[x]) +X_valid <- data.matrix(valid[x]) + +dtrain <- xgb.DMatrix(X_train, label = y_train) +dvalid <- xgb.DMatrix(X_valid, label = y_valid) + +ic <- c( + list(which(x %in% coord) - 1), + as.list(which(!x %in% coord) - 1) +) + +# Fit via early stopping +fit <- xgb.train( + params = list( + learning_rate = 0.15, + objective = "reg:squarederror", + max_depth = 5, + interaction_constraints = ic + ), + data = dtrain, + watchlist = list(valid = dvalid), + early_stopping_rounds = 20, + nrounds = 1000, + callbacks = list(cb.print.evaluation(period = 100)) +) + +average_loss(fit, X = X_valid, y = y_valid) # 0.0247 MSE -> 0.157 RMSE +perm_importance(fit, X = X_valid, y = y_valid) |> + plot() +# Or combining some features +v_groups <- list( + coord = c("longitude", "latitude"), + size = c("lnd_sqfoot", "tot_lvg_area"), + condition = c("age", "structure_quality") +) +perm_importance(fit, v = v_groups, X = X_valid, y = y_valid) |> + plot() +H <- hstats(fit, v = x, X = X_valid) +H +plot(H) +plot(H, zero = FALSE) +h2_pairwise(H, zero = FALSE, squared = FALSE, normalize = FALSE) +partial_dep(fit, v = "tot_lvg_area", X = X_valid) |> + plot() +partial_dep(fit, v = "tot_lvg_area", X = X_valid, BY = "structure_quality") |> + plot(show_points = FALSE) +plot(ii <- ice(fit, v = "tot_lvg_area", X = X_valid)) +plot(ii, center = TRUE) + +# Spatial plots +g <- unique(X_valid[, coord]) +pp <- partial_dep(fit, v = coord, X = X_valid, grid = g) +plot(pp, d2_geom = "point", alpha = 0.5, size = 1) + + coord_equal() +partial_dep(fit, v = coord, X = X_valid, grid = g, BY = "structure_quality") |> + plot(pp, d2_geom = "point", alpha = 0.5) + + coord_equal() + +#===================================== +# Naive benchmark +#===================================== + +# iml +predf <- function(object, newdata) predict(object, data.matrix(newdata[x])) +mod <- Predictor$new(fit, data = as.data.frame(X_valid), y = y_valid, + predict.function = predf) + +# DALEX +ex <- DALEX::explain(fit, data = X_valid, y = y_valid) + +# flashlight (my slightly old fashioned package) +fl <- flashlight( + model = fit, data = valid, y = "log_price", predict_function = predf, label = "lm" +) + +# Permutation importance: 10 repeats over full validation data (~2700 rows) +microbenchmark( + iml = FeatureImp$new(mod, n.repetitions = 10, loss = "mse", compare = "difference"), + dalex = feature_importance(ex, B = 10, type = "difference", n_sample = Inf), + flashlight = light_importance(fl, v = x, n_max = Inf, m_repetitions = 10), + hstats = perm_importance(fit, X = X_valid, y = y_valid, perms = 10), + times = 4 +) +# +# Unit: milliseconds +# expr min lq mean median uq max neval cld +# iml 1558.3352 1585.3964 1651.9098 1625.5042 1718.4233 1798.2958 4 a +# dalex 556.1398 573.8428 594.5660 592.1752 615.2893 637.7739 4 b +# flashlight 1207.8085 1238.2424 1347.5105 1340.0633 1456.7787 1502.1071 4 c +# hstats 146.0656 146.9564 151.3652 149.4352 155.7741 160.5249 4 d + +# Partial dependence (cont) +v <- "tot_lvg_area" +microbenchmark( + iml = FeatureEffect$new(mod, feature = v, grid.size = 50, method = "pdp"), + dalex = partial_dependence(ex, variables = v, N = Inf, grid_points = 50), + flashlight = light_profile(fl, v = v, pd_n_max = Inf, n_bins = 50), + hstats = partial_dep(fit, v = v, X = X_valid, grid_size = 50, n_max = Inf), + times = 4 +) +# Unit: milliseconds +# expr min lq mean median uq max neval cld +# iml 941.7763 968.5576 993.0481 1002.5849 1017.5386 1025.2462 4 a +# dalex 694.8007 740.1619 767.1501 788.6172 794.1384 796.5654 4 b +# flashlight 327.6056 328.7617 330.4069 330.5388 332.0522 332.9445 4 c +# hstats 216.4040 217.0602 217.5606 217.8603 218.0611 218.1179 4 d + +# Partial dependence (discrete) +v <- "structure_quality" +microbenchmark( + iml = FeatureEffect$new(mod, feature = v, method = "pdp", grid.points = 1:5), + dalex = partial_dependence(ex, variables = v, N = Inf, variable_type = "categorical", grid_points = 5), + flashlight = light_profile(fl, v = v, pd_n_max = Inf), + hstats = partial_dep(fit, v = v, X = X_valid, n_max = Inf), + times = 4 +) + +# Unit: milliseconds +# expr min lq mean median uq max neval cld +# iml 90.3690 91.08965 94.18403 92.57250 97.27840 101.2221 4 a +# dalex 174.2517 174.97330 179.43483 175.87115 183.89635 191.7453 4 b +# flashlight 43.9318 45.05070 48.09375 46.64275 51.13680 55.1577 4 c +# hstats 24.5972 24.64975 25.01325 24.94085 25.37675 25.5741 4 d + +# H-Stats -> we use a subset of 500 rows +X_v500 <- X_valid[1:500, ] +mod500 <- Predictor$new(fit, data = as.data.frame(X_v500), y = y_valid[1:500], + predict.function = predf) +fl500 <- flashlight(fl, data = as.data.frame(valid[1:500, ])) + +# iml # 77 s (no pairwise possible) +system.time( + iml_overall <- Interaction$new(mod500, grid.size = 500) +) + +# flashlight: 12s total, doing only one pairwise calculation, otherwise would take 63s +system.time( # 10s + fl_overall <- light_interaction(fl500, v = x, grid_size = Inf, n_max = Inf) +) +system.time( # 2s + fl_pairwise <- light_interaction( + fl500, v = coord, grid_size = Inf, n_max = Inf, pairwise = TRUE + ) +) + +# hstats: 3s total +system.time({ + H <- hstats(fit, v = x, X = X_v500, n_max = Inf) + hstats_overall <- h2_overall(H, squared = FALSE, zero = FALSE) + hstats_pairwise <- h2_pairwise(H, squared = FALSE, zero = FALSE) +} +) + +# Overall statistics correspond exactly +iml_overall$results |> filter(.interaction > 1e-6) +# .feature .interaction +# 1: latitude 0.2458269 +# 2: longitude 0.2458269 + +fl_overall$data |> subset(value > 0, select = c(variable, value)) +# variable value +# 1 latitude 0.246 +# 2 longitude 0.246 + +hstats_overall +# longitude latitude +# 0.2458269 0.2458269 + +# Pairwise results match as well +fl_pairwise$data |> subset(value > 0, select = c(variable, value)) +# latitude:longitude 0.394 + +hstats_pairwise +# latitude:longitude +# 0.3942526 \ No newline at end of file diff --git a/backlog/modeltuner.R b/backlog/modeltuner.R index ddd40db0..eb716ba6 100644 --- a/backlog/modeltuner.R +++ b/backlog/modeltuner.R @@ -1,3 +1,4 @@ +devtools::install_github("mayer79/hstats") # -> will become release 1.0.0 library(hstats) library(modeltuner) @@ -20,8 +21,7 @@ perm_importance(mm, X = iris, y = "Sepal.Length", w = "Petal.Width") |> plot() # Interaction statistics (H-statistics) -H <- hstats(mm, X = iris[-1], w = "Petal.Width") -H +(H <- hstats(mm, X = iris[-1], w = "Petal.Width")) plot(H) h2_pairwise(H, normalize = FALSE, squared = FALSE) |> plot(swap_dim = TRUE) diff --git a/man/plot.partial_dep.Rd b/man/plot.partial_dep.Rd index 888aea0a..94194b4d 100644 --- a/man/plot.partial_dep.Rd +++ b/man/plot.partial_dep.Rd @@ -12,6 +12,7 @@ facet_scales = "fixed", rotate_x = FALSE, show_points = TRUE, + d2_geom = c("tile", "point"), ... ) } @@ -38,6 +39,9 @@ E.g., to switch to a standard viridis scale, you can change the default via \item{show_points}{Logical flag indicating whether to show points (default) or not. No effect for 2D PDPs.} +\item{d2_geom}{The geometry used for 2D PDPs, by default "tile". The other option is +"point", which is useful, e.g., when the grid represents spatial points.} + \item{...}{Arguments passed to geometries.} } \value{ diff --git a/tests/testthat/test_partial_dep.R b/tests/testthat/test_partial_dep.R index 841f3b55..69e4d1fb 100644 --- a/tests/testthat/test_partial_dep.R +++ b/tests/testthat/test_partial_dep.R @@ -274,7 +274,7 @@ test_that("partial_dep() works on matrices and dfs", { # Some plots test_that("Plots give 'ggplot' objects", { - fit <- lm(Sepal.Length ~ . + Species * Petal.Length, data = iris) + fit <- lm(Sepal.Length ~ ., data = iris) # One v, no by, univariate expect_s3_class(plot(partial_dep(fit, v = "Species", X = iris)), "ggplot") @@ -285,20 +285,29 @@ test_that("Plots give 'ggplot' objects", { expect_s3_class(plot(pd, swap_dim = TRUE), "ggplot") # Two v, no by, univariate - v <- c("Species", "Petal.Width") + v <- c("Petal.Length", "Petal.Width") pd <- partial_dep(fit, v = v, X = iris) expect_s3_class(plot(pd), "ggplot") + # Two v, no by, univariate, prespecified grid + g <- unique(iris[v]) + pd <- partial_dep(fit, v = v, X = iris, grid = g) + expect_s3_class(plot(pd, d2_geom = "point"), "ggplot") + # Two v, with by, univariate - pd <- partial_dep(fit, v = v, X = iris, BY = "Sepal.Width") + pd <- partial_dep(fit, v = v, X = iris, BY = "Species") expect_s3_class(plot(pd), "ggplot") + # Two v, with by, univariate, prespecified grid + pd <- partial_dep(fit, v = v, X = iris, BY = "Species", grid = g) + expect_s3_class(plot(pd, d2_geom = "point"), "ggplot") + # Three v gives error - pd <- partial_dep(fit, v = c(v, "Petal.Length"), X = iris) + pd <- partial_dep(fit, v = c(v, "Species"), X = iris) expect_error(plot(pd)) # Now multioutput - fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris) + fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris) # One v, no by, multivariate pd <- partial_dep(fit, v = "Species", X = iris) @@ -314,6 +323,10 @@ test_that("Plots give 'ggplot' objects", { pd <- partial_dep(fit, v = v, X = iris) expect_s3_class(plot(pd, rotate_x = TRUE), "ggplot") + # Two v, no by, multivariate, prespecified grid + pd <- partial_dep(fit, v = v, X = iris, grid = g) + expect_s3_class(plot(pd, d2_geom = "point", alpha = 0.5), "ggplot") + # Two v, with by, multivariate gives error pd <- partial_dep(fit, v = v, X = iris, BY = "Petal.Width") expect_error(plot(pd))