Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/branch-24.12' into reenable-uc…
Browse files Browse the repository at this point in the history
…xx-ci
  • Loading branch information
pentschev committed Oct 14, 2024
2 parents 481720d + 8d88006 commit 81dfaba
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 1 deletion.
14 changes: 13 additions & 1 deletion dask_cuda/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from distributed.utils import import_term

from .cuda_worker import CUDAWorker
from .utils import print_cluster_config
from .utils import CommaSeparatedChoice, print_cluster_config

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -164,6 +164,16 @@ def cuda():
incompatible with RMM pools and managed memory, trying to enable both will
result in failure.""",
)
@click.option(
"--set-rmm-allocator-for-libs",
"rmm_allocator_external_lib_list",
type=CommaSeparatedChoice(["cupy", "torch"]),
default=None,
show_default=True,
help="""
Set RMM as the allocator for external libraries. Provide a comma-separated
list of libraries to set, e.g., "torch,cupy".""",
)
@click.option(
"--rmm-release-threshold",
default=None,
Expand Down Expand Up @@ -351,6 +361,7 @@ def worker(
rmm_maximum_pool_size,
rmm_managed_memory,
rmm_async,
rmm_allocator_external_lib_list,
rmm_release_threshold,
rmm_log_directory,
rmm_track_allocations,
Expand Down Expand Up @@ -425,6 +436,7 @@ def worker(
rmm_maximum_pool_size,
rmm_managed_memory,
rmm_async,
rmm_allocator_external_lib_list,
rmm_release_threshold,
rmm_log_directory,
rmm_track_allocations,
Expand Down
2 changes: 2 additions & 0 deletions dask_cuda/cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
rmm_maximum_pool_size=None,
rmm_managed_memory=False,
rmm_async=False,
rmm_allocator_external_lib_list=None,
rmm_release_threshold=None,
rmm_log_directory=None,
rmm_track_allocations=False,
Expand Down Expand Up @@ -231,6 +232,7 @@ def del_pid_file():
release_threshold=rmm_release_threshold,
log_directory=rmm_log_directory,
track_allocations=rmm_track_allocations,
external_lib_list=rmm_allocator_external_lib_list,
),
PreImport(pre_import),
CUDFSetup(spill=enable_cudf_spill, spill_stats=cudf_spill_stats),
Expand Down
22 changes: 22 additions & 0 deletions dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ class LocalCUDACluster(LocalCluster):
The asynchronous allocator requires CUDA Toolkit 11.2 or newer. It is also
incompatible with RMM pools and managed memory. Trying to enable both will
result in an exception.
rmm_allocator_external_lib_list: str, list or None, default None
List of external libraries for which to set RMM as the allocator.
Supported options are: ``["torch", "cupy"]``. Can be a comma-separated string
(like ``"torch,cupy"``) or a list of strings (like ``["torch", "cupy"]``).
If ``None``, no external libraries will use RMM as their allocator.
rmm_release_threshold: int, str or None, default None
When ``rmm.async is True`` and the pool size grows beyond this value, unused
memory held by the pool will be released at the next synchronization point.
Expand Down Expand Up @@ -231,6 +236,7 @@ def __init__(
rmm_maximum_pool_size=None,
rmm_managed_memory=False,
rmm_async=False,
rmm_allocator_external_lib_list=None,
rmm_release_threshold=None,
rmm_log_directory=None,
rmm_track_allocations=False,
Expand Down Expand Up @@ -265,6 +271,19 @@ def __init__(
n_workers = len(CUDA_VISIBLE_DEVICES)
if n_workers < 1:
raise ValueError("Number of workers cannot be less than 1.")

if rmm_allocator_external_lib_list is not None:
if isinstance(rmm_allocator_external_lib_list, str):
rmm_allocator_external_lib_list = [
v.strip() for v in rmm_allocator_external_lib_list.split(",")
]
elif not isinstance(rmm_allocator_external_lib_list, list):
raise ValueError(
"rmm_allocator_external_lib_list must be either a comma-separated "
"string or a list of strings. Examples: 'torch,cupy' "
"or ['torch', 'cupy']"
)

# Set nthreads=1 when parsing mem_limit since it only depends on n_workers
logger = logging.getLogger(__name__)
self.memory_limit = parse_memory_limit(
Expand All @@ -284,6 +303,8 @@ def __init__(
self.rmm_managed_memory = rmm_managed_memory
self.rmm_async = rmm_async
self.rmm_release_threshold = rmm_release_threshold
self.rmm_allocator_external_lib_list = rmm_allocator_external_lib_list

if rmm_pool_size is not None or rmm_managed_memory or rmm_async:
try:
import rmm # noqa F401
Expand Down Expand Up @@ -437,6 +458,7 @@ def new_worker_spec(self):
release_threshold=self.rmm_release_threshold,
log_directory=self.rmm_log_directory,
track_allocations=self.rmm_track_allocations,
external_lib_list=self.rmm_allocator_external_lib_list,
),
PreImport(self.pre_import),
CUDFSetup(self.enable_cudf_spill, self.cudf_spill_stats),
Expand Down
67 changes: 67 additions & 0 deletions dask_cuda/plugins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
import os
from typing import Callable, Dict

from distributed import WorkerPlugin

Expand Down Expand Up @@ -39,6 +40,7 @@ def __init__(
release_threshold,
log_directory,
track_allocations,
external_lib_list,
):
if initial_pool_size is None and maximum_pool_size is not None:
raise ValueError(
Expand All @@ -61,6 +63,7 @@ def __init__(
self.logging = log_directory is not None
self.log_directory = log_directory
self.rmm_track_allocations = track_allocations
self.external_lib_list = external_lib_list

def setup(self, worker=None):
if self.initial_pool_size is not None:
Expand Down Expand Up @@ -123,6 +126,70 @@ def setup(self, worker=None):
mr = rmm.mr.get_current_device_resource()
rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr))

if self.external_lib_list is not None:
for lib in self.external_lib_list:
enable_rmm_memory_for_library(lib)


def enable_rmm_memory_for_library(lib_name: str) -> None:
"""Enable RMM memory pool support for a specified third-party library.
This function allows the given library to utilize RMM's memory pool if it supports
integration with RMM. The library name is passed as a string argument, and if the
library is compatible, its memory allocator will be configured to use RMM.
Parameters
----------
lib_name : str
The name of the third-party library to enable RMM memory pool support for.
Supported libraries are "cupy" and "torch".
Raises
------
ValueError
If the library name is not supported or does not have RMM integration.
ImportError
If the required library is not installed.
"""

# Mapping of supported libraries to their respective setup functions
setup_functions: Dict[str, Callable[[], None]] = {
"torch": _setup_rmm_for_torch,
"cupy": _setup_rmm_for_cupy,
}

if lib_name not in setup_functions:
supported_libs = ", ".join(setup_functions.keys())
raise ValueError(
f"The library '{lib_name}' is not supported for RMM integration. "
f"Supported libraries are: {supported_libs}."
)

# Call the setup function for the specified library
setup_functions[lib_name]()


def _setup_rmm_for_torch() -> None:
try:
import torch
except ImportError as e:
raise ImportError("PyTorch is not installed.") from e

from rmm.allocators.torch import rmm_torch_allocator

torch.cuda.memory.change_current_allocator(rmm_torch_allocator)


def _setup_rmm_for_cupy() -> None:
try:
import cupy
except ImportError as e:
raise ImportError("CuPy is not installed.") from e

from rmm.allocators.cupy import rmm_cupy_allocator

cupy.cuda.set_allocator(rmm_cupy_allocator)


class PreImport(WorkerPlugin):
def __init__(self, libraries):
Expand Down
11 changes: 11 additions & 0 deletions dask_cuda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from multiprocessing import cpu_count
from typing import Optional

import click
import numpy as np
import pynvml
import toolz
Expand Down Expand Up @@ -764,3 +765,13 @@ def get_rmm_memory_resource_stack(mr) -> list:
if isinstance(mr, rmm.mr.StatisticsResourceAdaptor):
return mr.allocation_counts["current_bytes"]
return None


class CommaSeparatedChoice(click.Choice):
def convert(self, value, param, ctx):
values = [v.strip() for v in value.split(",")]
for v in values:
if v not in self.choices:
choices_str = ", ".join(f"'{c}'" for c in self.choices)
self.fail(f"invalid choice(s): {v}. (choices are: {choices_str})")
return values

0 comments on commit 81dfaba

Please sign in to comment.