diff --git a/distributed/core.py b/distributed/core.py index b17177e280..df6d3cb8e9 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -26,7 +26,6 @@ from dask.utils import parse_timedelta from distributed import profile, protocol -from distributed.collections import LRU from distributed.comm import ( Comm, CommClosedError, @@ -40,7 +39,6 @@ from distributed.counter import Counter from distributed.diskutils import WorkDir, WorkSpace from distributed.metrics import context_meter, time -from distributed.protocol import pickle from distributed.system_monitor import SystemMonitor from distributed.utils import ( NoOpAwaitable, @@ -64,21 +62,6 @@ Coro = Coroutine[Any, Any, T] -cache_loads: LRU[bytes, Callable[..., Any]] = LRU(maxsize=100) - - -def loads_function(bytes_object): - """Load a function from bytes, cache bytes""" - if len(bytes_object) < 100000: - try: - result = cache_loads[bytes_object] - except KeyError: - result = pickle.loads(bytes_object) - cache_loads[bytes_object] = result - return result - return pickle.loads(bytes_object) - - class Status(Enum): """ This Enum contains the various states a cluster, worker, scheduler and nanny can be @@ -519,7 +502,6 @@ def func(data): if load: try: import_file(out_filename) - cache_loads.data.clear() except Exception as e: logger.exception(e) raise e diff --git a/distributed/recreate_tasks.py b/distributed/recreate_tasks.py index a8351dad70..d76aaaf566 100644 --- a/distributed/recreate_tasks.py +++ b/distributed/recreate_tasks.py @@ -5,9 +5,9 @@ from dask.utils import stringify from distributed.client import futures_of, wait +from distributed.protocol.serialize import ToPickle from distributed.utils import sync from distributed.utils_comm import pack_data -from distributed.worker import _deserialize logger = logging.getLogger(__name__) @@ -42,7 +42,10 @@ def get_error_cause(self, *args, keys=(), **kwargs): def get_runspec(self, *args, key=None, **kwargs): key = self._process_key(key) ts = self.scheduler.tasks.get(key) - return {"task": ts.run_spec, "deps": [dts.key for dts in ts.dependencies]} + return { + "task": ToPickle(ts.run_spec), + "deps": [dts.key for dts in ts.dependencies], + } class ReplayTaskClient: @@ -83,13 +86,7 @@ async def _get_raw_components_from_future(self, future): await wait(future) key = future.key spec = await self.scheduler.get_runspec(key=key) - deps, task = spec["deps"], spec["task"] - if isinstance(task, dict): - function, args, kwargs = _deserialize(**task) - return (function, args, kwargs, deps) - else: - function, args, kwargs = _deserialize(task=task) - return (function, args, kwargs, deps) + return (*spec["task"], spec["deps"]) async def _prepare_raw_components(self, raw_components): """ diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9e1b16fd76..36d771fc81 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -126,7 +126,7 @@ ) from distributed.utils_perf import disable_gc_diagnosis, enable_gc_diagnosis from distributed.variable import VariableExtension -from distributed.worker import dumps_task +from distributed.worker import _normalize_task if TYPE_CHECKING: # TODO import from typing (requires Python >=3.10) @@ -156,6 +156,8 @@ # (recommendations, client messages, worker messages) RecsMsgs: TypeAlias = tuple[Recs, Msgs, Msgs] +T_runspec: TypeAlias = tuple[Callable, tuple, dict[str, Any]] + logger = logging.getLogger(__name__) LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") DEFAULT_DATA_SIZE = parse_bytes( @@ -1176,7 +1178,7 @@ class TaskState: #: "pure data" (such as, for example, a piece of data loaded in the scheduler using #: :meth:`Client.scatter`). A "pure data" task cannot be computed again if its #: value is lost. - run_spec: object + run_spec: T_runspec | None #: The priority provides each task with a relative ranking which is used to break #: ties when many tasks are being considered for execution. @@ -1375,7 +1377,7 @@ class TaskState: def __init__( self, key: str, - run_spec: object, + run_spec: T_runspec | None, state: TaskStateState, ): self.key = key @@ -1787,7 +1789,7 @@ def __pdict__(self) -> dict[str, Any]: def new_task( self, key: str, - spec: object, + spec: T_runspec | None, state: TaskStateState, computation: Computation | None = None, ) -> TaskState: @@ -3343,10 +3345,7 @@ def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]: dts.key: [ws.address for ws in dts.who_has] for dts in ts.dependencies }, "nbytes": {dts.key: dts.nbytes for dts in ts.dependencies}, - "run_spec": None, - "function": None, - "args": None, - "kwargs": None, + "run_spec": ToPickle(ts.run_spec), "resource_restrictions": ts.resource_restrictions, "actor": ts.actor, "annotations": ts.annotations, @@ -3355,11 +3354,6 @@ def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]: if self.validate: assert all(msg["who_has"].values()) - if isinstance(ts.run_spec, dict): - msg.update(ts.run_spec) - else: - msg["run_spec"] = ts.run_spec - return msg @@ -4606,7 +4600,11 @@ async def update_graph( self.digest_metric("update-graph-duration", end - start) def _generate_taskstates( - self, keys: set[str], dsk: dict, dependencies: dict, computation: Computation + self, + keys: set[str], + dsk: dict[str, T_runspec], + dependencies: dict[str, set[str]], + computation: Computation, ) -> tuple: # Get or create task states runnable = [] @@ -8479,8 +8477,8 @@ def transition( def _materialize_graph( - graph: HighLevelGraph, global_annotations: dict -) -> tuple[dict, dict, dict]: + graph: HighLevelGraph, global_annotations: dict[str, Any] +) -> tuple[dict[str, T_runspec], dict[str, set[str]], dict[str, Any]]: dsk = dask.utils.ensure_dict(graph) annotations_by_type: defaultdict[str, dict[str, Any]] = defaultdict(dict) for annotations_type, value in global_annotations.items(): @@ -8536,6 +8534,6 @@ def _materialize_graph( for k in list(dsk): if dsk[k] is k: del dsk[k] - dsk = valmap(dumps_task, dsk) + dsk = valmap(_normalize_task, dsk) return dsk, dependencies, annotations_by_type diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 1609972fa8..8daa83bf87 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -1111,12 +1111,11 @@ def test_workerstate_resumed_waiting_to_flight(ws): assert ws.tasks["x"].state == "flight" -@pytest.mark.parametrize("critical_section", ["execute", "deserialize_task"]) @pytest.mark.parametrize("resume_inside_critical_section", [False, True]) @pytest.mark.parametrize("resumed_status", ["executing", "resumed"]) @gen_cluster(client=True, nthreads=[("", 1)]) async def test_execute_preamble_early_cancel( - c, s, b, critical_section, resume_inside_critical_section, resumed_status + c, s, b, resume_inside_critical_section, resumed_status ): """Test multiple race conditions in the preamble of Worker.execute(), which used to cause a task to remain permanently in resumed state or to crash the worker through @@ -1129,15 +1128,8 @@ async def test_execute_preamble_early_cancel( test_worker.py::test_execute_preamble_abort_retirement """ async with BlockedExecute(s.address) as a: - if critical_section == "execute": - in_ev = a.in_execute - block_ev = a.block_execute - a.block_deserialize_task.set() - else: - assert critical_section == "deserialize_task" - in_ev = a.in_deserialize_task - block_ev = a.block_deserialize_task - a.block_execute.set() + in_ev = a.in_execute + block_ev = a.block_execute async def resume(): if resumed_status == "executing": diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index af476c0915..754eadaff5 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -24,7 +24,7 @@ import dask from dask import delayed from dask.highlevelgraph import HighLevelGraph, MaterializedLayer -from dask.utils import apply, parse_timedelta, stringify, tmpfile, typename +from dask.utils import parse_timedelta, stringify, tmpfile, typename from distributed import ( CancelledError, @@ -74,7 +74,7 @@ varying, wait_for_state, ) -from distributed.worker import dumps_function, dumps_task, get_worker, secede +from distributed.worker import dumps_function, get_worker, secede pytestmark = pytest.mark.ci1 @@ -345,7 +345,26 @@ async def test_decide_worker_rootish_while_last_worker_is_retiring(c, s, a): await wait(xs + ys) -@pytest.mark.slow +from distributed import WorkerPlugin + + +class CountData(WorkerPlugin): + def __init__(self, keys): + self.keys = keys + self.worker = None + self.count = 0 + + def setup(self, worker): + self.worker = worker + + def transition(self, start, finish, *args, **kwargs): + count = 0 + for k in self.worker.data: + if k in self.keys: + count += 1 + self.count = max(self.count, count) + + @gen_cluster( nthreads=[("", 2)] * 4, client=True, @@ -359,33 +378,18 @@ async def test_graph_execution_width(c, s, *workers): The number of parallel work streams match the number of threads. """ - class Refcount: - "Track how many instances of this class exist; logs the count at creation and deletion" - - count = 0 - lock = dask.utils.SerializableLock() - log = [] - - def __init__(self): - with self.lock: - type(self).count += 1 - self.log.append(self.count) - - def __del__(self): - with self.lock: - self.log.append(self.count) - type(self).count -= 1 - - roots = [delayed(Refcount)() for _ in range(32)] + roots = [delayed(inc)(ix) for ix in range(32)] passthrough1 = [delayed(slowidentity)(r, delay=0) for r in roots] passthrough2 = [delayed(slowidentity)(r, delay=0) for r in passthrough1] done = [delayed(lambda r: None)(r) for r in passthrough2] - + await c.register_worker_plugin( + CountData(keys=[f.key for f in roots]), name="count-roots" + ) fs = c.compute(done) await wait(fs) - # NOTE: the max should normally equal `total_nthreads`. But some macOS CI machines - # are slow enough that they aren't able to reach the full parallelism of 8 threads. - assert max(Refcount.log) <= s.total_nthreads + + res = await c.run(lambda dask_worker: dask_worker.plugins["count-roots"].count) + assert all(0 < count <= 2 for count in res.values()) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -953,24 +957,6 @@ def test_dumps_function(): assert a != c -def test_dumps_task(): - d = dumps_task((inc, 1)) - assert set(d) == {"function", "args"} - - def f(x, y=2): - return x + y - - d = dumps_task((apply, f, (1,), {"y": 10})) - assert cloudpickle.loads(d["function"])(1, 2) == 3 - assert cloudpickle.loads(d["args"]) == (1,) - assert cloudpickle.loads(d["kwargs"]) == {"y": 10} - - d = dumps_task((apply, f, (1,))) - assert cloudpickle.loads(d["function"])(1, 2) == 3 - assert cloudpickle.loads(d["args"]) == (1,) - assert set(d) == {"function", "args"} - - @pytest.mark.parametrize("worker_saturation", [1.0, float("inf")]) @gen_cluster(client=True) async def test_ready_remove_worker(c, s, a, b, worker_saturation): @@ -1357,9 +1343,9 @@ async def test_update_graph_culls(s, a, b): layers={ "foo": MaterializedLayer( { - "x": dumps_task((inc, 1)), - "y": dumps_task((inc, "x")), - "z": dumps_task((inc, 2)), + "x": (inc, 1), + "y": (inc, "x"), + "z": (inc, 2), } ) }, diff --git a/distributed/tests/test_spans.py b/distributed/tests/test_spans.py index 23ca4789d3..b6210a1710 100644 --- a/distributed/tests/test_spans.py +++ b/distributed/tests/test_spans.py @@ -518,14 +518,12 @@ async def test_worker_metrics(c, s, a, b): # metrics for foo include self and its child bar assert list(foo_metrics) == [ - ("execute", "x", "deserialize", "seconds"), ("execute", "x", "thread-cpu", "seconds"), ("execute", "x", "thread-noncpu", "seconds"), ("execute", "x", "executor", "seconds"), ("execute", "x", "other", "seconds"), ("execute", "x", "memory-read", "count"), ("execute", "x", "memory-read", "bytes"), - ("execute", "y", "deserialize", "seconds"), ("execute", "y", "thread-cpu", "seconds"), ("execute", "y", "thread-noncpu", "seconds"), ("execute", "y", "executor", "seconds"), @@ -536,7 +534,6 @@ async def test_worker_metrics(c, s, a, b): list(bar0_metrics) == list(bar1_metrics) == [ - ("execute", "y", "deserialize", "seconds"), ("execute", "y", "thread-cpu", "seconds"), ("execute", "y", "thread-noncpu", "seconds"), ("execute", "y", "executor", "seconds"), diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 29b789618c..4422e2dd51 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -26,7 +26,7 @@ from tornado.ioloop import IOLoop import dask -from dask import delayed, istask +from dask import delayed from dask.system import CPU_COUNT from dask.utils import tmpfile @@ -95,8 +95,6 @@ ExecuteFailureEvent, ExecuteSuccessEvent, RemoveReplicasEvent, - SerializedTask, - StealRequestEvent, ) pytestmark = pytest.mark.ci1 @@ -2053,7 +2051,7 @@ def __setstate__(self, state): self._thread_ident = threading.get_ident() return self - monkeypatch.setattr("distributed.worker.OFFLOAD_THRESHOLD", 1) + monkeypatch.setattr("distributed.comm.utils.OFFLOAD_THRESHOLD", 1) async with Worker(s.address, executor="offload") as w: from distributed.utils import _offload_executor @@ -2090,7 +2088,7 @@ async def test_stimulus_story(c, s, a): assert isinstance(story[0], ComputeTaskEvent) assert story[0].key == "f1" - assert story[0].run_spec == SerializedTask(task=None) # Not logged + assert story[0].run_spec is None # Not logged assert isinstance(story[1], ExecuteSuccessEvent) assert story[1].key == "f1" @@ -2100,7 +2098,7 @@ async def test_stimulus_story(c, s, a): assert isinstance(story[2], ComputeTaskEvent) assert story[2].key == "f2" assert story[2].who_has == {"f1": (a.address,)} - assert story[2].run_spec == SerializedTask(task=None) # Not logged + assert story[2].run_spec is None # Not logged assert story[2].handled >= story[1].handled assert isinstance(story[3], ExecuteFailureEvent) @@ -2789,39 +2787,6 @@ async def test_forget_dependents_after_release(c, s, a): assert fut2.key not in {d.key for d in a.state.tasks[fut.key].dependents} -@pytest.mark.filterwarnings("ignore:Sending large graph of size") -@pytest.mark.filterwarnings("ignore:Large object of size") -@gen_cluster(client=True) -async def test_steal_during_task_deserialization(c, s, a, b, monkeypatch): - stealing_ext = s.extensions["stealing"] - await stealing_ext.stop() - - in_deserialize = asyncio.Event() - wait_in_deserialize = asyncio.Event() - - async def custom_worker_offload(func, *args): - res = func(*args) - if not istask(args) and istask(res): - in_deserialize.set() - await wait_in_deserialize.wait() - return res - - monkeypatch.setattr("distributed.worker.offload", custom_worker_offload) - obj = random.randbytes(OFFLOAD_THRESHOLD + 1) - fut = c.submit(lambda _: 41, obj, workers=[a.address], allow_other_workers=True) - - await in_deserialize.wait() - ts = s.tasks[fut.key] - a.handle_stimulus(StealRequestEvent(key=fut.key, stimulus_id="test")) - stealing_ext.scheduler.send_task_to_worker(b.address, ts) - - fut2 = c.submit(inc, fut, workers=[a.address]) - fut3 = c.submit(inc, fut2, workers=[a.address]) - wait_in_deserialize.set() - assert await fut2 == 42 - await fut3 - - @gen_cluster(client=True, config=NO_AMM) async def test_acquire_replicas(c, s, a, b): fut = c.submit(inc, 1, workers=[a.address]) @@ -3652,7 +3617,6 @@ async def test_execute_preamble_abort_retirement(c, s): """ async with BlockedExecute(s.address) as a: await c.wait_for_workers(1) - a.block_deserialize_task.set() # Uninteresting in this test x = await c.scatter({"x": 1}, workers=[a.address]) y = c.submit(inc, 1, key="y", workers=[a.address]) diff --git a/distributed/tests/test_worker_metrics.py b/distributed/tests/test_worker_metrics.py index 8d823cc98c..059eba8566 100644 --- a/distributed/tests/test_worker_metrics.py +++ b/distributed/tests/test_worker_metrics.py @@ -150,7 +150,6 @@ async def test_custom_executor(c, s, a): await c.submit(sleep, 0.1) assert list(get_digests(a, "execute")) == [ - ("execute", span_id(s), "sleep", "deserialize", "seconds"), ("execute", span_id(s), "sleep", "executor", "seconds"), ("execute", span_id(s), "sleep", "other", "seconds"), ] @@ -160,18 +159,10 @@ async def test_custom_executor(c, s, a): ) -@gen_cluster(client=True, nthreads=[("", 1)]) -async def test_run_spec_deserialization(c, s, a): - """Test that deserialization of run_spec is metered""" - await c.submit(inc, 1, key="x") - assert 0 < a.digests_total["execute", span_id(s), "x", "deserialize", "seconds"] < 1 - - @gen_cluster(client=True) async def test_offload(c, s, a, b, monkeypatch): """Test that functions wrapped by offload() are metered""" monkeypatch.setattr(distributed.comm.utils, "OFFLOAD_THRESHOLD", 1) - monkeypatch.setattr(distributed.worker, "OFFLOAD_THRESHOLD", 1) x = c.submit(inc, 1, key="x", workers=[a.address]) y = c.submit(lambda x: None, x, key="y", workers=[b.address]) @@ -180,8 +171,6 @@ async def test_offload(c, s, a, b, monkeypatch): assert list(get_digests(b, {"offload", "serialize", "deserialize"})) == [ ("gather-dep", "offload", "seconds"), ("gather-dep", "deserialize", "seconds"), - ("execute", span_id(s), "y", "offload", "seconds"), - ("execute", span_id(s), "y", "deserialize", "seconds"), ("get-data", "offload", "seconds"), ("get-data", "serialize", "seconds"), ] @@ -364,7 +353,6 @@ def f(): await wait(c.submit(f, key="x")) assert list(get_digests(a)) == [ - ("execute", span_id(s), "x", "deserialize", "seconds"), ("execute", span_id(s), "x", "I/O", "seconds"), ("execute", span_id(s), "x", "thread-cpu", "seconds"), ("execute", span_id(s), "x", "thread-noncpu", "seconds"), @@ -387,7 +375,6 @@ async def f(): await wait(c.submit(f, key="x")) assert list(get_digests(a)) == [ - ("execute", span_id(s), "x", "deserialize", "seconds"), ("execute", span_id(s), "x", "I/O", "seconds"), ("execute", span_id(s), "x", "thread-noncpu", "seconds"), ("execute", span_id(s), "x", "other", "seconds"), @@ -431,7 +418,6 @@ def f(): a_metrics = get_digests(a) assert list(s_metrics) == [ - ("execute", "x", "deserialize", "seconds"), ("execute", "x", ("foo", 1), "seconds"), ("execute", "x", None, "custom"), ("execute", "x", "thread-cpu", "seconds"), @@ -536,7 +522,6 @@ async def test_send_metrics_to_scheduler(c, s, a, b): s_metrics = get_digests(s) expect_worker = [ - ("execute", None, "x", "deserialize", "seconds"), ("execute", None, "x", "thread-cpu", "seconds"), ("execute", None, "x", "thread-noncpu", "seconds"), ("execute", None, "x", "executor", "seconds"), @@ -583,7 +568,6 @@ async def test_no_spans_extension(c, s, a): s_metrics = get_digests(s) expect_worker = [ ("execute", None, "x", "failed", "seconds"), - ("execute", None, "y", "deserialize", "seconds"), ("execute", None, "y", "thread-cpu", "seconds"), ("execute", None, "y", "thread-noncpu", "seconds"), ("execute", None, "y", "executor", "seconds"), diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index b06ab6548e..6b21f849b3 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -57,7 +57,6 @@ RetryBusyWorkerEvent, RetryBusyWorkerLater, SecedeEvent, - SerializedTask, StateMachineEvent, TaskErredMsg, TaskState, @@ -375,28 +374,28 @@ def test_event_to_dict_without_annotations(): def test_computetask_to_dict(): """The potentially very large ComputeTaskEvent.run_spec is not stored in the log""" + + def f(arg): + pass + ev = ComputeTaskEvent( key="x", who_has={"y": ["w1"]}, nbytes={"y": 123}, priority=(0,), duration=123.45, - run_spec=None, + run_spec=(f, "arg", {}), resource_restrictions={}, actor=False, annotations={}, span_id=None, stimulus_id="test", - function=b"blob", - args=b"blob", - kwargs=None, run_id=5, ) - assert ev.run_spec == SerializedTask(function=b"blob", args=b"blob") + assert ev.run_spec is not None ev2 = ev.to_loggable(handled=11.22) assert ev2.handled == 11.22 - assert ev2.run_spec == SerializedTask(task=None) - assert ev.run_spec == SerializedTask(function=b"blob", args=b"blob") + assert ev2.run_spec is None d = recursive_to_dict(ev2) assert d == { "cls": "ComputeTaskEvent", @@ -404,7 +403,7 @@ def test_computetask_to_dict(): "who_has": {"y": ["w1"]}, "nbytes": {"y": 123}, "priority": [0], - "run_spec": [None, None, None, None], + "run_spec": None, "duration": 123.45, "resource_restrictions": {}, "actor": False, @@ -412,14 +411,11 @@ def test_computetask_to_dict(): "span_id": None, "stimulus_id": "test", "handled": 11.22, - "function": None, - "args": None, - "kwargs": None, "run_id": 5, } ev3 = StateMachineEvent.from_dict(d) assert isinstance(ev3, ComputeTaskEvent) - assert ev3.run_spec == SerializedTask(task=None) + assert ev3.run_spec is None assert ev3.priority == (0,) # List is automatically converted back to tuple @@ -431,15 +427,12 @@ def test_computetask_dummy(): nbytes={}, priority=(0,), duration=1.0, - run_spec=None, + run_spec=ComputeTaskEvent.dummy_runspec(), resource_restrictions={}, actor=False, annotations={}, span_id=None, stimulus_id="s", - function=None, - args=None, - kwargs=None, run_id=0, ) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index ee2300c7aa..8fc4ec4d38 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2253,10 +2253,6 @@ class BlockedExecute(Worker): method and then does not proceed, thus leaving the task in executing state indefinitely, until the test sets `block_execute`. - After that, the worker sets `in_deserialize_task` to simulate the moment when a - large run_spec is being deserialized in a separate thread. The worker will block - again until the test sets `block_deserialize_task`. - Finally, the worker sets `in_execute_exit` when execute() terminates, but before the worker state has processed its exit callback. The worker will block one last time until the test sets `block_execute_exit`. @@ -2287,8 +2283,6 @@ def f(in_task, block_task): def __init__(self, *args, **kwargs): self.in_execute = asyncio.Event() self.block_execute = asyncio.Event() - self.in_deserialize_task = asyncio.Event() - self.block_deserialize_task = asyncio.Event() self.in_execute_exit = asyncio.Event() self.block_execute_exit = asyncio.Event() @@ -2303,13 +2297,6 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent: self.in_execute_exit.set() await self.block_execute_exit.wait() - async def _maybe_deserialize_task( - self, ts: WorkerTaskState - ) -> tuple[Callable, tuple, dict[str, Any]]: - self.in_deserialize_task.set() - await self.block_deserialize_task.wait() - return await super()._maybe_deserialize_task(ts) - @contextmanager def freeze_data_fetching(w: Worker, *, jump_start: bool = False) -> Iterator[None]: diff --git a/distributed/worker.py b/distributed/worker.py index 27dbc5166f..3af7117909 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -66,7 +66,6 @@ from distributed.comm import Comm, connect, get_address_host, parse_address from distributed.comm import resolve_address as comm_resolve_address from distributed.comm.addressing import address_from_user_args -from distributed.comm.utils import OFFLOAD_THRESHOLD from distributed.compatibility import PeriodicCallback from distributed.core import ( ConnectionPool, @@ -75,7 +74,6 @@ coerce_to_address, context_meter_to_server_digest, error_message, - loads_function, pingpong, ) from distributed.core import rpc as RPCType @@ -125,7 +123,6 @@ WorkerMemoryManager, ) from distributed.worker_state_machine import ( - NO_VALUE, AcquireReplicasEvent, BaseWorker, CancelComputeEvent, @@ -162,6 +159,7 @@ from distributed.client import Client from distributed.diagnostics.plugin import WorkerPlugin from distributed.nanny import Nanny + from distributed.scheduler import T_runspec P = ParamSpec("P") T = TypeVar("T") @@ -2226,24 +2224,6 @@ def actor_attribute(self, actor=None, attribute=None) -> dict[str, Any]: except Exception as ex: return {"status": "error", "exception": to_serialize(ex)} - async def _maybe_deserialize_task( - self, ts: TaskState - ) -> tuple[Callable, tuple, dict[str, Any]]: - assert ts.run_spec is not None - start = time() - # Offload deserializing large tasks - if sizeof(ts.run_spec) > OFFLOAD_THRESHOLD: - function, args, kwargs = await offload(_deserialize, *ts.run_spec) - else: - function, args, kwargs = _deserialize(*ts.run_spec) - stop = time() - - if stop - start > 0.010: - ts.startstops.append( - {"action": "deserialize", "start": start, "stop": stop} - ) - return function, args, kwargs - @fail_hard async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent: """Execute a task. Implements BaseWorker abstract method. @@ -2262,23 +2242,13 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent: ts = self.state.tasks[key] run_id = ts.run_id - try: - function, args, kwargs = await self._maybe_deserialize_task(ts) - except Exception as exc: - logger.error("Could not deserialize task %s", key, exc_info=True) - return ExecuteFailureEvent.from_exception( - exc, - key=key, - run_id=run_id, - stimulus_id=f"run-spec-deserialize-failed-{time()}", - ) - try: if self.state.validate: assert not ts.waiting_for_data assert ts.state in ("executing", "cancelled", "resumed"), ts - assert ts.run_spec is not None + assert ts.run_spec is not None + function, args, kwargs = ts.run_spec args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs) assert ts.annotations is not None @@ -2947,27 +2917,14 @@ async def get_data_from_worker( rpc.reuse(worker, comm) -job_counter = [0] - - -@context_meter.meter("deserialize") -def _deserialize(function=None, args=None, kwargs=None, task=NO_VALUE): - """Deserialize task inputs and regularize to func, args, kwargs""" - # Some objects require threadlocal state during deserialization, e.g. to - # detect the current worker - if function is not None: - function = loads_function(function) - if args and isinstance(args, bytes): - args = pickle.loads(args) - if kwargs and isinstance(kwargs, bytes): - kwargs = pickle.loads(kwargs) - - if task is not NO_VALUE: - assert not function and not args and not kwargs - function = execute_task - args = (task,) +def _normalize_task(task: Any) -> T_runspec: + if istask(task): + if task[0] is apply and not any(map(_maybe_complex, task[2:])): + return task[1], task[2], task[3] if len(task) == 4 else {} + elif not any(map(_maybe_complex, task[1:])): + return task[0], task[1:], {} - return function, args or (), kwargs or {} + return execute_task, (task,), {} def execute_task(task): @@ -3008,62 +2965,6 @@ def dumps_function(func) -> bytes: return result -def dumps_task(task): - """Serialize a dask task - - Returns a dict of bytestrings that can each be loaded with ``loads`` - - Examples - -------- - Either returns a task as a function, args, kwargs dict - - >>> from operator import add - >>> dumps_task((add, 1)) # doctest: +SKIP - {'function': b'\x80\x04\x95\x00\x8c\t_operator\x94\x8c\x03add\x94\x93\x94.' - 'args': b'\x80\x04\x95\x07\x00\x00\x00K\x01K\x02\x86\x94.'} - - Or as a single task blob if it can't easily decompose the result. This - happens either if the task is highly nested, or if it isn't a task at all - - >>> dumps_task(1) # doctest: +SKIP - {'task': b'\x80\x04\x95\x03\x00\x00\x00\x00\x00\x00\x00K\x01.'} - """ - if istask(task): - if task[0] is apply and not any(map(_maybe_complex, task[2:])): - d = {"function": dumps_function(task[1]), "args": warn_dumps(task[2])} - if len(task) == 4: - d["kwargs"] = warn_dumps(task[3]) - return d - elif not any(map(_maybe_complex, task[1:])): - return {"function": dumps_function(task[0]), "args": warn_dumps(task[1:])} - return to_serialize(task) - - -_warn_dumps_warned = [False] - - -def warn_dumps(obj, dumps=pickle.dumps, limit=1e6): - """Dump an object to bytes, warn if those bytes are large""" - b = dumps(obj) - if not _warn_dumps_warned[0] and len(b) > limit: - _warn_dumps_warned[0] = True - s = str(obj) - if len(s) > 70: - s = s[:50] + " ... " + s[-15:] - warnings.warn( - "Large object of size %s detected in task graph: \n" - " %s\n" - "Consider scattering large objects ahead of time\n" - "with client.scatter to reduce scheduler burden and \n" - "keep data on workers\n\n" - " future = client.submit(func, big_data) # bad\n\n" - " big_future = client.scatter(big_data) # good\n" - " future = client.submit(func, big_future) # good" - % (format_bytes(len(b)), s) - ) - return b - - def apply_function( function, args, diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 985db35d24..7d22d7075d 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -27,16 +27,7 @@ from dataclasses import dataclass, field from functools import lru_cache, partial, singledispatchmethod, wraps from itertools import chain -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Literal, - NamedTuple, - TypedDict, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict, Union, cast from tlz import peekn @@ -49,7 +40,7 @@ from distributed.core import ErrorMessage, error_message from distributed.metrics import DelayedMetricsLedger, monotonic, time from distributed.protocol import pickle -from distributed.protocol.serialize import Serialize +from distributed.protocol.serialize import Serialize, ToPickle from distributed.sizeof import safe_sizeof as sizeof from distributed.utils import recursive_to_dict @@ -64,6 +55,7 @@ # Circular imports from distributed.diagnostics.plugin import WorkerPlugin + from distributed.scheduler import T_runspec from distributed.worker import Worker # Not to be confused with distributed.scheduler.TaskStateState @@ -112,19 +104,6 @@ RUN_ID_SENTINEL = -1 -class SerializedTask(NamedTuple): - """Info from distributed.scheduler.TaskState.run_spec - Input to distributed.worker._deserialize - - (function, args kwargs) and task are mutually exclusive - """ - - function: bytes | None = None - args: bytes | tuple | list | None = None - kwargs: bytes | dict[str, Any] | None = None - task: object = NO_VALUE - - class StartStop(TypedDict): action: Literal["compute", "transfer", "disk-read", "disk-write", "deserialize"] start: float @@ -239,11 +218,11 @@ class TaskState: prefix: str = field(init=False) #: Task run ID. run_id: int = RUN_ID_SENTINEL - #: A named tuple containing the ``function``, ``args``, ``kwargs`` and ``task`` + #: A tuple containing the ``function``, ``args``, ``kwargs`` and ``task`` #: associated with this `TaskState` instance. This defaults to ``None`` and can #: remain empty if it is a dependency that this worker will receive from another #: worker. - run_spec: SerializedTask | None = None + run_spec: T_runspec | None = None #: The data needed by this key to run dependencies: set[TaskState] = field(default_factory=set) @@ -763,10 +742,7 @@ class ComputeTaskEvent(StateMachineEvent): nbytes: dict[str, int] priority: tuple[int, ...] duration: float - run_spec: SerializedTask | None - function: bytes | None - args: bytes | tuple | list | None | None - kwargs: bytes | dict[str, Any] | None + run_spec: T_runspec | None resource_restrictions: dict[str, float] actor: bool annotations: dict @@ -778,24 +754,17 @@ def __post_init__(self) -> None: # Fixes after msgpack decode if isinstance(self.priority, list): # type: ignore[unreachable] self.priority = tuple(self.priority) # type: ignore[unreachable] - - if self.function is not None: - assert self.run_spec is None - self.run_spec = SerializedTask( - function=self.function, args=self.args, kwargs=self.kwargs - ) - elif not isinstance(self.run_spec, SerializedTask): - self.run_spec = SerializedTask(task=self.run_spec) + if isinstance(self.run_spec, ToPickle): + # FIXME Sometimes the protocol is not unpacking this + # E.g. distributed/tests/test_client.py::test_async_with + self.run_spec = self.run_spec.data # type: ignore[unreachable] def _to_dict(self, *, exclude: Container[str] = ()) -> dict: return StateMachineEvent._to_dict(self._clean(), exclude=exclude) def _clean(self) -> StateMachineEvent: out = copy(self) - out.function = None - out.kwargs = None - out.args = None - out.run_spec = SerializedTask(task=None, function=None, args=None, kwargs=None) + out.run_spec = None return out def to_loggable(self, *, handled: float) -> StateMachineEvent: @@ -804,7 +773,15 @@ def to_loggable(self, *, handled: float) -> StateMachineEvent: return out def _after_from_dict(self) -> None: - self.run_spec = SerializedTask(task=None, function=None, args=None, kwargs=None) + self.run_spec = None + + @classmethod + def _f(cls) -> None: + return # pragma: nocover + + @classmethod + def dummy_runspec(cls) -> tuple[Callable, tuple, dict]: + return (cls._f, (), {}) @staticmethod def dummy( @@ -830,10 +807,7 @@ def dummy( nbytes=nbytes or {k: 1 for k in who_has or ()}, priority=priority, duration=duration, - run_spec=None, - function=None, - args=None, - kwargs=None, + run_spec=ComputeTaskEvent.dummy_runspec(), resource_restrictions=resource_restrictions or {}, actor=actor, annotations=annotations or {},