Skip to content

Commit

Permalink
Adapt to non-string task keys in distributed (#1225)
Browse files Browse the repository at this point in the history
Now that keys are no longer strings there are two places we must adapt here.

1. Explicit comms must no longer manually stringify task keys before staging and intersection with the on-worker data (since that data mapping doesn't use the stringified version)
2. The `zict.File`-backed slow buffer in `DeviceHostFile` needs to translate non-string keys to string keys before writing to disk, to do this, use the same implementation that distributed uses for its own spilling buffer.

- Closes #1224

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Peter Andreas Entschev (https://github.com/pentschev)

URL: #1225
  • Loading branch information
wence- authored Aug 25, 2023
1 parent 2e7b6c0 commit 390ad36
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 6 additions & 2 deletions dask_cuda/device_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time

import numpy
from zict import Buffer, File, Func
from zict import Buffer, Func
from zict.common import ZictBase

import dask
Expand All @@ -17,6 +17,7 @@
serialize_bytelist,
)
from distributed.sizeof import safe_sizeof
from distributed.spill import CustomFile as KeyAsStringFile
from distributed.utils import nbytes

from .is_device_object import is_device_object
Expand Down Expand Up @@ -201,7 +202,10 @@ def __init__(
self.disk_func = Func(
_serialize_bytelist,
deserialize_bytes,
File(self.disk_func_path),
# Task keys are not strings, so this takes care of
# converting arbitrary tuple keys into a string before
# handing off to zict.File
KeyAsStringFile(self.disk_func_path),
)

host_buffer_kwargs = {}
Expand Down
4 changes: 1 addition & 3 deletions dask_cuda/explicit_comms/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any, Dict, Hashable, Iterable, List, Optional

import distributed.comm
from dask.utils import stringify
from distributed import Client, Worker, default_client, get_worker
from distributed.comm.addressing import parse_address, parse_host_port, unparse_address

Expand Down Expand Up @@ -305,8 +304,7 @@ def stage_keys(self, name: str, keys: Iterable[Hashable]) -> Dict[int, set]:
dict
dict that maps each worker-rank to the workers set of staged keys
"""
key_set = {stringify(k) for k in keys}
return dict(self.run(_stage_keys, name, key_set))
return dict(self.run(_stage_keys, name, set(keys)))


def pop_staging_area(session_state: dict, name: str) -> Dict[str, Any]:
Expand Down

0 comments on commit 390ad36

Please sign in to comment.