diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 7761afef7a..d9c610d785 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -5,6 +5,7 @@ .. _UCX: https://github.com/openucx/ucx """ +import itertools import logging import struct import weakref @@ -34,7 +35,11 @@ # variables to be set before being imported. ucp = None host_array = None +host_concat = None +host_split = None device_array = None +device_concat = None +device_split = None def synchronize_stream(stream=0): @@ -47,7 +52,7 @@ def synchronize_stream(stream=0): def init_once(): - global ucp, host_array, device_array + global ucp, host_array, host_concat, host_split, device_array, device_concat, device_split if ucp is not None: return @@ -61,12 +66,45 @@ def init_once(): ucp.init(options=ucx_config, env_takes_precedence=True) # Find the function, `host_array()`, to use when allocating new host arrays + # Also find `host_concat()` and `host_split()` to merge/split frames try: import numpy - host_array = lambda n: numpy.empty((n,), dtype="u1") + host_array = lambda n: numpy.empty((n,), dtype="u1").data + host_concat = lambda arys: numpy.concatenate( + [numpy.asarray(memoryview(e)).view("u1") for e in arys], axis=None + ).data + host_split = lambda ary, indices: [ + e.copy().data + for e in numpy.split(numpy.asarray(memoryview(ary)).view("u1"), indices) + ] except ImportError: - host_array = lambda n: bytearray(n) + host_array = lambda n: memoryview(bytearray(n)) + + def host_concat(arys): + arys = [memoryview(a) for a in arys] + sizes = [nbytes(a) for a in arys] + r = host_array(sum(sizes)) + r_view = memoryview(r) + for each_ary, each_size in zip(arys, sizes): + if each_size: + r_view[:each_size] = each_ary + r_view = r_view[each_size:] + return r + + def host_split(a, indices): + arys = [] + a_view = memoryview(a) + indices = list(indices) + for each_ij in zip([0] + indices, indices + [a.size]): + each_size = each_ij[1] - each_ij[0] + each_slice = slice(*each_ij) + each_ary = host_array(each_size) + if each_size: + each_ary_view = memoryview(each_ary) + each_ary_view[:] = a_view[each_slice] + arys.append(each_ary) + return arys # Find the function, `cuda_array()`, to use when allocating new CUDA arrays try: @@ -100,6 +138,80 @@ def device_array(n): "In order to send/recv CUDA arrays, Numba or RMM is required" ) + # Find the functions `device_concat` and `device_split` + try: + import cupy + + def dask_cupy_allocator(nbytes): + a = device_array(nbytes) + ptr = a.__cuda_array_interface__["data"][0] + dev_id = -1 if ptr else cupy.cuda.device.get_device_id() + mem = cupy.cuda.UnownedMemory( + ptr=ptr, size=nbytes, owner=a, device_id=dev_id + ) + return cupy.cuda.memory.MemoryPointer(mem, 0) + + def device_concat(arys): + with cupy.cuda.using_allocator(dask_cupy_allocator): + arys = [cupy.asarray(e).view("u1") for e in arys] + result = cupy.concatenate(arys, axis=None) + result_buffer = result.data.mem._owner + return result_buffer + + def device_split(ary, indices): + ary = cupy.asarray(ary).view("u1") + ary_split = cupy.split(ary, indices) + results = [] + result_buffers = [] + for e in ary_split: + b2 = device_array(e.nbytes) + e2 = cupy.asarray(b2) + cupy.copyto(e2, e) + results.append(e2) + result_buffers.append(b2) + return result_buffers + + except ImportError: + try: + import numba.cuda + + def device_concat(arys): + arys = [numba.cuda.as_cuda_array(a).view("u1") for a in arys] + sizes = [nbytes(a) for a in arys] + r = device_array(sum(sizes)) + r_view = numba.cuda.as_cuda_array(r) + for each_ary, each_size in zip(arys, sizes): + if each_size: + r_view[:each_size] = each_ary[:] + r_view = r_view[each_size:] + return r + + def device_split(a, indices): + arys = [] + a_view = numba.cuda.as_cuda_array(a).view("u1") + indices = list(indices) + for each_ij in zip([0] + indices, indices + [a.size]): + each_size = each_ij[1] - each_ij[0] + each_slice = slice(*each_ij) + each_ary = device_array(each_size) + if each_size: + each_ary_view = numba.cuda.as_cuda_array(each_ary) + each_ary_view[:] = a_view[each_slice] + arys.append(each_ary) + return arys + + except ImportError: + + def device_concat(arys): + raise RuntimeError( + "In order to send/recv CUDA arrays, CuPy or Numba is required" + ) + + def device_split(a, indices): + raise RuntimeError( + "In order to send/recv CUDA arrays, CuPy or Numba is required" + ) + pool_size_str = dask.config.get("rmm.pool-size") if pool_size_str is not None: pool_size = parse_bytes(pool_size_str) @@ -178,16 +290,22 @@ async def write( frames = await to_frames( msg, serializers=serializers, on_error=on_error ) + nframes = len(frames) - cuda_frames = tuple( - hasattr(f, "__cuda_array_interface__") for f in frames - ) - sizes = tuple(nbytes(f) for f in frames) - send_frames = [ - each_frame - for each_frame, each_size in zip(frames, sizes) - if each_size - ] + cuda_frames = [] + sizes = [] + device_frames = [] + host_frames = [] + for each_frame in frames: + is_cuda = hasattr(each_frame, "__cuda_array_interface__") + each_size = nbytes(each_frame) + cuda_frames.append(is_cuda) + sizes.append(each_size) + if each_size: + if is_cuda: + device_frames.append(each_frame) + else: + host_frames.append(each_frame) # Send meta data @@ -201,16 +319,24 @@ async def write( # Send frames - # It is necessary to first synchronize the default stream before start sending - # We synchronize the default stream because UCX is not stream-ordered and - # syncing the default stream will wait for other non-blocking CUDA streams. - # Note this is only sufficient if the memory being sent is not currently in use on - # non-blocking CUDA streams. - if any(cuda_frames): + if host_frames: + if len(host_frames) == 1: + host_frames = host_frames[0] + else: + host_frames = host_concat(host_frames) + await self.ep.send(host_frames) + if device_frames: + if len(device_frames) == 1: + device_frames = device_frames[0] + else: + device_frames = device_concat(device_frames) + # It is necessary to first synchronize the default stream before start sending + # We synchronize the default stream because UCX is not stream-ordered and + # syncing the default stream will wait for other non-blocking CUDA streams. + # Note this is only sufficient if the memory being sent is not currently in use on + # non-blocking CUDA streams. synchronize_stream(0) - - for each_frame in send_frames: - await self.ep.send(each_frame) + await self.ep.send(device_frames) return sum(sizes) except (ucp.exceptions.UCXBaseException): self.abort() @@ -245,21 +371,52 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): raise CommClosedError("While reading, the connection was closed") else: # Recv frames - frames = [ - device_array(each_size) if is_cuda else host_array(each_size) - for is_cuda, each_size in zip(cuda_frames, sizes) - ] - recv_frames = [ - each_frame for each_frame in frames if len(each_frame) > 0 - ] - - # It is necessary to first populate `frames` with CUDA arrays and synchronize - # the default stream before starting receiving to ensure buffers have been allocated - if any(cuda_frames): - synchronize_stream(0) + host_frame_sizes = [] + device_frame_sizes = [] + for is_cuda, each_size in zip(cuda_frames, sizes): + if is_cuda: + device_frame_sizes.append(each_size) + else: + host_frame_sizes.append(each_size) + + if host_frame_sizes: + host_frames = host_array(sum(host_frame_sizes)) + if host_frames.nbytes: + await self.ep.recv(host_frames) + if device_frame_sizes: + device_frames = device_array(sum(device_frame_sizes)) + if device_frames.nbytes: + # It is necessary to first populate `frames` with CUDA arrays and synchronize + # the default stream before starting receiving to ensure buffers have been allocated + synchronize_stream(0) + await self.ep.recv(device_frames) + + if len(host_frame_sizes) == 0: + host_frames = [] + elif len(host_frame_sizes) == 1: + host_frames = [host_frames] + else: + host_frames = host_split( + host_frames, list(itertools.accumulate(host_frame_sizes[:-1])) + ) + + if len(device_frame_sizes) == 0: + device_frames = [] + elif len(device_frame_sizes) == 1: + device_frames = [device_frames] + else: + device_frames = device_split( + device_frames, + list(itertools.accumulate(device_frame_sizes[:-1])), + ) + + frames = [] + for is_cuda in cuda_frames: + if is_cuda: + frames.append(device_frames.pop(0)) + else: + frames.append(host_frames.pop(0)) - for each_frame in recv_frames: - await self.ep.recv(each_frame) msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers )