From cbd27264c450940ae75472a6f45f1ca0242bfd65 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 26 Oct 2020 11:03:09 -0500 Subject: [PATCH] Use thread-local to track CUDA device in JNI (#6597) * Use thread-local to track CUDA device in JNI * changelog --- CHANGELOG.md | 3 ++- java/src/main/native/src/CudaJni.cpp | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 60a39724266..45ffed754f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,8 @@ - PR #6581 Add JNI API to check if PTDS is enabled ## Improvements -- PR #6430 Add struct type support to `to_arrow` and `from_arrow` +- PR #6430 Add struct type support to `to_arrow` and `from_arrow` - PR #6384 Add CSV fuzz tests with varying function parameters - PR #6385 Add JSON fuzz tests with varying function parameters - PR #6398 Remove function constructor macros in parquet reader @@ -29,6 +29,7 @@ - PR #6555 Adapt JNI build to libcudf composition of multiple libraries - PR #6564 Load JNI library dependencies with a thread pool - PR #6573 Create `cudf::detail::byte_cast` for `cudf::byte_cast` +- PR #6597 Use thread-local to track CUDA device in JNI ## Bug Fixes diff --git a/java/src/main/native/src/CudaJni.cpp b/java/src/main/native/src/CudaJni.cpp index 930014d5e06..b41fae21a74 100644 --- a/java/src/main/native/src/CudaJni.cpp +++ b/java/src/main/native/src/CudaJni.cpp @@ -21,6 +21,8 @@ namespace { /** The CUDA device that should be used by all threads using cudf */ int Cudf_device{cudaInvalidDeviceId}; +thread_local int Thread_device = cudaInvalidDeviceId; + } // anonymous namespace namespace cudf { @@ -37,12 +39,10 @@ void set_cudf_device(int device) { */ void auto_set_device(JNIEnv *env) { if (Cudf_device != cudaInvalidDeviceId) { - int device; - cudaError_t cuda_status = cudaGetDevice(&device); - jni_cuda_check(env, cuda_status); - if (device != Cudf_device) { - cuda_status = cudaSetDevice(Cudf_device); + if (Thread_device != Cudf_device) { + cudaError_t cuda_status = cudaSetDevice(Cudf_device); jni_cuda_check(env, cuda_status); + Thread_device = Cudf_device; } } }