diff --git a/src/main/cpp/src/SparkResourceAdaptorJni.cpp b/src/main/cpp/src/SparkResourceAdaptorJni.cpp index b68cc16308..3e7e3ff8c7 100644 --- a/src/main/cpp/src/SparkResourceAdaptorJni.cpp +++ b/src/main/cpp/src/SparkResourceAdaptorJni.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include +#include #include #include @@ -384,10 +384,10 @@ class full_thread_state { * mitigation we might want to do to avoid killing a task with an out of * memory error. */ -class spark_resource_adaptor final : public rmm::mr::device_memory_resource { +class spark_resource_adaptor final { public: spark_resource_adaptor(JNIEnv* env, - rmm::mr::device_memory_resource* mr, + rmm::device_async_resource_ref mr, std::shared_ptr& logger, bool const is_log_enabled) : resource{mr}, logger{logger}, is_log_enabled{is_log_enabled} @@ -399,7 +399,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { logger->set_pattern("%H:%M:%S.%f,%v"); } - rmm::mr::device_memory_resource* get_wrapped_resource() { return resource; } + rmm::device_async_resource_ref get_wrapped_resource() { return resource; } /** * Update the internal state so that a specific thread is dedicated to a task. @@ -870,7 +870,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { } private: - rmm::mr::device_memory_resource* const resource; + rmm::device_async_resource_ref resource; std::shared_ptr logger; ///< spdlog logger object bool const is_log_enabled; @@ -1728,13 +1728,46 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { return ret; } - void* do_allocate(std::size_t const num_bytes, rmm::cuda_stream_view stream) override + /** + * Sync allocation method required to satisfy cuda::mr::resource concept + * Synchronous memory allocations are not supported + */ + void* allocate(std::size_t, std::size_t) { return nullptr; } + + /** + * Sync deallocation method required to satisfy cuda::mr::resource concept + * Asynchronous memory allocations are not supported + */ + void deallocate(void*, std::size_t, std::size_t) {} + + /** + * Equality comparison method required to satisfy cuda::mr::resource concept + */ + friend bool operator==(const spark_resource_adaptor& lhs, const spark_resource_adaptor& rhs) + { + return (lhs.resource == rhs.resource) && (lhs.jvm == rhs.jvm); + } + + /** + * Equality comparison method required to satisfy cuda::mr::resource concept + */ + friend bool operator!=(const spark_resource_adaptor& lhs, const spark_resource_adaptor& rhs) + { + return !(lhs == rhs); + } + + /** + * Async allocation method required to satisfy cuda::mr::async_resource concept + */ + void* allocate_async(std::size_t const num_bytes, + std::size_t const alignment, + rmm::cuda_stream_view stream) { auto const tid = static_cast(pthread_self()); while (true) { bool const likely_spill = pre_alloc(tid); try { - void* ret = resource->allocate(num_bytes, stream); + void* ret = resource.allocate_async(num_bytes, alignment, stream); post_alloc_success(tid, likely_spill); return ret; } catch (rmm::out_of_memory const& e) { @@ -1787,9 +1820,15 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { wake_next_highest_priority_blocked(lock, true, is_for_cpu); } - void do_deallocate(void* p, std::size_t size, rmm::cuda_stream_view stream) override + /** + * Async deallocation method required to satisfy cuda::mr::async_resource concept + */ + void deallocate_async(void* p, + std::size_t size, + std::size_t const alignment, + rmm::cuda_stream_view stream) { - resource->deallocate(p, size, stream); + resource.deallocate_async(p, size, alignment, stream); // deallocate success if (size > 0) { std::unique_lock lock(state_mutex); @@ -1818,7 +1857,7 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_cr JNI_NULL_CHECK(env, child, "child is null", 0); try { cudf::jni::auto_set_device(env); - auto wrapped = reinterpret_cast(child); + auto wrapped = reinterpret_cast(child); cudf::jni::native_jstring nlogloc(env, log_loc); std::shared_ptr logger; bool is_log_enabled; @@ -1837,7 +1876,7 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_cr } } - auto ret = new spark_resource_adaptor(env, wrapped, logger, is_log_enabled); + auto ret = new spark_resource_adaptor(env, *wrapped, logger, is_log_enabled); return cudf::jni::ptr_as_jlong(ret); } CATCH_STD(env, 0)