Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two-dim PDP with geom_point() #77

Merged
merged 1 commit into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions R/partial_dep.R
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ print.partial_dep <- function(x, n = 3L, ...) {
#' To change the global option, use `options(stats.color = new value)`.
#' @param show_points Logical flag indicating whether to show points (default) or not.
#' No effect for 2D PDPs.
#' @param d2_geom The geometry used for 2D PDPs, by default "tile". The other option is
#' "point", which is useful, e.g., when the grid represents spatial points.
#' @param ... Arguments passed to geometries.
#' @inheritParams plot.hstats_matrix
#' @export
Expand All @@ -295,7 +297,8 @@ plot.partial_dep <- function(x,
swap_dim = FALSE,
viridis_args = getOption("hstats.viridis_args"),
facet_scales = "fixed",
rotate_x = FALSE, show_points = TRUE, ...) {
rotate_x = FALSE, show_points = TRUE,
d2_geom = c("tile", "point"), ...) {
v <- x[["v"]]
by_name <- x[["by_name"]]
K <- x[["K"]]
Expand Down Expand Up @@ -347,15 +350,20 @@ plot.partial_dep <- function(x,
}
} else if (length(v) == 2L) {
# Heat maps
d2_geom <- match.arg(d2_geom)
if (K > 1L || !is.null(by_name)) { # Only one is possible
wrp <- if (K > 1L) "varying_" else by_name
}
p <- ggplot2::ggplot(
data, ggplot2::aes(x = .data[[v[1L]]], y = .data[[v[2L]]], fill = value_)
) +
ggplot2::geom_tile(...) +
do.call(ggplot2::scale_fill_viridis_c, viridis_args) +
ggplot2::labs(fill = "PD")
p <- ggplot2::ggplot(data, ggplot2::aes(x = .data[[v[1L]]], y = .data[[v[2L]]]))
if (d2_geom == "tile") {
p <- p + ggplot2::geom_tile(ggplot2::aes(fill = value_), ...) +
do.call(ggplot2::scale_fill_viridis_c, viridis_args) +
ggplot2::labs(fill = "PD")
} else if (d2_geom == "point") {
p <- p + ggplot2::geom_point(ggplot2::aes(color = value_), ...) +
do.call(ggplot2::scale_color_viridis_c, viridis_args) +
ggplot2::labs(color = "PD")
}
}
if (!is.null(wrp)) {
p <- p + ggplot2::facet_wrap(wrp, scales = facet_scales)
Expand Down
2 changes: 1 addition & 1 deletion R/utils_plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ get_color_scale <- function(x) {
#'
#' @returns A data.frame with reverted factor levels.
barplot_reverter <- function(df, group = TRUE) {
x <- c("variable_", if (group) "varying_")
x <- c("variable_", if (group) "varying_")
for (z in x) {
f <- df[[z]]
df[[z]] <- factor(f, levels = rev(levels(f)))
Expand Down
201 changes: 201 additions & 0 deletions backlog/benchmark.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
library(hstats)
library(iml)
library(DALEX)
library(ingredients)
library(flashlight)

library(shapviz)
library(xgboost)
library(ggplot2)
library(microbenchmark)

# future::plan(multisession, workers = 1)

colnames(miami) <- tolower(colnames(miami))
miami <- transform(miami, log_price = log(sale_prc))
x <- c("tot_lvg_area", "lnd_sqfoot", "latitude", "longitude",
"structure_quality", "age", "month_sold")
coord <- c("longitude", "latitude")

# Train/valid split
set.seed(1)
ix <- sample(nrow(miami), 0.8 * nrow(miami))
train <- data.frame(miami[ix, ])
valid <- data.frame(miami[-ix, ])
y_train <- train$log_price
y_valid <- valid$log_price
X_train <- data.matrix(train[x])
X_valid <- data.matrix(valid[x])

dtrain <- xgb.DMatrix(X_train, label = y_train)
dvalid <- xgb.DMatrix(X_valid, label = y_valid)

ic <- c(
list(which(x %in% coord) - 1),
as.list(which(!x %in% coord) - 1)
)

# Fit via early stopping
fit <- xgb.train(
params = list(
learning_rate = 0.15,
objective = "reg:squarederror",
max_depth = 5,
interaction_constraints = ic
),
data = dtrain,
watchlist = list(valid = dvalid),
early_stopping_rounds = 20,
nrounds = 1000,
callbacks = list(cb.print.evaluation(period = 100))
)

average_loss(fit, X = X_valid, y = y_valid) # 0.0247 MSE -> 0.157 RMSE
perm_importance(fit, X = X_valid, y = y_valid) |>
plot()
# Or combining some features
v_groups <- list(
coord = c("longitude", "latitude"),
size = c("lnd_sqfoot", "tot_lvg_area"),
condition = c("age", "structure_quality")
)
perm_importance(fit, v = v_groups, X = X_valid, y = y_valid) |>
plot()
H <- hstats(fit, v = x, X = X_valid)
H
plot(H)
plot(H, zero = FALSE)
h2_pairwise(H, zero = FALSE, squared = FALSE, normalize = FALSE)
partial_dep(fit, v = "tot_lvg_area", X = X_valid) |>
plot()
partial_dep(fit, v = "tot_lvg_area", X = X_valid, BY = "structure_quality") |>
plot(show_points = FALSE)
plot(ii <- ice(fit, v = "tot_lvg_area", X = X_valid))
plot(ii, center = TRUE)

# Spatial plots
g <- unique(X_valid[, coord])
pp <- partial_dep(fit, v = coord, X = X_valid, grid = g)
plot(pp, d2_geom = "point", alpha = 0.5, size = 1) +
coord_equal()
partial_dep(fit, v = coord, X = X_valid, grid = g, BY = "structure_quality") |>
plot(pp, d2_geom = "point", alpha = 0.5) +
coord_equal()

#=====================================
# Naive benchmark
#=====================================

# iml
predf <- function(object, newdata) predict(object, data.matrix(newdata[x]))
mod <- Predictor$new(fit, data = as.data.frame(X_valid), y = y_valid,
predict.function = predf)

# DALEX
ex <- DALEX::explain(fit, data = X_valid, y = y_valid)

# flashlight (my slightly old fashioned package)
fl <- flashlight(
model = fit, data = valid, y = "log_price", predict_function = predf, label = "lm"
)

# Permutation importance: 10 repeats over full validation data (~2700 rows)
microbenchmark(
iml = FeatureImp$new(mod, n.repetitions = 10, loss = "mse", compare = "difference"),
dalex = feature_importance(ex, B = 10, type = "difference", n_sample = Inf),
flashlight = light_importance(fl, v = x, n_max = Inf, m_repetitions = 10),
hstats = perm_importance(fit, X = X_valid, y = y_valid, perms = 10),
times = 4
)
#
# Unit: milliseconds
# expr min lq mean median uq max neval cld
# iml 1558.3352 1585.3964 1651.9098 1625.5042 1718.4233 1798.2958 4 a
# dalex 556.1398 573.8428 594.5660 592.1752 615.2893 637.7739 4 b
# flashlight 1207.8085 1238.2424 1347.5105 1340.0633 1456.7787 1502.1071 4 c
# hstats 146.0656 146.9564 151.3652 149.4352 155.7741 160.5249 4 d

# Partial dependence (cont)
v <- "tot_lvg_area"
microbenchmark(
iml = FeatureEffect$new(mod, feature = v, grid.size = 50, method = "pdp"),
dalex = partial_dependence(ex, variables = v, N = Inf, grid_points = 50),
flashlight = light_profile(fl, v = v, pd_n_max = Inf, n_bins = 50),
hstats = partial_dep(fit, v = v, X = X_valid, grid_size = 50, n_max = Inf),
times = 4
)
# Unit: milliseconds
# expr min lq mean median uq max neval cld
# iml 941.7763 968.5576 993.0481 1002.5849 1017.5386 1025.2462 4 a
# dalex 694.8007 740.1619 767.1501 788.6172 794.1384 796.5654 4 b
# flashlight 327.6056 328.7617 330.4069 330.5388 332.0522 332.9445 4 c
# hstats 216.4040 217.0602 217.5606 217.8603 218.0611 218.1179 4 d

# Partial dependence (discrete)
v <- "structure_quality"
microbenchmark(
iml = FeatureEffect$new(mod, feature = v, method = "pdp", grid.points = 1:5),
dalex = partial_dependence(ex, variables = v, N = Inf, variable_type = "categorical", grid_points = 5),
flashlight = light_profile(fl, v = v, pd_n_max = Inf),
hstats = partial_dep(fit, v = v, X = X_valid, n_max = Inf),
times = 4
)

# Unit: milliseconds
# expr min lq mean median uq max neval cld
# iml 90.3690 91.08965 94.18403 92.57250 97.27840 101.2221 4 a
# dalex 174.2517 174.97330 179.43483 175.87115 183.89635 191.7453 4 b
# flashlight 43.9318 45.05070 48.09375 46.64275 51.13680 55.1577 4 c
# hstats 24.5972 24.64975 25.01325 24.94085 25.37675 25.5741 4 d

# H-Stats -> we use a subset of 500 rows
X_v500 <- X_valid[1:500, ]
mod500 <- Predictor$new(fit, data = as.data.frame(X_v500), y = y_valid[1:500],
predict.function = predf)
fl500 <- flashlight(fl, data = as.data.frame(valid[1:500, ]))

# iml # 77 s (no pairwise possible)
system.time(
iml_overall <- Interaction$new(mod500, grid.size = 500)
)

# flashlight: 12s total, doing only one pairwise calculation, otherwise would take 63s
system.time( # 10s
fl_overall <- light_interaction(fl500, v = x, grid_size = Inf, n_max = Inf)
)
system.time( # 2s
fl_pairwise <- light_interaction(
fl500, v = coord, grid_size = Inf, n_max = Inf, pairwise = TRUE
)
)

# hstats: 3s total
system.time({
H <- hstats(fit, v = x, X = X_v500, n_max = Inf)
hstats_overall <- h2_overall(H, squared = FALSE, zero = FALSE)
hstats_pairwise <- h2_pairwise(H, squared = FALSE, zero = FALSE)
}
)

# Overall statistics correspond exactly
iml_overall$results |> filter(.interaction > 1e-6)
# .feature .interaction
# 1: latitude 0.2458269
# 2: longitude 0.2458269

fl_overall$data |> subset(value > 0, select = c(variable, value))
# variable value
# 1 latitude 0.246
# 2 longitude 0.246

hstats_overall
# longitude latitude
# 0.2458269 0.2458269

# Pairwise results match as well
fl_pairwise$data |> subset(value > 0, select = c(variable, value))
# latitude:longitude 0.394

hstats_pairwise
# latitude:longitude
# 0.3942526
4 changes: 2 additions & 2 deletions backlog/modeltuner.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
devtools::install_github("mayer79/hstats") # -> will become release 1.0.0
library(hstats)
library(modeltuner)

Expand All @@ -20,8 +21,7 @@ perm_importance(mm, X = iris, y = "Sepal.Length", w = "Petal.Width") |>
plot()

# Interaction statistics (H-statistics)
H <- hstats(mm, X = iris[-1], w = "Petal.Width")
H
(H <- hstats(mm, X = iris[-1], w = "Petal.Width"))
plot(H)
h2_pairwise(H, normalize = FALSE, squared = FALSE) |>
plot(swap_dim = TRUE)
4 changes: 4 additions & 0 deletions man/plot.partial_dep.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 18 additions & 5 deletions tests/testthat/test_partial_dep.R
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ test_that("partial_dep() works on matrices and dfs", {

# Some plots
test_that("Plots give 'ggplot' objects", {
fit <- lm(Sepal.Length ~ . + Species * Petal.Length, data = iris)
fit <- lm(Sepal.Length ~ ., data = iris)

# One v, no by, univariate
expect_s3_class(plot(partial_dep(fit, v = "Species", X = iris)), "ggplot")
Expand All @@ -285,20 +285,29 @@ test_that("Plots give 'ggplot' objects", {
expect_s3_class(plot(pd, swap_dim = TRUE), "ggplot")

# Two v, no by, univariate
v <- c("Species", "Petal.Width")
v <- c("Petal.Length", "Petal.Width")
pd <- partial_dep(fit, v = v, X = iris)
expect_s3_class(plot(pd), "ggplot")

# Two v, no by, univariate, prespecified grid
g <- unique(iris[v])
pd <- partial_dep(fit, v = v, X = iris, grid = g)
expect_s3_class(plot(pd, d2_geom = "point"), "ggplot")

# Two v, with by, univariate
pd <- partial_dep(fit, v = v, X = iris, BY = "Sepal.Width")
pd <- partial_dep(fit, v = v, X = iris, BY = "Species")
expect_s3_class(plot(pd), "ggplot")

# Two v, with by, univariate, prespecified grid
pd <- partial_dep(fit, v = v, X = iris, BY = "Species", grid = g)
expect_s3_class(plot(pd, d2_geom = "point"), "ggplot")

# Three v gives error
pd <- partial_dep(fit, v = c(v, "Petal.Length"), X = iris)
pd <- partial_dep(fit, v = c(v, "Species"), X = iris)
expect_error(plot(pd))

# Now multioutput
fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)

# One v, no by, multivariate
pd <- partial_dep(fit, v = "Species", X = iris)
Expand All @@ -314,6 +323,10 @@ test_that("Plots give 'ggplot' objects", {
pd <- partial_dep(fit, v = v, X = iris)
expect_s3_class(plot(pd, rotate_x = TRUE), "ggplot")

# Two v, no by, multivariate, prespecified grid
pd <- partial_dep(fit, v = v, X = iris, grid = g)
expect_s3_class(plot(pd, d2_geom = "point", alpha = 0.5), "ggplot")

# Two v, with by, multivariate gives error
pd <- partial_dep(fit, v = v, X = iris, BY = "Petal.Width")
expect_error(plot(pd))
Expand Down