diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 3e2c868e3d..6979cdd934 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -191,8 +191,9 @@ async def write( # Send meta data await self.ep.send(struct.pack("Q", nframes)) - await self.ep.send(struct.pack(nframes * "?", *cuda_frames)) - await self.ep.send(struct.pack(nframes * "Q", *sizes)) + await self.ep.send( + struct.pack(nframes * "?" + nframes * "Q", *cuda_frames, *sizes) + ) # Send frames @@ -226,15 +227,11 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): await self.ep.recv(nframes) (nframes,) = struct.unpack(nframes_fmt, nframes) - cuda_frames_fmt = nframes * "?" - cuda_frames = host_array(struct.calcsize(cuda_frames_fmt)) - await self.ep.recv(cuda_frames) - cuda_frames = struct.unpack(cuda_frames_fmt, cuda_frames) - - sizes_fmt = nframes * "Q" - sizes = host_array(struct.calcsize(sizes_fmt)) - await self.ep.recv(sizes) - sizes = struct.unpack(sizes_fmt, sizes) + header_fmt = nframes * "?" + nframes * "Q" + header = host_array(struct.calcsize(header_fmt)) + await self.ep.recv(header) + header = struct.unpack(header_fmt, header) + cuda_frames, sizes = header[:nframes], header[nframes:] except (ucp.exceptions.UCXBaseException, CancelledError): self.abort() raise CommClosedError("While reading, the connection was closed")