Skip to content

Commit

Permalink
Update all UCX tests to use asyncio marker (#5484)
Browse files Browse the repository at this point in the history
* Update all UCX tests to use asyncio marker

By marking all UCX tests as asyncio we ensure that they all use the
fixture that properly releases UCX resources once the event loop is
closed.

* Mark test_ucx_specific as asyncio

* Run UCX stress test in multiprocess mode
  • Loading branch information
pentschev authored Nov 2, 2021
1 parent 8cc4284 commit 7649596
Showing 1 changed file with 64 additions and 64 deletions.
128 changes: 64 additions & 64 deletions distributed/comm/tests/test_ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from distributed.comm.registry import backends, get_backend
from distributed.deploy.local import LocalCluster
from distributed.protocol import to_serialize
from distributed.utils_test import gen_cluster, gen_test, inc
from distributed.utils_test import inc

try:
HOST = ucp.get_address()
Expand Down Expand Up @@ -89,7 +89,8 @@ async def test_comm_objs():
assert comm.peer_address == serv_comm.local_address


def test_ucx_specific():
@pytest.mark.asyncio
async def test_ucx_specific():
"""
Test concrete UCX API.
"""
Expand All @@ -98,52 +99,48 @@ def test_ucx_specific():
# 2. Use dict in read / write, put seralization there.
# 3. Test peer_address
# 4. Test cleanup
async def f():
address = f"ucx://{HOST}:{0}"

async def handle_comm(comm):
msg = await comm.read()
msg["op"] = "pong"
await comm.write(msg)
await comm.read()
assert comm.closed() is False
await comm.close()
assert comm.closed

listener = await ucx.UCXListener(address, handle_comm)
host, port = listener.get_host_port()
assert host.count(".") == 3
assert port > 0

l = []

async def client_communicate(key, delay=0):
addr = "%s:%d" % (host, port)
comm = await connect(listener.contact_address)
# TODO: peer_address
# assert comm.peer_address == 'ucx://' + addr
assert comm.extra_info == {}
msg = {"op": "ping", "data": key}
await comm.write(msg)
if delay:
await asyncio.sleep(delay)
msg = await comm.read()
assert msg == {"op": "pong", "data": key}
await comm.write({"op": "client closed"})
l.append(key)
return comm

comm = await client_communicate(key=1234, delay=0.5)

# Many clients at once
N = 2
futures = [client_communicate(key=i, delay=0.05) for i in range(N)]
await asyncio.gather(*futures)
assert set(l) == {1234} | set(range(N))

listener.stop()

asyncio.run(f())
address = f"ucx://{HOST}:{0}"

async def handle_comm(comm):
msg = await comm.read()
msg["op"] = "pong"
await comm.write(msg)
await comm.read()
await comm.close()
assert comm.closed() is True

listener = await ucx.UCXListener(address, handle_comm)
host, port = listener.get_host_port()
assert host.count(".") == 3
assert port > 0

l = []

async def client_communicate(key, delay=0):
addr = "%s:%d" % (host, port)
comm = await connect(listener.contact_address)
# TODO: peer_address
# assert comm.peer_address == 'ucx://' + addr
assert comm.extra_info == {}
msg = {"op": "ping", "data": key}
await comm.write(msg)
if delay:
await asyncio.sleep(delay)
msg = await comm.read()
assert msg == {"op": "pong", "data": key}
await comm.write({"op": "client closed"})
l.append(key)
return comm

comm = await client_communicate(key=1234, delay=0.5)

# Many clients at once
N = 2
futures = [client_communicate(key=i, delay=0.05) for i in range(N)]
await asyncio.gather(*futures)
assert set(l) == {1234} | set(range(N))

listener.stop()


@pytest.mark.asyncio
Expand All @@ -169,7 +166,7 @@ async def test_ping_pong_data():
await serv_com.close()


@gen_test()
@pytest.mark.asyncio
async def test_ucx_deserialize():
# Note we see this error on some systems with this test:
# `socket.gaierror: [Errno -5] No address associated with hostname`
Expand Down Expand Up @@ -256,7 +253,7 @@ async def test_large_cupy(n, cleanup):
await serv_com.close()


@gen_test()
@pytest.mark.asyncio
async def test_ping_pong_numba():
np = pytest.importorskip("numpy")
numba = pytest.importorskip("numba")
Expand All @@ -274,8 +271,8 @@ async def test_ping_pong_numba():
assert result["op"] == "ping"


@pytest.mark.parametrize("processes", [True, False])
@pytest.mark.asyncio
@pytest.mark.parametrize("processes", [True, False])
async def test_ucx_localcluster(processes, cleanup):
async with LocalCluster(
protocol="ucx",
Expand All @@ -296,7 +293,7 @@ async def test_ucx_localcluster(processes, cleanup):


@pytest.mark.slow
@gen_test(timeout=240)
@pytest.mark.asyncio
async def test_stress():
da = pytest.importorskip("dask.array")

Expand All @@ -306,7 +303,6 @@ async def test_stress():
protocol="ucx",
dashboard_address=":0",
asynchronous=True,
processes=False,
host=HOST,
) as cluster:
async with Client(cluster, asynchronous=True):
Expand All @@ -322,21 +318,25 @@ async def test_stress():
await wait(x)


@gen_cluster(client=True, scheduler_kwargs={"protocol": "ucx"})
async def test_simple(c, s, a, b):
assert s.address.startswith("ucx://")
assert await c.submit(lambda x: x + 1, 10) == 11
@pytest.mark.asyncio
async def test_simple():
async with LocalCluster(protocol="ucx", asynchronous=True) as cluster:
async with Client(cluster, asynchronous=True) as client:
assert cluster.scheduler_address.startswith("ucx://")
assert await client.submit(lambda x: x + 1, 10) == 11


@gen_cluster(client=True, scheduler_kwargs={"protocol": "ucx"})
async def test_transpose(c, s, a, b):
@pytest.mark.asyncio
async def test_transpose():
da = pytest.importorskip("dask.array")

assert s.address.startswith("ucx://")
x = da.ones((10000, 10000), chunks=(1000, 1000)).persist()
await x
y = (x + x.T).sum()
await y
async with LocalCluster(protocol="ucx", asynchronous=True) as cluster:
async with Client(cluster, asynchronous=True):
assert cluster.scheduler_address.startswith("ucx://")
x = da.ones((10000, 10000), chunks=(1000, 1000)).persist()
await x
y = (x + x.T).sum()
await y


@pytest.mark.asyncio
Expand Down

0 comments on commit 7649596

Please sign in to comment.