Skip to content

Commit

Permalink
Define as_numba_cuda_array
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed Apr 21, 2020
1 parent 98d82dd commit 5bbd53e
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ucp = None
host_array = None
device_array = None
as_numba_device_array = None


def synchronize_stream(stream=0):
Expand All @@ -47,7 +48,7 @@ def synchronize_stream(stream=0):


def init_once():
global ucp, host_array, device_array
global ucp, host_array, device_array, as_numba_device_array
if ucp is not None:
return

Expand Down Expand Up @@ -100,6 +101,16 @@ def device_array(n):
"In order to send/recv CUDA arrays, Numba or RMM is required"
)

# Find the function, `as_numba_device_array()`
try:
import numba.cuda

as_numba_device_array = lambda a: numba.cuda.as_cuda_array(a)
except ImportError:

def as_numba_device_array(n):
raise RuntimeError("In order to send/recv CUDA arrays, Numba is required")

pool_size_str = dask.config.get("rmm.pool-size")
if pool_size_str is not None:
pool_size = parse_bytes(pool_size_str)
Expand Down

0 comments on commit 5bbd53e

Please sign in to comment.