Skip to content

Commit

Permalink
XGBoost - Use DaskDeviceQuantileDMatrix with GPU Training (#528)
Browse files Browse the repository at this point in the history
* Use the DaskDeviceQuantileDMatrix data type when using GPU training

* Add test for XGBoost with gpu_hist tree_method

* Add use_quantile parameter to XGBoost.fit method

This allows optionally disabling the default functionality which uses
a DeviceQuantile matrix when using GPU training.
  • Loading branch information
oliverholworthy authored Jun 27, 2022
1 parent 10755a8 commit 5ef16f0
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
16 changes: 15 additions & 1 deletion merlin/models/xgb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def dask_client(self) -> Optional[distributed.Client]:
def fit(
self,
train: Dataset,
*,
use_quantile=True,
**train_kwargs,
) -> xgb.Booster:
"""Trains the XGBoost Model.
Expand All @@ -90,6 +92,11 @@ def fit(
The training dataset to use to fit the model.
We will use the column(s) tagged with merlin.schema.Tags.TARGET that match the
objective as the label(s).
use_quantile : bool
This param is only relevant when using GPU. (with
tree_method="gpu_hist"). If set to False, will use a
`DaskDMatrix`, instead of the default
`DaskDeviceQuantileDMatrix`, which is preferred for GPU training.
**train_kwargs
Additional keyword arguments passed to the xgboost.train function
Expand All @@ -108,7 +115,14 @@ def fit(
self.qid_column,
)

dtrain = xgb.dask.DaskDMatrix(self.dask_client, X, label=y, qid=qid)
dmatrix_cls = xgb.dask.DaskDMatrix
if self.params.get("tree_method") == "gpu_hist" and use_quantile:
# `DaskDeviceQuantileDMatrix` is a data type specialized
# for the `gpu_hist` tree method that reduces memory overhead.
# When training on GPU pipeline, it's preferred over `DaskDMatrix`.
dmatrix_cls = xgb.dask.DaskDeviceQuantileDMatrix

dtrain = dmatrix_cls(self.dask_client, X, label=y, qid=qid)
watchlist = [(dtrain, "train")]

booster: xgb.Booster = xgb.dask.train(
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/xgb/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unittest.mock import patch

import pytest
import xgboost

from merlin.core.dispatch import HAS_GPU
from merlin.io import Dataset
from merlin.models.xgb import XGBoost

Expand Down Expand Up @@ -101,3 +105,33 @@ def test_pairwise(self, social_data: Dataset):
model.fit(social_data)
model.predict(social_data)
model.evaluate(social_data)


@pytest.mark.skipif(not HAS_GPU, reason="No GPU available")
@pytest.mark.parametrize(
["fit_kwargs", "expected_dtrain_cls"],
[
({}, xgboost.dask.DaskDeviceQuantileDMatrix),
({"use_quantile": False}, xgboost.dask.DaskDMatrix),
],
)
@patch("xgboost.dask.train", side_effect=xgboost.dask.train)
def test_gpu_hist_dmatrix(
mock_train, fit_kwargs, expected_dtrain_cls, dask_client, music_streaming_data: Dataset
):
schema = music_streaming_data.schema
model = XGBoost(schema, objective="reg:logistic", tree_method="gpu_hist")
model.fit(music_streaming_data, **fit_kwargs)
model.predict(music_streaming_data)
metrics = model.evaluate(music_streaming_data)
assert "rmse" in metrics

assert mock_train.called
assert mock_train.call_count == 1

train_call = mock_train.call_args_list[0]
client, params, dtrain = train_call.args
assert dask_client == client
assert params["tree_method"] == "gpu_hist"
assert params["objective"] == "reg:logistic"
assert isinstance(dtrain, expected_dtrain_cls)

0 comments on commit 5ef16f0

Please sign in to comment.