Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] [dask] decide and document how users should provide a Dask client at training time #3808

Closed
jameslamb opened this issue Jan 21, 2021 · 10 comments

Comments

@jameslamb
Copy link
Collaborator

jameslamb commented Jan 21, 2021

Summary

The model objects in https://github.com/microsoft/LightGBM/blob/master/python-package/lightgbm/dask.py currently inherit from the classes in https://github.com/microsoft/LightGBM/blob/master/python-package/lightgbm/sklearn.py. For example, lightgbm.dask.DaskLGBMRegressor is a child of lightgbm.sklearn.LGBMRegressor.

These Dask estimators perform training on data stored in Dask collections (DataFrame and Array). To do this, they require a Dask client to talk to the cluster that those collections are stored on.

This issue describes the current state and possible alternatives for how LightGBM chooses the client to use for that task.

Option 1 - provide client in a keyword arg to .fit() (current state)

In Dask training, users provide a client to the cluster that they'd like to use for training. Currently, this is done in the .fit() method (which overrides the parent method from the sklearn interface).

def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the LGBMModel."""
return self._fit(LGBMClassifier, X, y, sample_weight, client, **kwargs)

Good properties of this pattern:

  • allows full user control over the cluster used for training
  • only used at runtime of .fit(), so DaskLGBMClassifier and DaskLGBMRegressor can be serialized with pickle / cloudpickle (I need to confirm that...but I know for sure that you cannot pickle a Dask client)

Bad properties of this pattern

  • makes the signature of the .fit() method look different from the sklearn equivalent, because of the extra parameter client

Option 2 - don't allow users to customize client

We could choose not to expose client at all as a parameter, and just use the default client that's available, for example by asking people to run training in a context manager (shown below), more generally, by using distributed.get_client()

from distributed import Client
with Client(cluster) as client:
    dask_reg = DaskLGBMRegressor(...)
    dask_reg.fit(X, y)

Good properties

  • keeps the constructor AND .fit() methods identical to the sklearn equivalents
  • can be generated at runtime of .fit(), so DaskLGBMClassifier and DaskLGBMRegressor can be serialized with pickle / cloudpickle (I need to confirm that...but I know for sure that you cannot pickle a Dask client)

Bad properties

  • users could end up with the default client without knowing it, and this could cause issues that lead to confusion or poor performance

Option 3 - allow users to set .client after construction

In this option, we could allow users to set .client after constructing an estimator, and use that property if it's set. Otherwise, use default client.

dask_reg = DaskLGBMRegressor(...)
dask_reg.client = Client(cluster) 
dask_reg.fit(X, y) 

This is what xgboost does in their Dask interface --> https://github.com/dmlc/xgboost/blob/7bc56fa0eda6984197152ed4aa98d2fc9e9e8cc8/python-package/xgboost/dask.py#L171

Good properties

  • keeps the constructor AND .fit() methods identical to the sklearn equivalents

Bad properties

  • users could end up with the default client without knowing it, and this could cause issues that lead to confusion or poor performance
  • this extra step of setting an attribute might feel unnatural to those who are used to the scikit-learn pattern of constructing an estimator then calling .fit()
  • might have to do a bit of extra work to be sure that a DaskLGBMRegressor object (for example) can be pickled

Option 4 - optionally pass a client in the model object's constructor

This is what cuml does for their Dask interface, for example: https://docs.rapids.ai/api/cuml/stable/api.html#cuml.dask.cluster.KMeans. If client = None, then in cuml you get the default client based on the context.

dask_reg = DaskLGBMRegressor(client=Client(cluster))
dask_reg.fit(X, y) 

Good properties

  • allows tight control over the client used, while also providing a you-don't-have-to-think-about-it option as the default

Bad properties

Other related questions

This project's Dask interface is very new, and there are other unresolved questions that might be related to this one. Like

  • should .predict() on Dask estimators only accept inputs that are Dask collections? [will insert issue link here soon]
    • today, it doesn't allow you to choose a client because it expects that the input will be local data (pandas or numpy)
  • what set of options do we want to support for saving models trained with the Dask estimators? [docs] [dask] Document how to save a Dask model #3838
@jameslamb
Copy link
Collaborator Author

I tried to write down a useful summary of the problem above, but it's possible that I've missed things or made it unclear. Apologies if that's the case.

@ffineis @hcho3 @trivialfis @StrikerRUS whenever you have time, could you review this and let me know if you have strong opinions or additional information that might help make this decision?

I'm going to hold back my personal opinion on this until I've heard others, to not bias the outcome 😀

@StrikerRUS
Copy link
Collaborator

@jameslamb I personally don't like option #3 very much...

I saw in official examples scikit-learn uses similar to option #2 approach - they ask users to use joblib context with dask backend.

To fit it using the cluster, we just need to use a context manager provided by joblib.
https://examples.dask.org/machine-learning.html

import joblib

with joblib.parallel_backend('dask'):
    grid_search.fit(X, y)

Maybe we can create a voting among Dask users? Special Dask repo for discussions or/and Twitter with appropriate hashtags/mentions?..

@jsignell
Copy link

I think the dasky-est approach would be to use not allow the user to explicitly set the client and instead assume that they have initiated a client already. This is basically Option 2, but I want to clarify that it doesn't necessitate that training be run inside a context manager. It is enough to have already instantiated a client object. I think you'll use default_client() for this.

Other questions

should .predict() on Dask estimators only accept inputs that are Dask collections

I would expect it to accept both pandas/numpy and dask collections. This is similar to the behavior of map_blocks or map_partitions just for example.

@jameslamb
Copy link
Collaborator Author

Thanks @jsignell !

To be fair, i also just added an Option 4. I forgot that cuml chose to have the client passed into the model object's constructor.

@jameslamb
Copy link
Collaborator Author

Maybe we can create a voting among Dask users?

I've invited more commentary on dask/community#104

@martindurant
Copy link

I mostly agree with @jsignell , but I do think that allowing an optional argument to the model constructor is a reasonable and useful addition; you might just want to have separate things fitting on different clusters simultaneously.
The default would be to use the current client (evaluated at execution time), which would be the most recently created global client or the current context, if there is one. Actually this allows for .client = too, but you wouldn't want to document that option probably, to avoid confusion.

@jsignell
Copy link

Actually this allows for .client = too, but you wouldn't want to document that option probably, to avoid confusion.

I forgot to mention that. Yeah, that could be the way to implement fitting on different clusters. It could be documented in an advanced section.

@jameslamb
Copy link
Collaborator Author

jameslamb commented Jan 31, 2021

Thanks for the comments and suggestions, everyone! I've just opened #3883 with my proposal for this issue.

I'm proposing removing the client keyword arg from .fit() and .predict(), and supporting these options:

# never telling LightGBM to use a specific client (uses distributed.default_client())
clf = lgb.DaskLGBMClassifier()
clf.fit(X, y)

# keyword arg in constructor
clf = lgb.DaskLGBMClassifier(client=client)
clf.fit(X, y)

# setting an attribute
clf = lgb.DaskLGBMClassifier()
clf.set_params(client = client)
clf.fit(X, y)

If you have time and interest, comments on #3883 are welcome.

@jameslamb
Copy link
Collaborator Author

Thanks @jsignell and @martindurant for joining this issue and giving us your expertise! I just merged #3883 with the changes described in the comment above.

@github-actions
Copy link

This issue has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Aug 23, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

4 participants