Skip to content

Commit

Permalink
Adjust tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed Oct 10, 2024
1 parent 18685db commit d746d50
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 11 deletions.
2 changes: 2 additions & 0 deletions distributed/protocol/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def loads(x, *, buffers=()):
return pickle.loads(x, buffers=buffers)
else:
return pickle.loads(x)
except EOFError:
raise
except Exception as e:
logger.info("Failed to deserialize %s", x[:10000], exc_info=True)
raise pickle.UnpicklingError("Failed to deserialize") from e
8 changes: 4 additions & 4 deletions distributed/protocol/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from distributed.protocol import deserialize, serialize
from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads
from distributed.protocol.serialize import dask_deserialize, dask_serialize
from distributed.utils_test import popen, save_sys_modules
from distributed.utils_test import popen, raises_with_cause, save_sys_modules


class MemoryviewHolder:
Expand Down Expand Up @@ -231,7 +231,7 @@ def _deserialize_nopickle(header, frames):


def test_allow_pickle_if_registered_in_dask_serialize():
with pytest.raises(TypeError, match="nope"):
with raises_with_cause(pickle.PicklingError, "serialize", TypeError, "nope"):
dumps(NoPickle())

dask_serialize.register(NoPickle)(_serialize_nopickle)
Expand All @@ -251,9 +251,9 @@ def __init__(self) -> None:

def test_nopickle_nested():
nested_obj = [NoPickle()]
with pytest.raises(TypeError, match="nope"):
with raises_with_cause(pickle.PicklingError, "serialize", TypeError, "nope"):
dumps(nested_obj)
with pytest.raises(TypeError, match="nope"):
with raises_with_cause(pickle.PicklingError, "serialize", TypeError, "nope"):
dumps(NestedNoPickle())

dask_serialize.register(NoPickle)(_serialize_nopickle)
Expand Down
18 changes: 16 additions & 2 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5100,7 +5100,14 @@ def __setstate__(self, state):
future = c.submit(identity, Foo())
await wait(future)
assert future.status == "error"
with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"):
with raises_with_cause(
RuntimeError,
"deserialization",
pickle.UnpicklingError,
"deserialize",
MyException,
"hello",
):
await future

futures = c.map(inc, range(10))
Expand All @@ -5125,7 +5132,14 @@ def __call__(self, *args):
future = c.submit(Foo(), 1)
await wait(future)
assert future.status == "error"
with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"):
with raises_with_cause(
RuntimeError,
"deserialization",
pickle.UnpicklingError,
"deserialize",
MyException,
"hello",
):
await future

futures = c.map(inc, range(10))
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_spillbuffer_fail_to_serialize(tmp_path):
with pytest.raises(TypeError, match="Failed to pickle 'a'") as e:
with captured_logger("distributed.spill") as logs_bad_key:
buf["a"] = a
assert isinstance(e.value.__cause__.__cause__, MyError)
assert isinstance(e.value.__cause__.__cause__.__cause__, MyError)

# spill.py must remain silent because we're already logging in worker.py
assert not logs_bad_key.getvalue()
Expand Down
9 changes: 6 additions & 3 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import itertools
import logging
import os
import pickle
import random
import sys
import tempfile
Expand Down Expand Up @@ -38,6 +39,7 @@
get_client,
get_worker,
profile,
protocol,
wait,
)
from distributed.comm.registry import backends
Expand All @@ -46,7 +48,6 @@
from distributed.core import CommClosedError, Status, rpc
from distributed.diagnostics.plugin import ForwardOutput
from distributed.metrics import time
from distributed.protocol import pickle
from distributed.scheduler import KilledWorker, Scheduler
from distributed.utils import get_mp_context, wait_for
from distributed.utils_test import (
Expand Down Expand Up @@ -509,13 +510,15 @@ async def test_plugin_internal_exception():
with raises_with_cause(
RuntimeError,
"Worker failed to start",
pickle.UnpicklingError,
"deserialize",
UnicodeDecodeError,
match_cause="codec can't decode",
"codec can't decode",
):
async with Worker(
s.address,
plugins={
b"corrupting pickle" + pickle.dumps(lambda: None),
b"corrupting pickle" + protocol.pickle.dumps(lambda: None),
},
) as w:
pass
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ async def test_fail_to_pickle_execute_1(c, s, a, b):

with pytest.raises(TypeError, match="Failed to pickle 'x'") as e:
await x
assert isinstance(e.value.__cause__.__cause__, CustomError)
assert isinstance(e.value.__cause__.__cause__.__cause__, CustomError)

await assert_basic_futures(c)

Expand Down

0 comments on commit d746d50

Please sign in to comment.