diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 7295b11bb4..2a0eaa557f 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -145,7 +145,7 @@ async def write( serializers = ("cuda", "dask", "pickle", "error") # msg can also be a list of dicts when sending batched messages frames = await to_frames( - msg, serializers=serializers, on_error=on_error + msg, serializers=serializers, on_error=on_error, ) # Send meta data diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index b75663a14f..3c9539f5b0 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -30,7 +30,7 @@ def _to_frames(): try: return list( protocol.dumps( - msg, serializers=serializers, on_error=on_error, context=context + msg, serializers=serializers, on_error=on_error, context=context, ) ) except Exception as e: diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 3bb863f78c..2bed474716 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -15,7 +15,21 @@ logger = logging.getLogger(__name__) -def dumps(msg, serializers=None, on_error="message", context=None): +def _split_and_compress(header, frames): + """ + Internal function for splitting and compressing frames + """ + + frames = frame_split_size(frames) + if frames: + compression, frames = zip(*map(maybe_compress, frames)) + else: + compression = [] + header["compression"] = compression + return header, frames + + +def dumps(msg, serializers=None, on_error="message", context=None, split_frames=True): """ Transform Python message to bytestream suitable for communication """ try: data = {} @@ -48,13 +62,13 @@ def dumps(msg, serializers=None, on_error="message", context=None): for key, (head, frames) in data.items(): if "lengths" not in head: head["lengths"] = tuple(map(nbytes, frames)) - if "compression" not in head: - frames = frame_split_size(frames) - if frames: - compression, frames = zip(*map(maybe_compress, frames)) - else: - compression = [] - head["compression"] = compression + # treat compression of collections homogenously + if "is-collection" in head: + if "compression" not in head["sub-headers"][0]: + head, frames = _split_and_compress(head, frames) + elif "compression" not in head: + head, frames = _split_and_compress(head, frames) + head["count"] = len(frames) header["headers"][key] = head header["keys"].append(key) diff --git a/distributed/protocol/cuda.py b/distributed/protocol/cuda.py index aa638f70c0..bc0be867e3 100644 --- a/distributed/protocol/cuda.py +++ b/distributed/protocol/cuda.py @@ -18,7 +18,9 @@ def cuda_dumps(x): header, frames = dumps(x) header["type-serialized"] = pickle.dumps(type(x)) header["serializer"] = "cuda" - header["compression"] = (None,) * len(frames) # no compression for gpu data + # note: when compression is not set dask may split the frame into muliple chunks + # dumps() in dask/protocol/core.py + header["compression"] = False return header, frames diff --git a/distributed/protocol/cupy.py b/distributed/protocol/cupy.py index 3d07426624..be066afd0f 100644 --- a/distributed/protocol/cupy.py +++ b/distributed/protocol/cupy.py @@ -46,6 +46,9 @@ def cuda_serialize_cupy_ndarray(x): x = cupy.array(x, copy=True) header = x.__cuda_array_interface__.copy() + # note: when compression is not set dask may split the frame into muliple chunks + # dumps() in dask/protocol/core.py + header["compression"] = False header["strides"] = tuple(x.strides) frames = [ cupy.ndarray(