From 3b37f06d045609c7f888d20fa87ca4e814521d83 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Wed, 17 Nov 2021 17:07:29 +0800 Subject: [PATCH 1/7] Add sample JNI Signed-off-by: Chong Gao --- java/src/main/java/ai/rapids/cudf/Table.java | 15 +++++++++++++++ java/src/main/native/src/TableJni.cpp | 14 ++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 2744728fb44..c605a7bd18f 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -656,6 +656,8 @@ private static native ContiguousTable[] contiguousSplitGroups(long inputTable, boolean[] keysDescending, boolean[] keysNullSmallest); + private static native long[] sample(long tableHandle, long n, boolean replacement, long seed); + ///////////////////////////////////////////////////////////////////////////// // TABLE CREATION APIs ///////////////////////////////////////////////////////////////////////////// @@ -2743,6 +2745,19 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data return result; } + /** + * Gather `n` samples from table randomly + * The output is not same with CPU Sample exec, but this is faster. + * + * @param n + * @param replacement Allow or disallow sampling of the same row more than once. + * @param seed Seed value to initiate random number generator. + * @return + */ + public Table sample(long n, boolean replacement, long seed) { + return new Table(sample(nativeHandle, n, replacement, seed)); + } + ///////////////////////////////////////////////////////////////////////////// // HELPER CLASSES ///////////////////////////////////////////////////////////////////////////// diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 96dd02e5f2a..2718d506ada 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -39,6 +39,7 @@ #include #include #include +#include #include #include "cudf_jni_apis.hpp" @@ -3145,4 +3146,17 @@ JNIEXPORT jobjectArray JNICALL Java_ai_rapids_cudf_Table_contiguousSplitGroups( CATCH_STD(env, NULL); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_sample(JNIEnv *env, jclass, jlong j_input, + jlong n, jboolean replacement, jlong seed) { + JNI_NULL_CHECK(env, j_input, "input table is null", 0); + try { + cudf::jni::auto_set_device(env); + cudf::table_view *input = reinterpret_cast(j_input); + auto sample_with_replacement = + replacement ? cudf::sample_with_replacement::TRUE : cudf::sample_with_replacement::FALSE; + std::unique_ptr result = cudf::sample(*input, n, sample_with_replacement, seed); + return cudf::jni::convert_table_for_return(env, result); + } + CATCH_STD(env, 0); +} } // extern "C" From c5ec85843578188007887907379b350343f03014 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Fri, 19 Nov 2021 11:07:21 +0800 Subject: [PATCH 2/7] Add comment Signed-off-by: Chong Gao --- java/src/main/java/ai/rapids/cudf/Table.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index c605a7bd18f..794784f130f 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -2749,7 +2749,7 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data * Gather `n` samples from table randomly * The output is not same with CPU Sample exec, but this is faster. * - * @param n + * @param n non-negative number of samples expected * @param replacement Allow or disallow sampling of the same row more than once. * @param seed Seed value to initiate random number generator. * @return From 37e9b5797c5be2737750319d1baa39c61bc3af65 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Fri, 19 Nov 2021 14:18:01 +0800 Subject: [PATCH 3/7] Update comment --- java/src/main/java/ai/rapids/cudf/Table.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 794784f130f..6ef06667e90 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -2747,7 +2747,7 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data /** * Gather `n` samples from table randomly - * The output is not same with CPU Sample exec, but this is faster. + * The output is not the same as what CPU Sample Exec produces, but this is faster. * * @param n non-negative number of samples expected * @param replacement Allow or disallow sampling of the same row more than once. From 6fe77c215622f1abe7e21beebd8474cea479297f Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Fri, 19 Nov 2021 19:26:13 +0800 Subject: [PATCH 4/7] Add a test case for sample Signed-off-by: Chong Gao --- .../src/test/java/ai/rapids/cudf/TableTest.java | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index cd1e433d07b..2846579b1ea 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -7291,4 +7291,21 @@ void testExplodeOuterPosition() { } } } + + @Test + void testSample() { + Table t = new Table.TestBuilder().column("s1", "s2", "s3", "s4", "s5").build(); + Table ret = t.sample(3, false, 0); + assertTrue(ret.getRowCount() == 3L); + + assertEquals("s3", ret.getColumn(0).getScalarElement(0).getJavaString()); + assertEquals("s4", ret.getColumn(0).getScalarElement(1).getJavaString()); + assertEquals("s5", ret.getColumn(0).getScalarElement(2).getJavaString()); + + ret = t.sample(100, true, 0); + assertTrue(ret.getRowCount() == 100L); + + ret = t.sample(4, true, 0); + assertTrue(ret.getRowCount() == 4L); + } } From 6205baa976cd05549458587c5a35a9e2b556848f Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Mon, 22 Nov 2021 19:45:50 +0800 Subject: [PATCH 5/7] Update comments and test cases Signed-off-by: Chong Gao --- java/src/main/java/ai/rapids/cudf/Table.java | 22 +++++++++++++--- .../test/java/ai/rapids/cudf/TableTest.java | 26 +++++++++++-------- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 6ef06667e90..8905097a5b0 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -2745,14 +2745,30 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data return result; } + /** * Gather `n` samples from table randomly - * The output is not the same as what CPU Sample Exec produces, but this is faster. + * Note: does not preserve the ordering + * Example: + * input: {col1: {1, 2, 3, 4, 5}, col2: {6, 7, 8, 9, 10}} + * n: 3 + * replacement: false + * + * output: {col1: {3, 1, 4}, col2: {8, 6, 9}} + * + * replacement: true * - * @param n non-negative number of samples expected + * output: {col1: {3, 1, 1}, col2: {8, 6, 6}} + * @endcode + * + * @throws "logic_error" if `n` > `input.num_rows()` and `replacement` == FALSE. + * @throws "logic_error" if `n` < 0. + * + * @param n non-negative number of samples expected from table * @param replacement Allow or disallow sampling of the same row more than once. * @param seed Seed value to initiate random number generator. - * @return + * + * @return Table containing samples */ public Table sample(long n, boolean replacement, long seed) { return new Table(sample(nativeHandle, n, replacement, seed)); diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 2846579b1ea..9987dafbd0e 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -7294,18 +7294,22 @@ void testExplodeOuterPosition() { @Test void testSample() { - Table t = new Table.TestBuilder().column("s1", "s2", "s3", "s4", "s5").build(); - Table ret = t.sample(3, false, 0); - assertTrue(ret.getRowCount() == 3L); - - assertEquals("s3", ret.getColumn(0).getScalarElement(0).getJavaString()); - assertEquals("s4", ret.getColumn(0).getScalarElement(1).getJavaString()); - assertEquals("s5", ret.getColumn(0).getScalarElement(2).getJavaString()); + try (Table t = new Table.TestBuilder().column("s1", "s2", "s3", "s4", "s5").build()) { + try (Table ret = t.sample(3, false, 0); + Table expected = new Table.TestBuilder().column("s3", "s4", "s5").build()) { + assertTablesAreEqual(expected, ret); + } - ret = t.sample(100, true, 0); - assertTrue(ret.getRowCount() == 100L); + try (Table ret = t.sample(5, false, 0); + Table expected = new Table.TestBuilder().column("s3", "s4", "s5", "s2", "s1").build()) { + assertTablesAreEqual(expected, ret); + } - ret = t.sample(4, true, 0); - assertTrue(ret.getRowCount() == 4L); + try (Table ret = t.sample(8, true, 0); + Table expected = new Table.TestBuilder() + .column("s1", "s1", "s4", "s5", "s5", "s1", "s3", "s2").build()) { + assertTablesAreEqual(expected, ret); + } + } } } From 3dbb13951c603b15b6a07afc630dcd0ade869bbd Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Tue, 23 Nov 2021 09:54:41 +0800 Subject: [PATCH 6/7] Update comments --- java/src/main/java/ai/rapids/cudf/Table.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 8905097a5b0..2cdaecd6b97 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -2759,10 +2759,9 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data * replacement: true * * output: {col1: {3, 1, 1}, col2: {8, 6, 6}} - * @endcode * - * @throws "logic_error" if `n` > `input.num_rows()` and `replacement` == FALSE. - * @throws "logic_error" if `n` < 0. + * throws "logic_error" if `n` > table rows and `replacement` == FALSE. + * throws "logic_error" if `n` < 0. * * @param n non-negative number of samples expected from table * @param replacement Allow or disallow sampling of the same row more than once. From 6282f4b556e954569c1488fd43b0785025715931 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Thu, 2 Dec 2021 17:58:58 +0800 Subject: [PATCH 7/7] Format Signed-off-by: Chong Gao --- java/src/main/native/src/TableJni.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 2718d506ada..4fc005ef5da 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -39,7 +40,6 @@ #include #include #include -#include #include #include "cudf_jni_apis.hpp" @@ -3147,7 +3147,8 @@ JNIEXPORT jobjectArray JNICALL Java_ai_rapids_cudf_Table_contiguousSplitGroups( } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_sample(JNIEnv *env, jclass, jlong j_input, - jlong n, jboolean replacement, jlong seed) { + jlong n, jboolean replacement, + jlong seed) { JNI_NULL_CHECK(env, j_input, "input table is null", 0); try { cudf::jni::auto_set_device(env);