Skip to content

Commit b202ebe

Browse files
authored
[R] Add evaluation set and early stopping for xgboost() (#11065)
1 parent 6c2d5b3 commit b202ebe

File tree

6 files changed

+331
-28
lines changed

6 files changed

+331
-28
lines changed

R-package/R/xgb.train.R

+16-7
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,22 @@
4848
#' If 2, some additional information will be printed out.
4949
#' Note that setting `verbose > 0` automatically engages the
5050
#' `xgb.cb.print.evaluation(period=1)` callback function.
51-
#' @param print_every_n Print each nth iteration evaluation messages when `verbose>0`.
52-
#' Default is 1 which means all messages are printed. This parameter is passed to the
53-
#' [xgb.cb.print.evaluation()] callback.
54-
#' @param early_stopping_rounds If `NULL`, the early stopping function is not triggered.
55-
#' If set to an integer `k`, training with a validation set will stop if the performance
56-
#' doesn't improve for `k` rounds. Setting this parameter engages the [xgb.cb.early.stop()] callback.
57-
#' @param maximize If `custom_metric` and `early_stopping_rounds` are set, then this parameter must be set as well.
51+
#' @param print_every_n When passing `verbose>0`, evaluation logs (metrics calculated on the
52+
#' data passed under `evals`) will be printed every nth iteration according to the value passed
53+
#' here. The first and last iteration are always included regardless of this 'n'.
54+
#'
55+
#' Only has an effect when passing data under `evals` and when passing `verbose>0`. The parameter
56+
#' is passed to the [xgb.cb.print.evaluation()] callback.
57+
#' @param early_stopping_rounds Number of boosting rounds after which training will be stopped
58+
#' if there is no improvement in performance (as measured by the evaluatiation metric that is
59+
#' supplied or selected by default for the objective) on the evaluation data passed under
60+
#' `evals`.
61+
#'
62+
#' Must pass `evals` in order to use this functionality. Setting this parameter adds the
63+
#' [xgb.cb.early.stop()] callback.
64+
#'
65+
#' If `NULL`, early stopping will not be used.
66+
#' @param maximize If `feval` and `early_stopping_rounds` are set, then this parameter must be set as well.
5867
#' When it is `TRUE`, it means the larger the evaluation score the better.
5968
#' This parameter is passed to the [xgb.cb.early.stop()] callback.
6069
#' @param save_period When not `NULL`, model is saved to disk after every `save_period` rounds.

R-package/R/xgboost.R

+178-6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ prescreen.parameters <- function(params) {
2222

2323
prescreen.objective <- function(objective) {
2424
if (!is.null(objective)) {
25+
if (!is.character(objective) || length(objective) != 1L || is.na(objective)) {
26+
stop("'objective' must be a single character/string variable.")
27+
}
28+
2529
if (objective %in% .OBJECTIVES_NON_DEFAULT_MODE()) {
2630
stop(
2731
"Objectives with non-default prediction mode (",
@@ -30,8 +34,8 @@ prescreen.objective <- function(objective) {
3034
)
3135
}
3236

33-
if (!is.character(objective) || length(objective) != 1L || is.na(objective)) {
34-
stop("'objective' must be a single character/string variable.")
37+
if (objective %in% .RANKING_OBJECTIVES()) {
38+
stop("Ranking objectives are not supported in 'xgboost()'. Try 'xgb.train()'.")
3539
}
3640
}
3741
}
@@ -501,7 +505,7 @@ check.nthreads <- function(nthreads) {
501505
return(as.integer(nthreads))
502506
}
503507

504-
check.can.use.qdm <- function(x, params) {
508+
check.can.use.qdm <- function(x, params, eval_set) {
505509
if ("booster" %in% names(params)) {
506510
if (params$booster == "gblinear") {
507511
return(FALSE)
@@ -512,6 +516,9 @@ check.can.use.qdm <- function(x, params) {
512516
return(FALSE)
513517
}
514518
}
519+
if (NROW(eval_set)) {
520+
return(FALSE)
521+
}
515522
return(TRUE)
516523
}
517524

@@ -717,6 +724,129 @@ process.x.and.col.args <- function(
717724
return(lst_args)
718725
}
719726

727+
process.eval.set <- function(eval_set, lst_args) {
728+
if (!NROW(eval_set)) {
729+
return(NULL)
730+
}
731+
nrows <- nrow(lst_args$dmatrix_args$data)
732+
is_classif <- hasName(lst_args$metadata, "y_levels")
733+
processed_y <- lst_args$dmatrix_args$label
734+
eval_set <- as.vector(eval_set)
735+
if (length(eval_set) == 1L) {
736+
737+
eval_set <- as.numeric(eval_set)
738+
if (is.na(eval_set) || eval_set < 0 || eval_set >= 1) {
739+
stop("'eval_set' as a fraction must be a number between zero and one (non-inclusive).")
740+
}
741+
if (eval_set == 0) {
742+
return(NULL)
743+
}
744+
nrow_eval <- as.integer(round(nrows * eval_set, 0))
745+
if (nrow_eval < 1) {
746+
warning(
747+
"Desired 'eval_set' fraction amounts to zero observations.",
748+
" Will not create evaluation set."
749+
)
750+
return(NULL)
751+
}
752+
nrow_train <- nrows - nrow_eval
753+
if (nrow_train < 2L) {
754+
stop("Desired 'eval_set' fraction would leave less than 2 observations for training data.")
755+
}
756+
if (is_classif && nrow_train < length(lst_args$metadata$y_levels)) {
757+
stop("Desired 'eval_set' fraction would not leave enough samples for each class of 'y'.")
758+
}
759+
760+
seed <- lst_args$params$seed
761+
if (!is.null(seed)) {
762+
set.seed(seed)
763+
}
764+
765+
idx_shuffled <- sample(nrows, nrows, replace = FALSE)
766+
idx_eval <- idx_shuffled[seq(1L, nrow_eval)]
767+
idx_train <- idx_shuffled[seq(nrow_eval + 1L, nrows)]
768+
# Here we want the training set to include all of the classes of 'y' for classification
769+
# objectives. If that condition doesn't hold with the random sample, then it forcibly
770+
# makes a new random selection in such a way that the condition would always hold, by
771+
# first sampling one random example of 'y' for training and then choosing the evaluation
772+
# set from the remaining rows. The procedure here is quite inefficient, but there aren't
773+
# enough random-related functions in base R to be able to construct an efficient version.
774+
if (is_classif && length(unique(processed_y[idx_train])) < length(lst_args$metadata$y_levels)) {
775+
# These are defined in order to avoid NOTEs from CRAN checks
776+
# when using non-standard data.table evaluation with column names.
777+
idx <- NULL
778+
y <- NULL
779+
ranked_idx <- NULL
780+
chosen <- NULL
781+
782+
dt <- data.table::data.table(y = processed_y, idx = seq(1L, nrows))[
783+
, .(
784+
ranked_idx = seq(1L, .N),
785+
chosen = rep(sample(.N, 1L), .N),
786+
idx
787+
)
788+
, by = y
789+
]
790+
min_idx_train <- dt[ranked_idx == chosen, idx]
791+
rem_idx <- dt[ranked_idx != chosen, idx]
792+
if (length(rem_idx) == nrow_eval) {
793+
idx_train <- min_idx_train
794+
idx_eval <- rem_idx
795+
} else {
796+
rem_idx <- rem_idx[sample(length(rem_idx), length(rem_idx), replace = FALSE)]
797+
idx_eval <- rem_idx[seq(1L, nrow_eval)]
798+
idx_train <- c(min_idx_train, rem_idx[seq(nrow_eval + 1L, length(rem_idx))])
799+
}
800+
}
801+
802+
} else {
803+
804+
if (any(eval_set != floor(eval_set))) {
805+
stop("'eval_set' as indices must contain only integers.")
806+
}
807+
eval_set <- as.integer(eval_set)
808+
idx_min <- min(eval_set)
809+
if (is.na(idx_min) || idx_min < 1L) {
810+
stop("'eval_set' contains invalid indices.")
811+
}
812+
idx_max <- max(eval_set)
813+
if (is.na(idx_max) || idx_max > nrows) {
814+
stop("'eval_set' contains row indices beyond the size of the input data.")
815+
}
816+
idx_train <- seq(1L, nrows)[-eval_set]
817+
if (is_classif && length(unique(processed_y[idx_train])) < length(lst_args$metadata$y_levels)) {
818+
warning("'eval_set' indices will leave some classes of 'y' outside of the training data.")
819+
}
820+
idx_eval <- eval_set
821+
822+
}
823+
824+
# Note: slicing is done in the constructed DMatrix object instead of in the
825+
# original input, because objects from 'Matrix' might change class after
826+
# being sliced (e.g. 'dgRMatrix' turns into 'dgCMatrix').
827+
return(list(idx_train = idx_train, idx_eval = idx_eval))
828+
}
829+
830+
check.early.stopping.rounds <- function(early_stopping_rounds, eval_set) {
831+
if (is.null(early_stopping_rounds)) {
832+
return(NULL)
833+
}
834+
if (is.null(eval_set)) {
835+
stop("'early_stopping_rounds' requires passing 'eval_set'.")
836+
}
837+
if (NROW(early_stopping_rounds) != 1L) {
838+
stop("'early_stopping_rounds' must be NULL or an integer greater than zero.")
839+
}
840+
early_stopping_rounds <- as.integer(early_stopping_rounds)
841+
if (is.na(early_stopping_rounds) || early_stopping_rounds <= 0L) {
842+
stop(
843+
"'early_stopping_rounds' must be NULL or an integer greater than zero. Got: ",
844+
early_stopping_rounds
845+
)
846+
}
847+
return(early_stopping_rounds)
848+
}
849+
720850
#' Fit XGBoost Model
721851
#'
722852
#' @export
@@ -808,6 +938,35 @@ process.x.and.col.args <- function(
808938
#' 2 (info), and 3 (debug).
809939
#' @param monitor_training Whether to monitor objective optimization progress on the input data.
810940
#' Note that same 'x' and 'y' data are used for both model fitting and evaluation.
941+
#' @param eval_set Subset of the data to use as evaluation set. Can be passed as:
942+
#' - A vector of row indices (base-1 numeration) indicating the observations that are to be designed
943+
#' as evaluation data.
944+
#' - A number between zero and one indicating a random fraction of the input data to use as
945+
#' evaluation data. Note that the selection will be done uniformly at random, regardless of
946+
#' argument `weights`.
947+
#'
948+
#' If passed, this subset of the data will be excluded from the training procedure, and the
949+
#' evaluation metric(s) supplied under `eval_metric` will be calculated on this dataset after each
950+
#' boosting iteration (pass `verbosity>0` to have these metrics printed during training). If
951+
#' `eval_metric` is not passed, a default metric will be selected according to `objective`.
952+
#'
953+
#' If passing a fraction, in classification problems, the evaluation set will be chosen in such a
954+
#' way that at least one observation of each class will be kept in the training data.
955+
#'
956+
#' For more elaborate evaluation variants (e.g. custom metrics, multiple evaluation sets, etc.),
957+
#' one might want to use [xgb.train()] instead.
958+
#' @param early_stopping_rounds Number of boosting rounds after which training will be stopped
959+
#' if there is no improvement in performance (as measured by the last metric passed under
960+
#' `eval_metric`, or by the default metric for the objective if `eval_metric` is not passed) on the
961+
#' evaluation data from `eval_set`. Must pass `eval_set` in order to use this functionality.
962+
#'
963+
#' If `NULL`, early stopping will not be used.
964+
#' @param print_every_n When passing `verbosity>0` and either `monitor_training=TRUE` or `eval_set`,
965+
#' evaluation logs (metrics calculated on the training and/or evaluation data) will be printed every
966+
#' nth iteration according to the value passed here. The first and last iteration are always
967+
#' included regardless of this 'n'.
968+
#'
969+
#' Only has an effect when passing `verbosity>0`.
811970
#' @param nthreads Number of parallel threads to use. If passing zero, will use all CPU threads.
812971
#' @param seed Seed to use for random number generation. If passing `NULL`, will draw a random
813972
#' number using R's PRNG system to use as seed.
@@ -893,8 +1052,11 @@ xgboost <- function(
8931052
objective = NULL,
8941053
nrounds = 100L,
8951054
weights = NULL,
896-
verbosity = 0L,
1055+
verbosity = if (is.null(eval_set)) 0L else 1L,
8971056
monitor_training = verbosity > 0,
1057+
eval_set = NULL,
1058+
early_stopping_rounds = NULL,
1059+
print_every_n = 1L,
8981060
nthreads = parallel::detectCores(),
8991061
seed = 0L,
9001062
monotone_constraints = NULL,
@@ -907,7 +1069,7 @@ xgboost <- function(
9071069
params <- list(...)
9081070
params <- prescreen.parameters(params)
9091071
prescreen.objective(objective)
910-
use_qdm <- check.can.use.qdm(x, params)
1072+
use_qdm <- check.can.use.qdm(x, params, eval_set)
9111073
lst_args <- process.y.margin.and.objective(y, base_margin, objective, params)
9121074
lst_args <- process.row.weights(weights, lst_args)
9131075
lst_args <- process.x.and.col.args(
@@ -918,8 +1080,9 @@ xgboost <- function(
9181080
lst_args,
9191081
use_qdm
9201082
)
1083+
eval_set <- process.eval.set(eval_set, lst_args)
9211084

922-
if (use_qdm && "max_bin" %in% names(params)) {
1085+
if (use_qdm && hasName(params, "max_bin")) {
9231086
lst_args$dmatrix_args$max_bin <- params$max_bin
9241087
}
9251088

@@ -929,18 +1092,27 @@ xgboost <- function(
9291092
lst_args$params$seed <- seed
9301093

9311094
params <- c(lst_args$params, params)
1095+
params$verbosity <- verbosity
9321096

9331097
fn_dm <- if (use_qdm) xgb.QuantileDMatrix else xgb.DMatrix
9341098
dm <- do.call(fn_dm, lst_args$dmatrix_args)
1099+
if (!is.null(eval_set)) {
1100+
dm_eval <- xgb.slice.DMatrix(dm, eval_set$idx_eval)
1101+
dm <- xgb.slice.DMatrix(dm, eval_set$idx_train)
1102+
}
9351103
evals <- list()
9361104
if (monitor_training) {
9371105
evals <- list(train = dm)
9381106
}
1107+
if (!is.null(eval_set)) {
1108+
evals <- c(evals, list(eval = dm_eval))
1109+
}
9391110
model <- xgb.train(
9401111
params = params,
9411112
data = dm,
9421113
nrounds = nrounds,
9431114
verbose = verbosity,
1115+
print_every_n = print_every_n,
9441116
evals = evals
9451117
)
9461118
attributes(model)$metadata <- lst_args$metadata

R-package/man/xgb.cv.Rd

+16-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R-package/man/xgb.train.Rd

+16-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)