diff --git a/dask_cuda/plugins.py b/dask_cuda/plugins.py index 4e3726e4..c2844278 100644 --- a/dask_cuda/plugins.py +++ b/dask_cuda/plugins.py @@ -3,7 +3,11 @@ from distributed import WorkerPlugin -from .utils import get_rmm_log_file_name, parse_device_memory_limit, enable_rmm_memory_for_library +from .utils import ( + enable_rmm_memory_for_library, + get_rmm_log_file_name, + parse_device_memory_limit, +) class CPUAffinity(WorkerPlugin): @@ -64,7 +68,6 @@ def __init__( self.rmm_track_allocations = track_allocations self.external_lib_list = external_lib_list - def setup(self, worker=None): if self.initial_pool_size is not None: self.initial_pool_size = parse_device_memory_limit( @@ -125,7 +128,7 @@ def setup(self, worker=None): mr = rmm.mr.get_current_device_resource() rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr)) - + if self.external_lib_list is not None: for lib in self.external_lib_list: enable_rmm_memory_for_library(lib)