Skip to content

Commit

Permalink
Add retry count and block time metrics (NVIDIA#1031)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Mar 28, 2023
1 parent 31664e5 commit 113f928
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 0 deletions.
118 changes: 118 additions & 0 deletions src/main/cpp/src/SparkResourceAdaptorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,14 @@ class full_thread_state {
int cudf_exception_injected = 0;
// watchdog limit on maximum number of retries to avoid unexpected live lock situations
int num_times_retried = 0;
// metric for being able to report how many times each type of exception was thrown,
// and some timings
int num_times_retry_throw = 0;
int num_times_split_retry_throw = 0;
long time_blocked_nanos = 0;

std::chrono::time_point<std::chrono::steady_clock> block_start;

std::unique_ptr<std::condition_variable> wake_condition =
std::make_unique<std::condition_variable>();

Expand All @@ -196,6 +204,16 @@ class full_thread_state {
state = new_state;
}

void before_block() {
block_start = std::chrono::steady_clock::now();
}

void after_block() {
auto end = std::chrono::steady_clock::now();
auto diff = end - block_start;
time_blocked_nanos += std::chrono::duration_cast<std::chrono::nanoseconds>(diff).count();
}

/**
* Get the priority of this thread.
*/
Expand Down Expand Up @@ -414,6 +432,63 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
}
}

/**
* get the number of times a retry was thrown and reset the value to 0.
*/
int get_n_reset_num_retry(long task_id) {
std::unique_lock<std::mutex> lock(state_mutex);
int ret = 0;
auto task_at = task_to_threads.find(task_id);
if (task_at != task_to_threads.end()) {
for (auto thread_id : task_at->second) {
auto threads_at = threads.find(thread_id);
if (threads_at != threads.end()) {
ret += threads_at->second.num_times_retry_throw;
threads_at->second.num_times_retry_throw = 0;
}
}
}
return ret;
}

/**
* get the number of times a split and retry was thrown and reset the value to 0.
*/
int get_n_reset_num_split_retry(long task_id) {
std::unique_lock<std::mutex> lock(state_mutex);
int ret = 0;
auto task_at = task_to_threads.find(task_id);
if (task_at != task_to_threads.end()) {
for (auto thread_id : task_at->second) {
auto threads_at = threads.find(thread_id);
if (threads_at != threads.end()) {
ret += threads_at->second.num_times_split_retry_throw;
threads_at->second.num_times_split_retry_throw = 0;
}
}
}
return ret;
}

/**
* get the time in ns that the task was blocked for.
*/
long get_n_reset_block_time(long task_id) {
std::unique_lock<std::mutex> lock(state_mutex);
long ret = 0;
auto task_at = task_to_threads.find(task_id);
if (task_at != task_to_threads.end()) {
for (auto thread_id : task_at->second) {
auto threads_at = threads.find(thread_id);
if (threads_at != threads.end()) {
ret += threads_at->second.time_blocked_nanos;
threads_at->second.time_blocked_nanos = 0;
}
}
}
return ret;
}

