Skip to content

Commit

Permalink
Merge pull request #72 from mayer79/more_examples
Browse files Browse the repository at this point in the history
add XGBoost and LightGBM example to README
  • Loading branch information
mayer79 authored Oct 8, 2023
2 parents 0592fa5 + 4428ef3 commit 8e30f88
Show file tree
Hide file tree
Showing 3 changed files with 471 additions and 0 deletions.
120 changes: 120 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,126 @@ perm_importance(fit, X = iris, y = "Species", loss = "mlogloss")

![](man/figures/multivariate_ice.svg)

## XGBoost and LightGBM

Here, we provide simple examples for working with XGBoost and LightGBM classification.

Their predict function requires to pass `reshape = TRUE`.

### LightGBM

```r
library(hstats)
library(lightgbm)

set.seed(1)

ix <- c(1:40, 51:90, 101:140)
X <- data.matrix(iris[, -5])
y <- as.integer(iris[, 5]) - 1
X_train <- X[ix, ]
X_valid <- X[-ix, ]
y_train <- y[ix]
y_valid <- y[-ix]

params <- list(objective = "multiclass", num_class = 3, learning_rate = 0.2)
dtrain <- lgb.Dataset(X_train, label = y_train)
dvalid <- lgb.Dataset(X_valid, label = y_valid)

fit <- lgb.train(
params = params,
data = dtrain,
valids = list(valid = dvalid),
early_stopping_rounds = 20,
nrounds = 1000
)

# Check that predictions require reshape = TRUE to be a matrix
predict(fit, head(X_pred, 2), reshape = TRUE)
# [,1] [,2] [,3]
# [1,] 0.9999997 2.918695e-07 2.858720e-14
# [2,] 0.9999999 1.038470e-07 7.337221e-10

# mlogloss: 9.331699e-05
average_loss(fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE)

partial_dep(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(show_points = FALSE)

ice(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(swap_dim = TRUE, alpha = 0.05)

perm_importance(
fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE, m_rep = 100
)
# Permutation importance regarding mlogloss
# Petal.Length Petal.Width Sepal.Width Sepal.Length
# 2.61783760 1.00647382 0.08414687 0.01011645

# Interaction statistics (H-statistics)
(H <- hstats(fit, X = X_train, reshape = TRUE)) # 0.3010446 0.4167927 0.1623982
plot(H, normalize = FALSE, squared = FALSE)
```

![](man/figures/lightgbm.svg)

### XGBoost

```r
library(hstats)
library(xgboost)

set.seed(1)

ix <- c(1:40, 51:90, 101:140)
X <- data.matrix(iris[, -5])
y <- as.integer(iris[, 5]) - 1
X_train <- X[ix, ]
X_valid <- X[-ix, ]
y_train <- y[ix]
y_valid <- y[-ix]

params <- list(objective = "multi:softprob", num_class = 3, learning_rate = 0.2)
dtrain <- xgb.DMatrix(X_train, label = y_train)
dvalid <- xgb.DMatrix(X_valid, label = y_valid)

fit <- xgb.train(
params = params,
data = dtrain,
watchlist = list(valid = dvalid),
early_stopping_rounds = 20,
nrounds = 1000
)

# We need to pass reshape = TRUE to get a beautiful matrix
predict(fit, head(X_pred, 2), reshape = TRUE)
# [,1] [,2] [,3]
# [1,] 0.9974016 0.002130089 0.0004682819
# [2,] 0.9971375 0.002129525 0.0007328897

# mlogloss: 0.006689544
average_loss(fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE)

partial_dep(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(show_points = FALSE)

ice(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(alpha = 0.05)

perm_importance(
fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE, m_rep = 100
)
# Permutation importance regarding mlogloss
# Petal.Length Petal.Width Sepal.Length Sepal.Width
# 1.731532873 0.276671377 0.009158659 0.005717263

# Interaction statistics (H-statistics)
(H <- hstats(fit, X = X_train, reshape = TRUE)) # 0.02714399 0.16067364 0.11606973
plot(H, normalize = FALSE, squared = FALSE)
```

![](man/figures/xgboost.svg)

## Meta-learning packages

Here, we provide some working examples for "tidymodels", "caret", and "mlr3".
Expand Down
Loading

0 comments on commit 8e30f88

Please sign in to comment.