diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 61c6c0dc03..e47b61a4fa 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -5,7 +5,6 @@ .. _UCX: https://github.com/openucx/ucx """ -import asyncio import functools import logging import os @@ -279,6 +278,16 @@ async def write( # Send meta data + # Send close flag and number of frames (_Bool, int64) + await self.ep.send(struct.pack("?Q", False, nframes)) + # Send which frames are CUDA (bool) and + # how large each frame is (uint64) + await self.ep.send( + struct.pack(nframes * "?" + nframes * "Q", *cuda_frames, *sizes) + ) + + # Send frames + # It is necessary to first synchronize the default stream before start # sending We synchronize the default stream because UCX is not # stream-ordered and syncing the default stream will wait for other @@ -287,22 +296,8 @@ async def write( if any(cuda_send_frames): synchronize_stream(0) - tasks = [] - - # Send close flag and number of frames (_Bool, int64) - tasks.append(self.ep.send(struct.pack("?Q", False, nframes))) - # Send which frames are CUDA (bool) and - # how large each frame is (uint64) - tasks.append( - self.ep.send( - struct.pack(nframes * "?" + nframes * "Q", *cuda_frames, *sizes) - ) - ) - - # Send frames for each_frame in send_frames: - tasks.append(self.ep.send(each_frame)) - await asyncio.gather(*tasks) + await self.ep.send(each_frame) return sum(sizes) except (ucp.exceptions.UCXBaseException): self.abort() @@ -359,10 +354,8 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): if any(cuda_recv_frames): synchronize_stream(0) - tasks = [] for each_frame in recv_frames: - tasks.append(self.ep.recv(each_frame)) - await asyncio.gather(*tasks) + await self.ep.recv(each_frame) msg = await from_frames( frames, deserialize=self.deserialize,