/**
* Update the internal state so that this thread is known that it is going to enter a
* shuffle stage and could indirectly block on a shuffle thread (UCX).
Expand Down Expand Up @@ -574,12 +649,14 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {

void throw_retry_oom(const char *msg, full_thread_state &state,
const std::unique_lock<std::mutex> &lock) {
state.num_times_retry_throw++;
check_before_oom(state, lock);
throw_java_exception(RETRY_OOM_CLASS, "task should retry operation");
}

void throw_split_n_retry_oom(const char *msg, full_thread_state &state,
const std::unique_lock<std::mutex> &lock) {
state.num_times_split_retry_throw++;
check_before_oom(state, lock);
throw_java_exception(SPLIT_AND_RETRY_OOM_CLASS, "task should split input and retry operation");
}
Expand Down Expand Up @@ -615,10 +692,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
// fall through
case SHUFFLE_BLOCKED:
log_status("WAITING", thread_id, thread->second.task_id, thread->second.state);
thread->second.before_block();
do {
thread->second.wake_condition->wait(lock);
thread = threads.find(thread_id);
} while (thread != threads.end() && is_blocked(thread->second.state));
thread->second.after_block();
task_has_woken_condition.notify_all();
break;
case SHUFFLE_THROW:
Expand All @@ -636,10 +715,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
// check again to see if this was fixed or not.
check_and_update_for_bufn(lock);
log_status("WAITING", thread_id, thread->second.task_id, thread->second.state);
thread->second.before_block();
do {
thread->second.wake_condition->wait(lock);
thread = threads.find(thread_id);
} while (thread != threads.end() && is_blocked(thread->second.state));
thread->second.after_block();
task_has_woken_condition.notify_all();
break;
case TASK_SPLIT_THROW:
Expand Down Expand Up @@ -763,6 +844,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
if (thread != threads.end()) {
if (thread->second.retry_oom_injected > 0) {
thread->second.retry_oom_injected--;
thread->second.num_times_retry_throw++;
throw_java_exception(RETRY_OOM_CLASS, "injected RetryOOM");
}

Expand All @@ -773,6 +855,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {

if (thread->second.split_and_retry_oom_injected > 0) {
thread->second.split_and_retry_oom_injected--;
thread->second.num_times_split_retry_throw++;
throw_java_exception(SPLIT_AND_RETRY_OOM_CLASS, "injected SplitAndRetryOOM");
}

Expand Down Expand Up @@ -1271,4 +1354,39 @@ JNIEXPORT jint JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_get
}
CATCH_STD(env, 0)
}


JNIEXPORT jint JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_getAndResetRetryThrowInternal(
JNIEnv *env, jclass, jlong ptr, jlong task_id) {
JNI_NULL_CHECK(env, ptr, "resource_adaptor is null", 0);
try {
cudf::jni::auto_set_device(env);
auto mr = reinterpret_cast<spark_resource_adaptor *>(ptr);
return mr->get_n_reset_num_retry(task_id);
}
CATCH_STD(env, 0)
}

JNIEXPORT jint JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_getAndResetSplitRetryThrowInternal(
JNIEnv *env, jclass, jlong ptr, jlong task_id) {
JNI_NULL_CHECK(env, ptr, "resource_adaptor is null", 0);
try {
cudf::jni::auto_set_device(env);
auto mr = reinterpret_cast<spark_resource_adaptor *>(ptr);
return mr->get_n_reset_num_split_retry(task_id);
}
CATCH_STD(env, 0)
}

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_getAndResetBlockTimeInternal(
JNIEnv *env, jclass, jlong ptr, jlong task_id) {
JNI_NULL_CHECK(env, ptr, "resource_adaptor is null", 0);
try {
cudf::jni::auto_set_device(env);
auto mr = reinterpret_cast<spark_resource_adaptor *>(ptr);
return mr->get_n_reset_block_time(task_id);
}
CATCH_STD(env, 0)
}

}
48 changes: 48 additions & 0 deletions src/main/java/com/nvidia/spark/rapids/jni/RmmSpark.java
Original file line number Diff line number Diff line change
Expand Up @@ -333,4 +333,52 @@ public static RmmSparkThreadState getStateOf(long threadId) {
}
}
}

/**
* Get the number of retry exceptions that were thrown and reset the metric.
* @param taskId the id of the task to get the metric for.
* @return the number of times it was thrown or 0 if in the UNKNOWN state.
*/
public static int getAndResetNumRetryThrow(long taskId) {
synchronized (Rmm.class) {
if (sra != null && sra.isOpen()) {
return sra.getAndResetNumRetryThrow(taskId);
} else {
// sra is not set so the value is by definition 0
return 0;
}
}
}

