From 5bbd53e18d9b56705bdf85ed6bda16a74e1f9c38 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 20 Apr 2020 18:08:06 -0700 Subject: [PATCH] Define `as_numba_cuda_array` --- distributed/comm/ucx.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 6979cdd9342..8d30287cbf5 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -35,6 +35,7 @@ ucp = None host_array = None device_array = None +as_numba_device_array = None def synchronize_stream(stream=0): @@ -47,7 +48,7 @@ def synchronize_stream(stream=0): def init_once(): - global ucp, host_array, device_array + global ucp, host_array, device_array, as_numba_device_array if ucp is not None: return @@ -100,6 +101,16 @@ def device_array(n): "In order to send/recv CUDA arrays, Numba or RMM is required" ) + # Find the function, `as_numba_device_array()` + try: + import numba.cuda + + as_numba_device_array = lambda a: numba.cuda.as_cuda_array(a) + except ImportError: + + def as_numba_device_array(n): + raise RuntimeError("In order to send/recv CUDA arrays, Numba is required") + pool_size_str = dask.config.get("rmm.pool-size") if pool_size_str is not None: pool_size = parse_bytes(pool_size_str)