Skip to content

Commit

Permalink
Define as_cuda_array
Browse files Browse the repository at this point in the history
Provides a function to let us coerce our underlying
`__cuda_array_interface__` objects into something that behaves more like
an array. Prefers CuPy if possible, but will fallback to Numba if its
not available.
  • Loading branch information
jakirkham committed Apr 21, 2020
1 parent 6db09f3 commit 070a19f
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ucp = None
host_array = None
device_array = None
as_device_array = None


def synchronize_stream(stream=0):
Expand All @@ -47,7 +48,7 @@ def synchronize_stream(stream=0):


def init_once():
global ucp, host_array, device_array
global ucp, host_array, device_array, as_device_array
if ucp is not None:
return

Expand Down Expand Up @@ -100,6 +101,23 @@ def device_array(n):
"In order to send/recv CUDA arrays, Numba or RMM is required"
)

# Find the function, `as_device_array()`
try:
import cupy

as_device_array = lambda a: cupy.asarray(a)
except ImportError:
try:
import numba.cuda

as_device_array = lambda a: numba.cuda.as_cuda_array(a)
except ImportError:

def as_device_array(n):
raise RuntimeError(
"In order to send/recv CUDA arrays, CuPy or Numba is required"
)

pool_size_str = dask.config.get("rmm.pool-size")
if pool_size_str is not None:
pool_size = parse_bytes(pool_size_str)
Expand Down

0 comments on commit 070a19f

Please sign in to comment.