From 5ef16f022f0dc5100c317bf915003488fea7b128 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 27 Jun 2022 11:17:37 +0100 Subject: [PATCH] XGBoost - Use DaskDeviceQuantileDMatrix with GPU Training (#528) * Use the DaskDeviceQuantileDMatrix data type when using GPU training * Add test for XGBoost with gpu_hist tree_method * Add use_quantile parameter to XGBoost.fit method This allows optionally disabling the default functionality which uses a DeviceQuantile matrix when using GPU training. --- merlin/models/xgb/__init__.py | 16 +++++++++++++++- tests/unit/xgb/test_xgboost.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/merlin/models/xgb/__init__.py b/merlin/models/xgb/__init__.py index 70f357b2f2..087a43ea29 100644 --- a/merlin/models/xgb/__init__.py +++ b/merlin/models/xgb/__init__.py @@ -76,6 +76,8 @@ def dask_client(self) -> Optional[distributed.Client]: def fit( self, train: Dataset, + *, + use_quantile=True, **train_kwargs, ) -> xgb.Booster: """Trains the XGBoost Model. @@ -90,6 +92,11 @@ def fit( The training dataset to use to fit the model. We will use the column(s) tagged with merlin.schema.Tags.TARGET that match the objective as the label(s). + use_quantile : bool + This param is only relevant when using GPU. (with + tree_method="gpu_hist"). If set to False, will use a + `DaskDMatrix`, instead of the default + `DaskDeviceQuantileDMatrix`, which is preferred for GPU training. **train_kwargs Additional keyword arguments passed to the xgboost.train function @@ -108,7 +115,14 @@ def fit( self.qid_column, ) - dtrain = xgb.dask.DaskDMatrix(self.dask_client, X, label=y, qid=qid) + dmatrix_cls = xgb.dask.DaskDMatrix + if self.params.get("tree_method") == "gpu_hist" and use_quantile: + # `DaskDeviceQuantileDMatrix` is a data type specialized + # for the `gpu_hist` tree method that reduces memory overhead. + # When training on GPU pipeline, it's preferred over `DaskDMatrix`. + dmatrix_cls = xgb.dask.DaskDeviceQuantileDMatrix + + dtrain = dmatrix_cls(self.dask_client, X, label=y, qid=qid) watchlist = [(dtrain, "train")] booster: xgb.Booster = xgb.dask.train( diff --git a/tests/unit/xgb/test_xgboost.py b/tests/unit/xgb/test_xgboost.py index 2b54ac2d9c..a5689ee853 100644 --- a/tests/unit/xgb/test_xgboost.py +++ b/tests/unit/xgb/test_xgboost.py @@ -13,8 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from unittest.mock import patch + import pytest +import xgboost +from merlin.core.dispatch import HAS_GPU from merlin.io import Dataset from merlin.models.xgb import XGBoost @@ -101,3 +105,33 @@ def test_pairwise(self, social_data: Dataset): model.fit(social_data) model.predict(social_data) model.evaluate(social_data) + + +@pytest.mark.skipif(not HAS_GPU, reason="No GPU available") +@pytest.mark.parametrize( + ["fit_kwargs", "expected_dtrain_cls"], + [ + ({}, xgboost.dask.DaskDeviceQuantileDMatrix), + ({"use_quantile": False}, xgboost.dask.DaskDMatrix), + ], +) +@patch("xgboost.dask.train", side_effect=xgboost.dask.train) +def test_gpu_hist_dmatrix( + mock_train, fit_kwargs, expected_dtrain_cls, dask_client, music_streaming_data: Dataset +): + schema = music_streaming_data.schema + model = XGBoost(schema, objective="reg:logistic", tree_method="gpu_hist") + model.fit(music_streaming_data, **fit_kwargs) + model.predict(music_streaming_data) + metrics = model.evaluate(music_streaming_data) + assert "rmse" in metrics + + assert mock_train.called + assert mock_train.call_count == 1 + + train_call = mock_train.call_args_list[0] + client, params, dtrain = train_call.args + assert dask_client == client + assert params["tree_method"] == "gpu_hist" + assert params["objective"] == "reg:logistic" + assert isinstance(dtrain, expected_dtrain_cls)