Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update spark_resource_adaptor for rmm's device_async_resource_ref #2064

Draft
wants to merge 1 commit into
base: branch-24.12
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 50 additions & 11 deletions src/main/cpp/src/SparkResourceAdaptorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <cudf_jni_apis.hpp>
#include <pthread.h>
Expand Down Expand Up @@ -384,10 +384,10 @@ class full_thread_state {
* mitigation we might want to do to avoid killing a task with an out of
* memory error.
*/
class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
class spark_resource_adaptor final {
public:
spark_resource_adaptor(JNIEnv* env,
rmm::mr::device_memory_resource* mr,
rmm::device_async_resource_ref mr,
std::shared_ptr<spdlog::logger>& logger,
bool const is_log_enabled)
: resource{mr}, logger{logger}, is_log_enabled{is_log_enabled}
Expand All @@ -399,7 +399,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
logger->set_pattern("%H:%M:%S.%f,%v");
}

rmm::mr::device_memory_resource* get_wrapped_resource() { return resource; }
rmm::device_async_resource_ref get_wrapped_resource() { return resource; }

/**
* Update the internal state so that a specific thread is dedicated to a task.
Expand Down Expand Up @@ -870,7 +870,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
}

private:
rmm::mr::device_memory_resource* const resource;
rmm::device_async_resource_ref resource;
std::shared_ptr<spdlog::logger> logger; ///< spdlog logger object
bool const is_log_enabled;

Expand Down Expand Up @@ -1728,13 +1728,46 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
return ret;
}

void* do_allocate(std::size_t const num_bytes, rmm::cuda_stream_view stream) override
/**
* Sync allocation method required to satisfy cuda::mr::resource concept
* Synchronous memory allocations are not supported
*/
void* allocate(std::size_t, std::size_t) { return nullptr; }

/**
* Sync deallocation method required to satisfy cuda::mr::resource concept
* Asynchronous memory allocations are not supported
*/
void deallocate(void*, std::size_t, std::size_t) {}

/**
* Equality comparison method required to satisfy cuda::mr::resource concept
*/
friend bool operator==(const spark_resource_adaptor& lhs, const spark_resource_adaptor& rhs)
{
return (lhs.resource == rhs.resource) && (lhs.jvm == rhs.jvm);
}

/**
* Equality comparison method required to satisfy cuda::mr::resource concept
*/
friend bool operator!=(const spark_resource_adaptor& lhs, const spark_resource_adaptor& rhs)
{
return !(lhs == rhs);
}

/**
* Async allocation method required to satisfy cuda::mr::async_resource concept
*/
void* allocate_async(std::size_t const num_bytes,
std::size_t const alignment,
rmm::cuda_stream_view stream)
{
auto const tid = static_cast<long>(pthread_self());
while (true) {
bool const likely_spill = pre_alloc(tid);
try {
void* ret = resource->allocate(num_bytes, stream);
void* ret = resource.allocate_async(num_bytes, alignment, stream);
post_alloc_success(tid, likely_spill);
return ret;
} catch (rmm::out_of_memory const& e) {
Expand Down Expand Up @@ -1787,9 +1820,15 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
wake_next_highest_priority_blocked(lock, true, is_for_cpu);
}

void do_deallocate(void* p, std::size_t size, rmm::cuda_stream_view stream) override
/**
* Async deallocation method required to satisfy cuda::mr::async_resource concept
*/
void deallocate_async(void* p,
std::size_t size,
std::size_t const alignment,
rmm::cuda_stream_view stream)
{
resource->deallocate(p, size, stream);
resource.deallocate_async(p, size, alignment, stream);
// deallocate success
if (size > 0) {
std::unique_lock<std::mutex> lock(state_mutex);
Expand Down Expand Up @@ -1818,7 +1857,7 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_cr
JNI_NULL_CHECK(env, child, "child is null", 0);
try {
cudf::jni::auto_set_device(env);
auto wrapped = reinterpret_cast<rmm::mr::device_memory_resource*>(child);
auto wrapped = reinterpret_cast<rmm::device_async_resource_ref*>(child);
cudf::jni::native_jstring nlogloc(env, log_loc);
std::shared_ptr<spdlog::logger> logger;
bool is_log_enabled;
Expand All @@ -1837,7 +1876,7 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_cr
}
}

auto ret = new spark_resource_adaptor(env, wrapped, logger, is_log_enabled);
auto ret = new spark_resource_adaptor(env, *wrapped, logger, is_log_enabled);
return cudf::jni::ptr_as_jlong(ret);
}
CATCH_STD(env, 0)
Expand Down
Loading