diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index e47b61a4fa..61c6c0dc03 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -5,6 +5,7 @@ .. _UCX: https://github.com/openucx/ucx """ +import asyncio import functools import logging import os @@ -278,16 +279,6 @@ async def write( # Send meta data - # Send close flag and number of frames (_Bool, int64) - await self.ep.send(struct.pack("?Q", False, nframes)) - # Send which frames are CUDA (bool) and - # how large each frame is (uint64) - await self.ep.send( - struct.pack(nframes * "?" + nframes * "Q", *cuda_frames, *sizes) - ) - - # Send frames - # It is necessary to first synchronize the default stream before start # sending We synchronize the default stream because UCX is not # stream-ordered and syncing the default stream will wait for other @@ -296,8 +287,22 @@ async def write( if any(cuda_send_frames): synchronize_stream(0) + tasks = [] + + # Send close flag and number of frames (_Bool, int64) + tasks.append(self.ep.send(struct.pack("?Q", False, nframes))) + # Send which frames are CUDA (bool) and + # how large each frame is (uint64) + tasks.append( + self.ep.send( + struct.pack(nframes * "?" + nframes * "Q", *cuda_frames, *sizes) + ) + ) + + # Send frames for each_frame in send_frames: - await self.ep.send(each_frame) + tasks.append(self.ep.send(each_frame)) + await asyncio.gather(*tasks) return sum(sizes) except (ucp.exceptions.UCXBaseException): self.abort() @@ -354,8 +359,10 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): if any(cuda_recv_frames): synchronize_stream(0) + tasks = [] for each_frame in recv_frames: - await self.ep.recv(each_frame) + tasks.append(self.ep.recv(each_frame)) + await asyncio.gather(*tasks) msg = await from_frames( frames, deserialize=self.deserialize,