diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 24f7d44ed28..db90c09a078 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -708,8 +708,7 @@ private static native long[] mixedLeftAntiJoinGatherMapWithSize(long leftKeysTab private static native long[] filter(long input, long mask); private static native long[] dropDuplicates(long nativeHandle, int[] keyColumns, - boolean keepFirst, boolean nullsEqual, - boolean nullsBefore) throws CudfException; + int keepValue, boolean nullsEqual) throws CudfException; private static native long[] gather(long tableHandle, long gatherView, boolean checkBounds); @@ -2029,28 +2028,39 @@ public Table filter(ColumnView mask) { return new Table(filter(nativeHandle, mask.getNativeView())); } + /** + * Enum to specify which of duplicate rows/elements will be copied to the output. + */ + public enum DuplicateKeepOption { + KEEP_ANY(0), + KEEP_FIRST(1), + KEEP_LAST(2), + KEEP_NONE(3); + + final int keepValue; + + DuplicateKeepOption(int keepValue) { + this.keepValue = keepValue; + } + } + /** * Copy rows of the current table to an output table such that duplicate rows in the key columns * are ignored (i.e., only one row from the duplicate ones will be copied). These keys columns are * a subset of the current table columns and their indices are specified by an input array. * - * Currently, the output table is sorted by key columns, using stable sort. However, this is not - * guaranteed in the future. + * The order of rows in the output table is not specified. * * @param keyColumns Array of indices representing key columns from the current table. - * @param keepFirst If it is true, the first row with a duplicated key will be copied. Otherwise, - * copy the last row with a duplicated key. + * @param keep Option specifying to keep any, first, last, or none of the found duplicates. * @param nullsEqual Flag to denote whether nulls are treated as equal when comparing rows of the * key columns to check for uniqueness. - * @param nullsBefore Flag to specify whether nulls in the key columns will appear before or - * after non-null elements when sorting the table. * * @return Table with unique keys. */ - public Table dropDuplicates(int[] keyColumns, boolean keepFirst, boolean nullsEqual, - boolean nullsBefore) { + public Table dropDuplicates(int[] keyColumns, DuplicateKeepOption keep, boolean nullsEqual) { assert keyColumns.length >= 1 : "Input keyColumns must contain indices of at least one column"; - return new Table(dropDuplicates(nativeHandle, keyColumns, keepFirst, nullsEqual, nullsBefore)); + return new Table(dropDuplicates(nativeHandle, keyColumns, keep.keepValue, nullsEqual)); } /** diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 6c6f73edfe8..4bdd54640d6 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -3053,9 +3053,11 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_filter(JNIEnv *env, jclas CATCH_STD(env, 0); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_dropDuplicates( - JNIEnv *env, jclass, jlong input_jtable, jintArray key_columns, jboolean keep_first, - jboolean nulls_equal, jboolean nulls_before) { +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_dropDuplicates(JNIEnv *env, jclass, + jlong input_jtable, + jintArray key_columns, + jint keep, + jboolean nulls_equal) { JNI_NULL_CHECK(env, input_jtable, "input table is null", 0); JNI_NULL_CHECK(env, key_columns, "input key_columns is null", 0); try { @@ -3066,22 +3068,22 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_dropDuplicates( auto const native_keys_indices = cudf::jni::native_jintArray(env, key_columns); auto const keys_indices = std::vector(native_keys_indices.begin(), native_keys_indices.end()); - - // cudf::unique keeps unique rows in each consecutive group of equivalent rows. To match the - // behavior of pandas.DataFrame.drop_duplicates, users need to stable sort the input first and - // then invoke cudf::unique. - std::vector order(keys_indices.size(), cudf::order::ASCENDING); - std::vector null_precedence( - keys_indices.size(), nulls_before ? cudf::null_order::BEFORE : cudf::null_order::AFTER); - auto const sorted_input = - cudf::stable_sort_by_key(*input, input->select(keys_indices), order, null_precedence); + auto const keep_option = [&] { + switch (keep) { + case 0: return cudf::duplicate_keep_option::KEEP_ANY; + case 1: return cudf::duplicate_keep_option::KEEP_FIRST; + case 2: return cudf::duplicate_keep_option::KEEP_LAST; + case 3: return cudf::duplicate_keep_option::KEEP_NONE; + default: + JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Invalid `keep` option", + cudf::duplicate_keep_option::KEEP_ANY); + } + }(); auto result = - cudf::unique(sorted_input->view(), keys_indices, - keep_first ? cudf::duplicate_keep_option::KEEP_FIRST : - cudf::duplicate_keep_option::KEEP_LAST, - nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL, - rmm::mr::get_current_device_resource()); + cudf::distinct(*input, keys_indices, keep_option, + nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL, + cudf::nan_equality::ALL_EQUAL, rmm::mr::get_current_device_resource()); return convert_table_for_return(env, result); } CATCH_STD(env, 0); diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 557bd8f289a..639d498d2f3 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -5347,11 +5347,11 @@ void testWindowingNthElement() { ColumnVector expect_first = ColumnVector.fromBoxedInts(7, 7, 5, 3, 7, 7, 9, X, X, X, 0, 4); ColumnVector expect_last = ColumnVector.fromBoxedInts(5, 3, 7, 7, 9, X, 4, 4, 0, 4, X, X); ColumnVector expect_1th = ColumnVector.fromBoxedInts(5, 5, 3, 7, 9, 9, X, 4, 0, 0, 4, X); - ColumnVector expect_first_skip_null = + ColumnVector expect_first_skip_null = ColumnVector.fromBoxedInts(7, 7, 5, 3, 7, 7, 9, 4, 0, 0, 0, 4); - ColumnVector expect_last_skip_null = + ColumnVector expect_last_skip_null = ColumnVector.fromBoxedInts(5, 3, 7, 7, 9, 9, 4, 4, 0, 4, 4, 4); - ColumnVector expect_1th_skip_null = + ColumnVector expect_1th_skip_null = ColumnVector.fromBoxedInts(5, 5, 3, 7, 9, 9, 4, X, X, 4, 4, X)) { assertColumnsAreEqual(expect_first, windowAggResults.getColumn(0)); assertColumnsAreEqual(expect_last, windowAggResults.getColumn(1)); @@ -7510,19 +7510,21 @@ void testDropDuplicates() { Table input = new Table(col1, col2)) { // Keep the first duplicate element. - try (Table result = input.dropDuplicates(keyColumns, true, true, true); + try (Table result = input.dropDuplicates(keyColumns, Table.DuplicateKeepOption.KEEP_FIRST, true); + Table resultSorted = result.orderBy(OrderByArg.asc(1, true)); ColumnVector expectedCol1 = ColumnVector.fromBoxedInts(null, 5, 5, 8); ColumnVector expectedCol2 = ColumnVector.fromBoxedInts(null, 19, 20, 21); Table expected = new Table(expectedCol1, expectedCol2)) { - assertTablesAreEqual(expected, result); + assertTablesAreEqual(expected, resultSorted); } // Keep the last duplicate element. - try (Table result = input.dropDuplicates(keyColumns, false, true, true); + try (Table result = input.dropDuplicates(keyColumns, Table.DuplicateKeepOption.KEEP_LAST, true); + Table resultSorted = result.orderBy(OrderByArg.asc(1, true)); ColumnVector expectedCol1 = ColumnVector.fromBoxedInts(3, 1, 5, 8); ColumnVector expectedCol2 = ColumnVector.fromBoxedInts(null, 19, 20, 21); Table expected = new Table(expectedCol1, expectedCol2)) { - assertTablesAreEqual(expected, result); + assertTablesAreEqual(expected, resultSorted); } } }