Skip to content

Commit dc092ae

Browse files
authored
Validate reference dataset for training. (#11105)
1 parent f06dcf8 commit dc092ae

File tree

4 files changed

+56
-14
lines changed

4 files changed

+56
-14
lines changed

python-package/xgboost/core.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -1451,7 +1451,20 @@ def _ref_data_from_csr(self, csr: scipy.sparse.csr_matrix) -> None:
14511451
)
14521452

14531453

1454-
class QuantileDMatrix(DMatrix):
1454+
class _RefMixIn:
1455+
@property
1456+
def ref(self) -> Optional[weakref.ReferenceType]:
1457+
"""Internal method for retrieving a reference to the training DMatrix."""
1458+
if hasattr(self, "_ref"):
1459+
return self._ref
1460+
return None
1461+
1462+
@ref.setter
1463+
def ref(self, ref: weakref.ReferenceType) -> None:
1464+
self._ref = ref
1465+
1466+
1467+
class QuantileDMatrix(DMatrix, _RefMixIn):
14551468
"""A DMatrix variant that generates quantilized data directly from input for the
14561469
``hist`` tree method. This DMatrix is primarily designed to save memory in training
14571470
by avoiding intermediate storage. Set ``max_bin`` to control the number of bins
@@ -1640,8 +1653,11 @@ def _init(
16401653
_check_call(ret)
16411654
self.handle = handle
16421655

1656+
if ref is not None:
1657+
self.ref = weakref.ref(ref)
1658+
16431659

1644-
class ExtMemQuantileDMatrix(DMatrix):
1660+
class ExtMemQuantileDMatrix(DMatrix, _RefMixIn):
16451661
"""The external memory version of the :py:class:`QuantileDMatrix`.
16461662
16471663
See :doc:`/tutorials/external_memory` for explanation and usage examples, and
@@ -1739,6 +1755,9 @@ def _init(
17391755
_check_call(ret)
17401756
self.handle = handle
17411757

1758+
if ref is not None:
1759+
self.ref = weakref.ref(ref)
1760+
17421761

17431762
Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
17441763
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]

python-package/xgboost/training.py

+15
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""Training Library containing training routines."""
44
import copy
55
import os
6+
import weakref
67
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast
78

89
import numpy as np
@@ -147,6 +148,20 @@ def train(
147148
callbacks = [] if callbacks is None else copy.copy(list(callbacks))
148149
evals = list(evals) if evals else []
149150

151+
for va, _ in evals:
152+
if not isinstance(va, DMatrix):
153+
raise TypeError("Invalid type for the `evals`.")
154+
155+
if (
156+
hasattr(va, "ref")
157+
and va.ref is not weakref.ref(dtrain)
158+
and va is not dtrain
159+
):
160+
raise ValueError(
161+
"Training dataset should be used as a reference when constructing "
162+
"the `QuantileDMatrix` for evaluation."
163+
)
164+
150165
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
151166
start_iteration = 0
152167

tests/python-gpu/test_device_quantile_dmatrix.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def test_ref_dmatrix(self) -> None:
175175
import cupy as cp
176176

177177
rng = cp.random.RandomState(np.uint64(1994))
178-
self.cputest.run_ref_dmatrix(rng, "gpu_hist", False)
178+
self.cputest.run_ref_dmatrix(rng, "cuda", False)
179179

180180
@given(
181181
strategies.integers(1, 1000),

tests/python/test_quantile_dmatrix.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,13 @@ def test_training(self, sparsity: float) -> None:
170170
}
171171
xgb.train(parameters, Xy)
172172

173-
def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None:
173+
def run_ref_dmatrix(self, rng: Any, device: str, enable_cat: bool) -> None:
174174
n_samples, n_features = 2048, 17
175175
if enable_cat:
176176
X, y = make_categorical(
177177
n_samples, n_features, n_categories=13, onehot=False
178178
)
179-
if tree_method == "gpu_hist":
179+
if device == "cuda":
180180
import cudf
181181

182182
X = cudf.from_pandas(X)
@@ -189,10 +189,12 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None:
189189

190190
# Use ref
191191
Xy = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat)
192-
Xy_valid = xgb.QuantileDMatrix(X, y, ref=Xy, enable_categorical=enable_cat)
192+
Xy_valid: xgb.DMatrix = xgb.QuantileDMatrix(
193+
X, y, ref=Xy, enable_categorical=enable_cat
194+
)
193195
qdm_results: Dict[str, Dict[str, List[float]]] = {}
194196
xgb.train(
195-
{"tree_method": tree_method},
197+
{"tree_method": "hist", "device": device},
196198
Xy,
197199
evals=[(Xy, "Train"), (Xy_valid, "valid")],
198200
evals_result=qdm_results,
@@ -201,10 +203,10 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None:
201203
qdm_results["Train"]["rmse"], qdm_results["valid"]["rmse"]
202204
)
203205
# No ref
204-
Xy_valid = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat)
206+
Xy_valid = xgb.DMatrix(X, y, enable_categorical=enable_cat)
205207
qdm_results = {}
206208
xgb.train(
207-
{"tree_method": tree_method},
209+
{"tree_method": "hist", "device": device},
208210
Xy,
209211
evals=[(Xy, "Train"), (Xy_valid, "valid")],
210212
evals_result=qdm_results,
@@ -229,7 +231,7 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None:
229231
n_samples, n_features = 256, 17
230232
if enable_cat:
231233
X, y = make_categorical(n_samples, n_features, 13, onehot=False)
232-
if tree_method == "gpu_hist":
234+
if device == "cuda":
233235
import cudf
234236

235237
X = cudf.from_pandas(X)
@@ -246,15 +248,15 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None:
246248

247249
qdm_results = {}
248250
xgb.train(
249-
{"tree_method": tree_method},
251+
{"tree_method": "hist", "device": device},
250252
Xy,
251253
evals=[(Xy, "Train"), (Xy_valid, "valid")],
252254
evals_result=qdm_results,
253255
)
254256

255257
dm_results: Dict[str, Dict[str, List[float]]] = {}
256258
xgb.train(
257-
{"tree_method": tree_method},
259+
{"tree_method": "hist", "device": device},
258260
dXy,
259261
evals=[(dXy, "Train"), (dXy_valid, "valid"), (Xy_valid_d, "dvalid")],
260262
evals_result=dm_results,
@@ -269,13 +271,19 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None:
269271
dm_results["dvalid"]["rmse"], qdm_results["valid"]["rmse"]
270272
)
271273

274+
Xy_valid = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat)
275+
with pytest.raises(ValueError, match="should be used as a reference"):
276+
xgb.train(
277+
{"device": device}, dXy, evals=[(dXy, "Train"), (Xy_valid, "Valid")]
278+
)
279+
272280
def test_ref_quantile_cut(self) -> None:
273281
check_ref_quantile_cut("cpu")
274282

275283
def test_ref_dmatrix(self) -> None:
276284
rng = np.random.RandomState(1994)
277-
self.run_ref_dmatrix(rng, "hist", True)
278-
self.run_ref_dmatrix(rng, "hist", False)
285+
self.run_ref_dmatrix(rng, "cpu", True)
286+
self.run_ref_dmatrix(rng, "cpu", False)
279287

280288
@pytest.mark.parametrize("sparsity", [0.0, 0.5])
281289
def test_predict(self, sparsity: float) -> None:

0 commit comments

Comments
 (0)