diff --git a/README.md b/README.md index 0d3eee51..187437f6 100644 --- a/README.md +++ b/README.md @@ -277,6 +277,126 @@ perm_importance(fit, X = iris, y = "Species", loss = "mlogloss") ![](man/figures/multivariate_ice.svg) +## XGBoost and LightGBM + +Here, we provide simple examples for working with XGBoost and LightGBM classification. + +Their predict function requires to pass `reshape = TRUE`. + +### LightGBM + +```r +library(hstats) +library(lightgbm) + +set.seed(1) + +ix <- c(1:40, 51:90, 101:140) +X <- data.matrix(iris[, -5]) +y <- as.integer(iris[, 5]) - 1 +X_train <- X[ix, ] +X_valid <- X[-ix, ] +y_train <- y[ix] +y_valid <- y[-ix] + +params <- list(objective = "multiclass", num_class = 3, learning_rate = 0.2) +dtrain <- lgb.Dataset(X_train, label = y_train) +dvalid <- lgb.Dataset(X_valid, label = y_valid) + +fit <- lgb.train( + params = params, + data = dtrain, + valids = list(valid = dvalid), + early_stopping_rounds = 20, + nrounds = 1000 +) + +# Check that predictions require reshape = TRUE to be a matrix +predict(fit, head(X_pred, 2), reshape = TRUE) +# [,1] [,2] [,3] +# [1,] 0.9999997 2.918695e-07 2.858720e-14 +# [2,] 0.9999999 1.038470e-07 7.337221e-10 + +# mlogloss: 9.331699e-05 +average_loss(fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE) + +partial_dep(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |> + plot(show_points = FALSE) + +ice(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |> + plot(swap_dim = TRUE, alpha = 0.05) + +perm_importance( + fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE, m_rep = 100 +) +# Permutation importance regarding mlogloss +# Petal.Length Petal.Width Sepal.Width Sepal.Length +# 2.61783760 1.00647382 0.08414687 0.01011645 + +# Interaction statistics (H-statistics) +(H <- hstats(fit, X = X_train, reshape = TRUE)) # 0.3010446 0.4167927 0.1623982 +plot(H, normalize = FALSE, squared = FALSE) +``` + +![](man/figures/lightgbm.svg) + +### XGBoost + +```r +library(hstats) +library(xgboost) + +set.seed(1) + +ix <- c(1:40, 51:90, 101:140) +X <- data.matrix(iris[, -5]) +y <- as.integer(iris[, 5]) - 1 +X_train <- X[ix, ] +X_valid <- X[-ix, ] +y_train <- y[ix] +y_valid <- y[-ix] + +params <- list(objective = "multi:softprob", num_class = 3, learning_rate = 0.2) +dtrain <- xgb.DMatrix(X_train, label = y_train) +dvalid <- xgb.DMatrix(X_valid, label = y_valid) + +fit <- xgb.train( + params = params, + data = dtrain, + watchlist = list(valid = dvalid), + early_stopping_rounds = 20, + nrounds = 1000 +) + +# We need to pass reshape = TRUE to get a beautiful matrix +predict(fit, head(X_pred, 2), reshape = TRUE) +# [,1] [,2] [,3] +# [1,] 0.9974016 0.002130089 0.0004682819 +# [2,] 0.9971375 0.002129525 0.0007328897 + +# mlogloss: 0.006689544 +average_loss(fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE) + +partial_dep(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |> + plot(show_points = FALSE) + +ice(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |> + plot(alpha = 0.05) + +perm_importance( + fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE, m_rep = 100 +) +# Permutation importance regarding mlogloss +# Petal.Length Petal.Width Sepal.Length Sepal.Width +# 1.731532873 0.276671377 0.009158659 0.005717263 + +# Interaction statistics (H-statistics) +(H <- hstats(fit, X = X_train, reshape = TRUE)) # 0.02714399 0.16067364 0.11606973 +plot(H, normalize = FALSE, squared = FALSE) +``` + +![](man/figures/xgboost.svg) + ## Meta-learning packages Here, we provide some working examples for "tidymodels", "caret", and "mlr3". diff --git a/man/figures/lightgbm.svg b/man/figures/lightgbm.svg new file mode 100644 index 00000000..10483137 --- /dev/null +++ b/man/figures/lightgbm.svg @@ -0,0 +1,181 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +Overall + + + + + + + + + + +Pairwise + + + + + + + +0.00 +0.05 +0.10 +0.15 +0.20 + + + + + + +0.00 +0.05 +0.10 +0.15 +0.20 +0.25 +Sepal.Length:Sepal.Width +Sepal.Width:Petal.Width +Sepal.Length:Petal.Width +Sepal.Width:Petal.Length +Sepal.Length:Petal.Length +Petal.Length:Petal.Width + + + + + + +Sepal.Length +Sepal.Width +Petal.Width +Petal.Length + + + + +H (unnormalized) + + + + + + + +y1 +y2 +y3 + + diff --git a/man/figures/xgboost.svg b/man/figures/xgboost.svg new file mode 100644 index 00000000..aff22978 --- /dev/null +++ b/man/figures/xgboost.svg @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +Overall + + + + + + + + + + +Pairwise + + + + + + +0.00 +0.05 +0.10 +0.15 + + + + +0.00 +0.05 +0.10 +0.15 +Sepal.Length:Sepal.Width +Sepal.Length:Petal.Width +Sepal.Length:Petal.Length +Sepal.Width:Petal.Length +Sepal.Width:Petal.Width +Petal.Length:Petal.Width + + + + + + +Sepal.Length +Sepal.Width +Petal.Width +Petal.Length + + + + +H (unnormalized) + + + + + + + +y1 +y2 +y3 + +