/**
* Get the number of split and retry exceptions that were thrown and reset the metric.
* @param taskId the id of the task to get the metric for.
* @return the number of times it was thrown or 0 if in the UNKNOWN state.
*/
public static int getAndResetNumSplitRetryThrow(long taskId) {
synchronized (Rmm.class) {
if (sra != null && sra.isOpen()) {
return sra.getAndResetNumSplitRetryThrow(taskId);
} else {
// sra is not set so the value is by definition 0
return 0;
}
}
}

/**
* Get how long, in nanoseconds, that the task was blocked for
* @param taskId the id of the task to get the metric for.
* @return the time the task was blocked or 0 if in the UNKNOWN state.
*/
public static long getAndResetBlockTimeNs(long taskId) {
synchronized (Rmm.class) {
if (sra != null && sra.isOpen()) {
return sra.getAndResetBlockTime(taskId);
} else {
// sra is not set so the value is by definition 0
return 0;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ public RmmSparkThreadState getStateOf(long threadId) {
return RmmSparkThreadState.fromNativeId(getStateOf(getHandle(), threadId));
}

public int getAndResetNumRetryThrow(long taskId) {
return getAndResetRetryThrowInternal(getHandle(), taskId);
}

public int getAndResetNumSplitRetryThrow(long taskId) {
return getAndResetSplitRetryThrowInternal(getHandle(), taskId);
}

public long getAndResetBlockTime(long taskId) {
return getAndResetBlockTimeInternal(getHandle(), taskId);
}

/**
* Get the ID of the current thread that can be used with the other SparkResourceAdaptor APIs.
* Don't use the java thread ID. They are not related.
Expand All @@ -181,4 +193,8 @@ public RmmSparkThreadState getStateOf(long threadId) {
private static native void forceCudfException(long handle, long threadId, int numTimes);
private static native void blockThreadUntilReady(long handle);
private static native int getStateOf(long handle, long threadId);
private static native int getAndResetRetryThrowInternal(long handle, long taskId);
private static native int getAndResetSplitRetryThrowInternal(long handle, long taskId);
private static native long getAndResetBlockTimeInternal(long handle, long taskId);

}
7 changes: 7 additions & 0 deletions src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ public void testInsertOOMs() {
long threadId = RmmSpark.getCurrentThreadId();
long taskid = 0; // This is arbitrary
assertEquals(RmmSparkThreadState.UNKNOWN, RmmSpark.getStateOf(threadId));
assertEquals(0, RmmSpark.getAndResetNumRetryThrow(taskid));
assertEquals(0, RmmSpark.getAndResetNumSplitRetryThrow(taskid));
RmmSpark.associateThreadWithTask(threadId, taskid);
assertEquals(RmmSparkThreadState.TASK_RUNNING, RmmSpark.getStateOf(threadId));
try {
Expand All @@ -326,6 +328,8 @@ public void testInsertOOMs() {
// Verify that injecting OOM does not cause the block to actually happen or
// the state to change
assertEquals(RmmSparkThreadState.TASK_RUNNING, RmmSpark.getStateOf(threadId));
assertEquals(1, RmmSpark.getAndResetNumRetryThrow(taskid));
assertEquals(0, RmmSpark.getAndResetNumSplitRetryThrow(taskid));
RmmSpark.blockThreadUntilReady();

// Allocate something small and verify that it works...
Expand All @@ -337,6 +341,9 @@ public void testInsertOOMs() {
// No change in state after force
assertEquals(RmmSparkThreadState.TASK_RUNNING, RmmSpark.getStateOf(threadId));
assertThrows(SplitAndRetryOOM.class, () -> Rmm.alloc(100).close());
assertEquals(0, RmmSpark.getAndResetNumRetryThrow(taskid));
assertEquals(1, RmmSpark.getAndResetNumSplitRetryThrow(taskid));

// Verify that injecting OOM does not cause the block to actually happen
assertEquals(RmmSparkThreadState.TASK_RUNNING, RmmSpark.getStateOf(threadId));
RmmSpark.blockThreadUntilReady();
Expand Down

0 comments on commit 113f928

Please sign in to comment.