diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 57fa8262e3..eddb6d4655 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -201,11 +201,34 @@ async def write( hasattr(f, "__cuda_array_interface__") for f in frames ) sizes = tuple(nbytes(f) for f in frames) - send_frames = [ - each_frame - for each_frame, each_size in zip(frames, sizes) - if each_size - ] + host_frames = host_array( + sum( + each_size + for is_cuda, each_size in zip(cuda_frames, sizes) + if not is_cuda + ) + ) + device_frames = device_array( + sum( + each_size + for is_cuda, each_size in zip(cuda_frames, sizes) + if is_cuda + ) + ) + + # Pack frames + host_frames_view = memoryview(host_frames) + device_frames_view = as_device_array(device_frames) + for each_frame, is_cuda, each_size in zip(frames, cuda_frames, sizes): + if each_size: + if is_cuda: + each_frame_view = as_device_array(each_frame) + device_frames_view[:each_size] = each_frame_view[:] + device_frames_view = device_frames_view[each_size:] + else: + each_frame_view = memoryview(each_frame).cast("B") + host_frames_view[:each_size] = each_frame_view[:] + host_frames_view = host_frames_view[each_size:] # Send meta data @@ -227,8 +250,10 @@ async def write( if any(cuda_frames): synchronize_stream(0) - for each_frame in send_frames: - await self.ep.send(each_frame) + if nbytes(host_frames): + await self.ep.send(host_frames) + if nbytes(device_frames): + await self.ep.send(device_frames) return sum(sizes) except (ucp.exceptions.UCXBaseException): self.abort() @@ -263,21 +288,48 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): raise CommClosedError("While reading, the connection was closed") else: # Recv frames - frames = [ - device_array(each_size) if is_cuda else host_array(each_size) - for is_cuda, each_size in zip(cuda_frames, sizes) - ] - recv_frames = [ - each_frame for each_frame in frames if len(each_frame) > 0 - ] + host_frames = host_array( + sum( + each_size + for is_cuda, each_size in zip(cuda_frames, sizes) + if not is_cuda + ) + ) + device_frames = device_array( + sum( + each_size + for is_cuda, each_size in zip(cuda_frames, sizes) + if is_cuda + ) + ) # It is necessary to first populate `frames` with CUDA arrays and synchronize # the default stream before starting receiving to ensure buffers have been allocated if any(cuda_frames): synchronize_stream(0) - for each_frame in recv_frames: - await self.ep.recv(each_frame) + if nbytes(host_frames): + await self.ep.recv(host_frames) + if nbytes(device_frames): + await self.ep.recv(device_frames) + + frames = [ + device_array(each_size) if is_cuda else host_array(each_size) + for is_cuda, each_size in zip(cuda_frames, sizes) + ] + host_frames_view = memoryview(host_frames) + device_frames_view = as_device_array(device_frames) + for each_frame, is_cuda, each_size in zip(frames, cuda_frames, sizes): + if each_size: + if is_cuda: + each_frame_view = as_device_array(each_frame) + each_frame_view[:] = device_frames_view[:each_size] + device_frames_view = device_frames_view[each_size:] + else: + each_frame_view = memoryview(each_frame) + each_frame_view[:] = host_frames_view[:each_size] + host_frames_view = host_frames_view[each_size:] + msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers )