Skip to content

Commit

Permalink
Check direct_to_workers before using get_worker in Client (#2656)
Browse files Browse the repository at this point in the history
Otherwise we would ignore direct_to_workers=True when there wasn't a
local worker
  • Loading branch information
mrocklin authored May 5, 2019
1 parent d42173b commit 09b959a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
10 changes: 5 additions & 5 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def __init__(
serializers=None,
deserializers=None,
extensions=DEFAULT_EXTENSIONS,
direct_to_workers=False,
direct_to_workers=None,
**kwargs
):
if timeout == no_default:
Expand Down Expand Up @@ -1607,6 +1607,8 @@ def _gather(self, futures, errors="raise", direct=None, local_worker=None):
bad_data = dict()
data = {}

if direct is None:
direct = self.direct_to_workers
if direct is None:
try:
w = get_worker()
Expand All @@ -1615,8 +1617,6 @@ def _gather(self, futures, errors="raise", direct=None, local_worker=None):
else:
if w.scheduler.address == self.scheduler.address:
direct = True
if direct is None:
direct = self.direct_to_workers

@gen.coroutine
def wait(k):
Expand Down Expand Up @@ -1866,6 +1866,8 @@ def _scatter(

types = valmap(type, data)

if direct is None:
direct = self.direct_to_workers
if direct is None:
try:
w = get_worker()
Expand All @@ -1874,8 +1876,6 @@ def _scatter(
else:
if w.scheduler.address == self.scheduler.address:
direct = True
if direct is None:
direct = self.direct_to_workers

if local_worker: # running within task
local_worker.update_data(data=data, report=False)
Expand Down
8 changes: 8 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5707,5 +5707,13 @@ def test_get_mix_futures_and_SubgraphCallable_dask_dataframe(c, s, a, b):
assert result.equals(df.astype("f8"))


def test_direct_to_workers(s, loop):
with Client(s["address"], loop=loop, direct_to_workers=True) as client:
future = client.scatter(1)
future.result()
resp = client.run_on_scheduler(lambda dask_scheduler: dask_scheduler.events)
assert "gather" not in str(resp)


if sys.version_info >= (3, 5):
from distributed.tests.py3_test_client import * # noqa F401

0 comments on commit 09b959a

Please sign in to comment.