Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove dumps_task #8067

Merged
merged 5 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from dask.utils import parse_timedelta

from distributed import profile, protocol
from distributed.collections import LRU
from distributed.comm import (
Comm,
CommClosedError,
Expand All @@ -40,7 +39,6 @@
from distributed.counter import Counter
from distributed.diskutils import WorkDir, WorkSpace
from distributed.metrics import context_meter, time
from distributed.protocol import pickle
from distributed.system_monitor import SystemMonitor
from distributed.utils import (
NoOpAwaitable,
Expand All @@ -64,21 +62,6 @@
Coro = Coroutine[Any, Any, T]


cache_loads: LRU[bytes, Callable[..., Any]] = LRU(maxsize=100)


def loads_function(bytes_object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @madsbk - It looks like we were using this function in dask-cuda (rapidsai/dask-cuda#1219)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we use it for its caching feature but I don't think it is needed.

"""Load a function from bytes, cache bytes"""
if len(bytes_object) < 100000:
try:
result = cache_loads[bytes_object]
except KeyError:
result = pickle.loads(bytes_object)
cache_loads[bytes_object] = result
return result
return pickle.loads(bytes_object)


class Status(Enum):
"""
This Enum contains the various states a cluster, worker, scheduler and nanny can be
Expand Down Expand Up @@ -519,7 +502,6 @@ def func(data):
if load:
try:
import_file(out_filename)
cache_loads.data.clear()
except Exception as e:
logger.exception(e)
raise e
Expand Down
15 changes: 6 additions & 9 deletions distributed/recreate_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from dask.utils import stringify

from distributed.client import futures_of, wait
from distributed.protocol.serialize import ToPickle
from distributed.utils import sync
from distributed.utils_comm import pack_data
from distributed.worker import _deserialize

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,7 +42,10 @@ def get_error_cause(self, *args, keys=(), **kwargs):
def get_runspec(self, *args, key=None, **kwargs):
key = self._process_key(key)
ts = self.scheduler.tasks.get(key)
return {"task": ts.run_spec, "deps": [dts.key for dts in ts.dependencies]}
return {
"task": ToPickle(ts.run_spec),
"deps": [dts.key for dts in ts.dependencies],
}


class ReplayTaskClient:
Expand Down Expand Up @@ -83,13 +86,7 @@ async def _get_raw_components_from_future(self, future):
await wait(future)
key = future.key
spec = await self.scheduler.get_runspec(key=key)
deps, task = spec["deps"], spec["task"]
if isinstance(task, dict):
function, args, kwargs = _deserialize(**task)
return (function, args, kwargs, deps)
else:
function, args, kwargs = _deserialize(task=task)
return (function, args, kwargs, deps)
return (*spec["task"], spec["deps"])

async def _prepare_raw_components(self, raw_components):
"""
Expand Down
32 changes: 15 additions & 17 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
)
from distributed.utils_perf import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.variable import VariableExtension
from distributed.worker import dumps_task
from distributed.worker import _normalize_task

if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
Expand Down Expand Up @@ -156,6 +156,8 @@
# (recommendations, client messages, worker messages)
RecsMsgs: TypeAlias = tuple[Recs, Msgs, Msgs]

T_runspec: TypeAlias = tuple[Callable, tuple, dict[str, Any]]

logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_DATA_SIZE = parse_bytes(
Expand Down Expand Up @@ -1176,7 +1178,7 @@ class TaskState:
#: "pure data" (such as, for example, a piece of data loaded in the scheduler using
#: :meth:`Client.scatter`). A "pure data" task cannot be computed again if its
#: value is lost.
run_spec: object
run_spec: T_runspec | None

#: The priority provides each task with a relative ranking which is used to break
#: ties when many tasks are being considered for execution.
Expand Down Expand Up @@ -1375,7 +1377,7 @@ class TaskState:
def __init__(
self,
key: str,
run_spec: object,
run_spec: T_runspec | None,
state: TaskStateState,
):
self.key = key
Expand Down Expand Up @@ -1787,7 +1789,7 @@ def __pdict__(self) -> dict[str, Any]:
def new_task(
self,
key: str,
spec: object,
spec: T_runspec | None,
state: TaskStateState,
computation: Computation | None = None,
) -> TaskState:
Expand Down Expand Up @@ -3343,10 +3345,7 @@ def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]:
dts.key: [ws.address for ws in dts.who_has] for dts in ts.dependencies
},
"nbytes": {dts.key: dts.nbytes for dts in ts.dependencies},
"run_spec": None,
"function": None,
"args": None,
"kwargs": None,
"run_spec": ToPickle(ts.run_spec),
"resource_restrictions": ts.resource_restrictions,
"actor": ts.actor,
"annotations": ts.annotations,
Expand All @@ -3355,11 +3354,6 @@ def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]:
if self.validate:
assert all(msg["who_has"].values())

if isinstance(ts.run_spec, dict):
msg.update(ts.run_spec)
else:
msg["run_spec"] = ts.run_spec

return msg


Expand Down Expand Up @@ -4606,7 +4600,11 @@ async def update_graph(
self.digest_metric("update-graph-duration", end - start)

def _generate_taskstates(
self, keys: set[str], dsk: dict, dependencies: dict, computation: Computation
self,
keys: set[str],
dsk: dict[str, T_runspec],
dependencies: dict[str, set[str]],
computation: Computation,
) -> tuple:
# Get or create task states
runnable = []
Expand Down Expand Up @@ -8479,8 +8477,8 @@ def transition(


def _materialize_graph(
graph: HighLevelGraph, global_annotations: dict
) -> tuple[dict, dict, dict]:
graph: HighLevelGraph, global_annotations: dict[str, Any]
) -> tuple[dict[str, T_runspec], dict[str, set[str]], dict[str, Any]]:
dsk = dask.utils.ensure_dict(graph)
annotations_by_type: defaultdict[str, dict[str, Any]] = defaultdict(dict)
for annotations_type, value in global_annotations.items():
Expand Down Expand Up @@ -8536,6 +8534,6 @@ def _materialize_graph(
for k in list(dsk):
if dsk[k] is k:
del dsk[k]
dsk = valmap(dumps_task, dsk)
dsk = valmap(_normalize_task, dsk)

return dsk, dependencies, annotations_by_type
14 changes: 3 additions & 11 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,12 +1111,11 @@ def test_workerstate_resumed_waiting_to_flight(ws):
assert ws.tasks["x"].state == "flight"


@pytest.mark.parametrize("critical_section", ["execute", "deserialize_task"])
@pytest.mark.parametrize("resume_inside_critical_section", [False, True])
@pytest.mark.parametrize("resumed_status", ["executing", "resumed"])
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_execute_preamble_early_cancel(
c, s, b, critical_section, resume_inside_critical_section, resumed_status
c, s, b, resume_inside_critical_section, resumed_status
):
"""Test multiple race conditions in the preamble of Worker.execute(), which used to
cause a task to remain permanently in resumed state or to crash the worker through
Expand All @@ -1129,15 +1128,8 @@ async def test_execute_preamble_early_cancel(
test_worker.py::test_execute_preamble_abort_retirement
"""
async with BlockedExecute(s.address) as a:
if critical_section == "execute":
in_ev = a.in_execute
block_ev = a.block_execute
a.block_deserialize_task.set()
else:
assert critical_section == "deserialize_task"
in_ev = a.in_deserialize_task
block_ev = a.block_deserialize_task
a.block_execute.set()
in_ev = a.in_execute
block_ev = a.block_execute

async def resume():
if resumed_status == "executing":
Expand Down
78 changes: 32 additions & 46 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import dask
from dask import delayed
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
from dask.utils import apply, parse_timedelta, stringify, tmpfile, typename
from dask.utils import parse_timedelta, stringify, tmpfile, typename

from distributed import (
CancelledError,
Expand Down Expand Up @@ -74,7 +74,7 @@
varying,
wait_for_state,
)
from distributed.worker import dumps_function, dumps_task, get_worker, secede
from distributed.worker import dumps_function, get_worker, secede

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -345,7 +345,26 @@ async def test_decide_worker_rootish_while_last_worker_is_retiring(c, s, a):
await wait(xs + ys)


@pytest.mark.slow
from distributed import WorkerPlugin


class CountData(WorkerPlugin):
def __init__(self, keys):
self.keys = keys
self.worker = None
self.count = 0

def setup(self, worker):
self.worker = worker

def transition(self, start, finish, *args, **kwargs):
count = 0
for k in self.worker.data:
if k in self.keys:
count += 1
self.count = max(self.count, count)


@gen_cluster(
nthreads=[("", 2)] * 4,
client=True,
Expand All @@ -359,33 +378,18 @@ async def test_graph_execution_width(c, s, *workers):
The number of parallel work streams match the number of threads.
"""

class Refcount:
"Track how many instances of this class exist; logs the count at creation and deletion"

count = 0
lock = dask.utils.SerializableLock()
log = []

def __init__(self):
with self.lock:
type(self).count += 1
self.log.append(self.count)

def __del__(self):
with self.lock:
self.log.append(self.count)
type(self).count -= 1
Comment on lines -362 to -377
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is interesting. This PR is not changing anything in terms of scheduling, ordering, etc. but this is still quite reliably failing. It seems as if Refcount is relying on explicit garbage collection. This is something I want to look into a little more since we're seeing a lot of GC warnings recently. However, for the sake of this PR I rewrote it to count keys in data instead of relying on GC. Eventually, I think both tests would make sense

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really a weird case and somehow connected to how this object is defined in a local context.
I looked pretty closely but I cannot find any cyclic references. In fact, I see actually fewer objects actually tracked by GC than this counter is let to believe. I know that CPython guarantees that __del__ is indeed called and only called once but I believe there are some caveats about when this is the case.


roots = [delayed(Refcount)() for _ in range(32)]
roots = [delayed(inc)(ix) for ix in range(32)]
passthrough1 = [delayed(slowidentity)(r, delay=0) for r in roots]
passthrough2 = [delayed(slowidentity)(r, delay=0) for r in passthrough1]
done = [delayed(lambda r: None)(r) for r in passthrough2]

await c.register_worker_plugin(
CountData(keys=[f.key for f in roots]), name="count-roots"
)
fs = c.compute(done)
await wait(fs)
# NOTE: the max should normally equal `total_nthreads`. But some macOS CI machines
# are slow enough that they aren't able to reach the full parallelism of 8 threads.
assert max(Refcount.log) <= s.total_nthreads

res = await c.run(lambda dask_worker: dask_worker.plugins["count-roots"].count)
assert all(0 < count <= 2 for count in res.values())


@gen_cluster(client=True, nthreads=[("", 1)])
Expand Down Expand Up @@ -953,24 +957,6 @@ def test_dumps_function():
assert a != c


def test_dumps_task():
d = dumps_task((inc, 1))
assert set(d) == {"function", "args"}

def f(x, y=2):
return x + y

d = dumps_task((apply, f, (1,), {"y": 10}))
assert cloudpickle.loads(d["function"])(1, 2) == 3
assert cloudpickle.loads(d["args"]) == (1,)
assert cloudpickle.loads(d["kwargs"]) == {"y": 10}

d = dumps_task((apply, f, (1,)))
assert cloudpickle.loads(d["function"])(1, 2) == 3
assert cloudpickle.loads(d["args"]) == (1,)
assert set(d) == {"function", "args"}


@pytest.mark.parametrize("worker_saturation", [1.0, float("inf")])
@gen_cluster(client=True)
async def test_ready_remove_worker(c, s, a, b, worker_saturation):
Expand Down Expand Up @@ -1357,9 +1343,9 @@ async def test_update_graph_culls(s, a, b):
layers={
"foo": MaterializedLayer(
{
"x": dumps_task((inc, 1)),
"y": dumps_task((inc, "x")),
"z": dumps_task((inc, 2)),
"x": (inc, 1),
"y": (inc, "x"),
"z": (inc, 2),
}
)
},
Expand Down
3 changes: 0 additions & 3 deletions distributed/tests/test_spans.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,14 +518,12 @@ async def test_worker_metrics(c, s, a, b):

# metrics for foo include self and its child bar
assert list(foo_metrics) == [
("execute", "x", "deserialize", "seconds"),
("execute", "x", "thread-cpu", "seconds"),
("execute", "x", "thread-noncpu", "seconds"),
("execute", "x", "executor", "seconds"),
("execute", "x", "other", "seconds"),
("execute", "x", "memory-read", "count"),
("execute", "x", "memory-read", "bytes"),
("execute", "y", "deserialize", "seconds"),
("execute", "y", "thread-cpu", "seconds"),
("execute", "y", "thread-noncpu", "seconds"),
("execute", "y", "executor", "seconds"),
Expand All @@ -536,7 +534,6 @@ async def test_worker_metrics(c, s, a, b):
list(bar0_metrics)
== list(bar1_metrics)
== [
("execute", "y", "deserialize", "seconds"),
("execute", "y", "thread-cpu", "seconds"),
("execute", "y", "thread-noncpu", "seconds"),
("execute", "y", "executor", "seconds"),
Expand Down
Loading
Loading