Skip to content

Commit

Permalink
Detect recursion in pre_allocate and skip state transitions (#1122)
Browse files Browse the repository at this point in the history
* Detect recursion in pre_allocate and skip state transitions

Signed-off-by: Alessandro Bellina <abellina@nvidia.com>
  • Loading branch information
abellina authored May 9, 2023
1 parent 6383147 commit 7bc86fd
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 21 deletions.
72 changes: 53 additions & 19 deletions src/main/cpp/src/SparkResourceAdaptorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,12 +837,29 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
* Called prior to processing an alloc attempt. This will throw any injected exception and
* wait until the thread is ready to actually do/retry the allocation. That blocking API may
* throw other exceptions if rolling back or splitting the input is considered needed.
*
* @return true if the call finds our thread in an ALLOC state, meaning that we recursively
* entered the state machine. The only known case is GPU memory required for setup in
* cuDF for a spill operation.
*/
void pre_alloc(long thread_id) {
bool pre_alloc(long thread_id) {
std::unique_lock<std::mutex> lock(state_mutex);

auto thread = threads.find(thread_id);
if (thread != threads.end()) {
switch(thread->second.state) {
// If the thread is in one of the ALLOC or ALLOC_FREE states, we have detected a loop
// likely due to spill setup required in cuDF. We will treat this allocation differently
// and skip transitions.
case TASK_ALLOC:
case SHUFFLE_ALLOC:
case TASK_ALLOC_FREE:
case SHUFFLE_ALLOC_FREE:
return true;

default: break;
}

if (thread->second.retry_oom_injected > 0) {
thread->second.retry_oom_injected--;
thread->second.num_times_retry_throw++;
Expand Down Expand Up @@ -870,8 +887,9 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
case SHUFFLE_RUNNING:
transition(thread->second, thread_state::SHUFFLE_ALLOC);
break;
// TODO I don't think there are other states that we need to handle, but
// this needs more testing.

// TODO I don't think there are other states that we need to handle, but
// this needs more testing.
default: {
std::stringstream ss;
ss << "thread " << thread_id << " in unexpected state pre alloc "
Expand All @@ -881,6 +899,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
}
}
}
return false;
}

/**
Expand All @@ -889,12 +908,15 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
* GPU memory. I don't want to mark it as nothrow, because we can throw an
* exception on an internal error, and I would rather see that we got the internal
* error and leak something instead of getting a segfault.
*
* `likely_spill` if this allocation should be treated differently, because
* we detected recursion while handling a prior allocation in this thread.
*/
void post_alloc_success(long thread_id) {
void post_alloc_success(long thread_id, bool likely_spill) {
std::unique_lock<std::mutex> lock(state_mutex);
// pre allocate checks
auto thread = threads.find(thread_id);
if (thread != threads.end()) {
if (!likely_spill && thread != threads.end()) {
switch (thread->second.state) {
case TASK_ALLOC:
// fall through
Expand Down Expand Up @@ -1087,12 +1109,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
* typically happen after this has run, and we loop around to retry the alloc
* if the state says we should.
*/
bool post_alloc_failed(long thread_id, bool is_oom) {
bool post_alloc_failed(long thread_id, bool is_oom, bool likely_spill) {
std::unique_lock<std::mutex> lock(state_mutex);
auto thread = threads.find(thread_id);
// only retry if this was due to an out of memory exception.
bool ret = true;
if (thread != threads.end()) {
if (!likely_spill && thread != threads.end()) {
switch (thread->second.state) {
case TASK_ALLOC_FREE: transition(thread->second, thread_state::TASK_RUNNING); break;
case TASK_ALLOC:
Expand Down Expand Up @@ -1130,19 +1152,18 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
void *do_allocate(std::size_t num_bytes, rmm::cuda_stream_view stream) override {
auto tid = static_cast<long>(pthread_self());
while (true) {
pre_alloc(tid);
bool likely_spill = pre_alloc(tid);
try {
void *ret = resource->allocate(num_bytes, stream);
post_alloc_success(tid);
post_alloc_success(tid, likely_spill);
return ret;
} catch (const std::bad_alloc &e) {
if (!post_alloc_failed(tid, true)) {
if (!post_alloc_failed(tid, true, likely_spill)) {
throw;
}
} catch (const std::exception &e) {
if (!post_alloc_failed(tid, false)) {
throw;
}
post_alloc_failed(tid, false, likely_spill);
throw;
}
}
// we should never reach this point, but just in case
Expand All @@ -1163,10 +1184,25 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {

std::unique_lock<std::mutex> lock(state_mutex);
for (auto thread = threads.begin(); thread != threads.end(); thread++) {
switch (thread->second.state) {
case TASK_ALLOC: transition(thread->second, thread_state::TASK_ALLOC_FREE); break;
case SHUFFLE_ALLOC: transition(thread->second, thread_state::SHUFFLE_ALLOC_FREE); break;
default: break;
// Only update state for _other_ threads. We update only other threads, for the case
// where we are handling a free from the recursive case: when an allocation/free
// happened while handling an allocation failure in onAllocFailed.
//
// If we moved all threads to *_ALLOC_FREE, after we exit the recursive state and
// are back handling the original allocation failure, we are left with a thread
// in a state that won't be retried in `post_alloc_failed`.
//
// By not changing our thread's state to TASK_ALLOC_FREE, we keep the state
// the same, but we still let other threads know that there was a free and they should
// handle accordingly.
if (thread->second.thread_id != tid) {
switch (thread->second.state) {
case TASK_ALLOC:
transition(thread->second, thread_state::TASK_ALLOC_FREE); break;
case SHUFFLE_ALLOC:
transition(thread->second, thread_state::SHUFFLE_ALLOC_FREE); break;
default: break;
}
}
}
wake_next_highest_priority_regular_blocked(lock);
Expand Down Expand Up @@ -1359,7 +1395,6 @@ JNIEXPORT jint JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_get
CATCH_STD(env, 0)
}


JNIEXPORT jint JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_getAndResetRetryThrowInternal(
JNIEnv *env, jclass, jlong ptr, jlong task_id) {
JNI_NULL_CHECK(env, ptr, "resource_adaptor is null", 0);
Expand Down Expand Up @@ -1392,5 +1427,4 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_ge
}
CATCH_STD(env, 0)
}

}
108 changes: 106 additions & 2 deletions src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,8 @@
import ai.rapids.cudf.RmmEventHandler;
import ai.rapids.cudf.RmmLimitingResourceAdaptor;
import ai.rapids.cudf.RmmTrackingResourceAdaptor;
import ai.rapids.cudf.ColumnVector.EventHandler;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -480,6 +482,10 @@ public synchronized void close() {
}

void setupRmmForTestingWithLimits(long maxAllocSize) {
setupRmmForTestingWithLimits(maxAllocSize, new BaseRmmEventHandler());
}

void setupRmmForTestingWithLimits(long maxAllocSize, RmmEventHandler eventHandler) {
// Rmm.initialize is not going to limit allocations without a pool, so we
// need to set it up ourselves.
RmmDeviceMemoryResource resource = null;
Expand All @@ -495,7 +501,7 @@ void setupRmmForTestingWithLimits(long maxAllocSize) {
resource.close();
}
}
RmmSpark.setEventHandler(new BaseRmmEventHandler(), "stderr");
RmmSpark.setEventHandler(eventHandler, "stderr");
}

@Test
Expand Down Expand Up @@ -773,6 +779,57 @@ public void retryWatchdog() {
long endTime = System.nanoTime();
System.err.println("Took " + (endTime - startTime) + "ns to retry 500 times...");
}

//
// These next two tests deal with a special case where allocations (and allocation failures)
// could happen during spill handling.
//
// When we spill we may need to invoke cuDF code that creates memory, specifically to
// pack previously unpacked memory into a single contiguous buffer (cudf::chunked_pack).
// This operation, although it makes use of an auxiliary memory resource, still deals with
// cuDF apis that could, at any time, allocate small amounts of memory in the default memory
// resource. As such, allocations and allocation failures could happen, which cause us
// to recursively enter the state machine in SparkResourceAdaptorJni.
//
@Test
public void testAllocationDuringSpill() {
// Create a handler that allocates 1 byte from the handler (it should succeed)
AllocatingRmmEventHandler rmmEventHandler = new AllocatingRmmEventHandler(1);
// 10 MiB
setupRmmForTestingWithLimits(10 * 1024 * 1024, rmmEventHandler);
long threadId = RmmSpark.getCurrentThreadId();
long taskid = 0; // This is arbitrary
RmmSpark.associateThreadWithTask(threadId, taskid);
assertThrows(GpuOOM.class, () -> {
try (DeviceMemoryBuffer filler = Rmm.alloc(9 * 1024 * 1024)) {
try (DeviceMemoryBuffer shouldFail = Rmm.alloc(2 * 1024 * 1024)) {}
fail("overallocation should have failed");
} finally {
RmmSpark.removeThreadAssociation(threadId);
}
});
assertEquals(11, rmmEventHandler.getAllocationCount());
}

@Test
public void testAllocationFailedDuringSpill() {
// Create a handler that allocates 2MB from the handler (it should fail)
AllocatingRmmEventHandler rmmEventHandler = new AllocatingRmmEventHandler(2L*1024*1024);
// 10 MiB
setupRmmForTestingWithLimits(10 * 1024 * 1024, rmmEventHandler);
long threadId = RmmSpark.getCurrentThreadId();
long taskid = 0; // This is arbitrary
RmmSpark.associateThreadWithTask(threadId, taskid);
assertThrows(GpuOOM.class, () -> {
try (DeviceMemoryBuffer filler = Rmm.alloc(9 * 1024 * 1024)) {
try (DeviceMemoryBuffer shouldFail = Rmm.alloc(2 * 1024 * 1024)) {}
fail("overallocation should have failed");
} finally {
RmmSpark.removeThreadAssociation(threadId);
}
});
assertEquals(0, rmmEventHandler.getAllocationCount());
}

private static class BaseRmmEventHandler implements RmmEventHandler {
@Override
Expand All @@ -799,4 +856,51 @@ public boolean onAllocFailure(long sizeRequested, int retryCount) {
return false;
}
}

private static class AllocatingRmmEventHandler extends BaseRmmEventHandler {
// if true, we are still in the onAllocFailure callback (recursive call)
boolean stillHandlingAllocFailure = false;

int allocationCount;

long allocSize;

public int getAllocationCount() {
return allocationCount;
}

public AllocatingRmmEventHandler(long allocSize) {
this.allocSize = allocSize;
}

@Override
public boolean onAllocFailure(long sizeRequested, int retryCount) {
// Catch java.lang.OutOfMemory since we could gt this exception during `Rmm.alloc`.
// Catch all throwables because any other exception is not handled gracefully from callers
// but if we do see such exceptions make sure we call `fail` so we get a test failure.
try {
if (stillHandlingAllocFailure) {
// detected a loop
stillHandlingAllocFailure = false;
return false;
} else {
stillHandlingAllocFailure = true;
try (DeviceMemoryBuffer dmb = Rmm.alloc(allocSize)) { // try to allocate one byte, and free
allocationCount++;
stillHandlingAllocFailure = false;
}
// allow retries up to 10 times
return retryCount < 10;
}
} catch (java.lang.OutOfMemoryError e) {
// return false here, this allocation failure handling failed with
// java.lang.OutOfMemory from `RmmJni`
return false;
} catch (Throwable t) {
fail("unexpected exception in onAllocFailure", t);
return false;
}
}

}
}

0 comments on commit 7bc86fd

Please sign in to comment.