Skip to content

Commit

Permalink
Use cudf::distinct in Java binding (NVIDIA#11232)
Browse files Browse the repository at this point in the history
Java binding has `dropDuplicates` API to remove duplicate rows from a table. Previously it has been implemented by sorting the table then calling to `cudf::unique`. This PR changes the implementation to use `cudf::distinct` directly, which can significantly improve performance by avoiding sorting the input table.

Authors:
  - Nghia Truong (https://github.com/ttnghia)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)
  - MithunR (https://github.com/mythrocks)

URL: rapidsai/cudf#11232
  • Loading branch information
ttnghia authored Jul 8, 2022
1 parent 4b70c37 commit 89a8e70
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 35 deletions.
32 changes: 21 additions & 11 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -708,8 +708,7 @@ private static native long[] mixedLeftAntiJoinGatherMapWithSize(long leftKeysTab
private static native long[] filter(long input, long mask);

private static native long[] dropDuplicates(long nativeHandle, int[] keyColumns,
boolean keepFirst, boolean nullsEqual,
boolean nullsBefore) throws CudfException;
int keepValue, boolean nullsEqual) throws CudfException;

private static native long[] gather(long tableHandle, long gatherView, boolean checkBounds);

Expand Down Expand Up @@ -2029,28 +2028,39 @@ public Table filter(ColumnView mask) {
return new Table(filter(nativeHandle, mask.getNativeView()));
}

/**
* Enum to specify which of duplicate rows/elements will be copied to the output.
*/
public enum DuplicateKeepOption {
KEEP_ANY(0),
KEEP_FIRST(1),
KEEP_LAST(2),
KEEP_NONE(3);

final int keepValue;

DuplicateKeepOption(int keepValue) {
this.keepValue = keepValue;
}
}

/**
* Copy rows of the current table to an output table such that duplicate rows in the key columns
* are ignored (i.e., only one row from the duplicate ones will be copied). These keys columns are
* a subset of the current table columns and their indices are specified by an input array.
*
* Currently, the output table is sorted by key columns, using stable sort. However, this is not
* guaranteed in the future.
* The order of rows in the output table is not specified.
*
* @param keyColumns Array of indices representing key columns from the current table.
* @param keepFirst If it is true, the first row with a duplicated key will be copied. Otherwise,
* copy the last row with a duplicated key.
* @param keep Option specifying to keep any, first, last, or none of the found duplicates.
* @param nullsEqual Flag to denote whether nulls are treated as equal when comparing rows of the
* key columns to check for uniqueness.
* @param nullsBefore Flag to specify whether nulls in the key columns will appear before or
* after non-null elements when sorting the table.
*
* @return Table with unique keys.
*/
public Table dropDuplicates(int[] keyColumns, boolean keepFirst, boolean nullsEqual,
boolean nullsBefore) {
public Table dropDuplicates(int[] keyColumns, DuplicateKeepOption keep, boolean nullsEqual) {
assert keyColumns.length >= 1 : "Input keyColumns must contain indices of at least one column";
return new Table(dropDuplicates(nativeHandle, keyColumns, keepFirst, nullsEqual, nullsBefore));
return new Table(dropDuplicates(nativeHandle, keyColumns, keep.keepValue, nullsEqual));
}

/**
Expand Down
36 changes: 19 additions & 17 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3053,9 +3053,11 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_filter(JNIEnv *env, jclas
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_dropDuplicates(
JNIEnv *env, jclass, jlong input_jtable, jintArray key_columns, jboolean keep_first,
jboolean nulls_equal, jboolean nulls_before) {
JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_dropDuplicates(JNIEnv *env, jclass,
jlong input_jtable,
jintArray key_columns,
jint keep,
jboolean nulls_equal) {
JNI_NULL_CHECK(env, input_jtable, "input table is null", 0);
JNI_NULL_CHECK(env, key_columns, "input key_columns is null", 0);
try {
Expand All @@ -3066,22 +3068,22 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_dropDuplicates(
auto const native_keys_indices = cudf::jni::native_jintArray(env, key_columns);
auto const keys_indices =
std::vector<cudf::size_type>(native_keys_indices.begin(), native_keys_indices.end());

// cudf::unique keeps unique rows in each consecutive group of equivalent rows. To match the
// behavior of pandas.DataFrame.drop_duplicates, users need to stable sort the input first and
// then invoke cudf::unique.
std::vector<cudf::order> order(keys_indices.size(), cudf::order::ASCENDING);
std::vector<cudf::null_order> null_precedence(
keys_indices.size(), nulls_before ? cudf::null_order::BEFORE : cudf::null_order::AFTER);
auto const sorted_input =
cudf::stable_sort_by_key(*input, input->select(keys_indices), order, null_precedence);
auto const keep_option = [&] {
switch (keep) {
case 0: return cudf::duplicate_keep_option::KEEP_ANY;
case 1: return cudf::duplicate_keep_option::KEEP_FIRST;
case 2: return cudf::duplicate_keep_option::KEEP_LAST;
case 3: return cudf::duplicate_keep_option::KEEP_NONE;
default:
JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Invalid `keep` option",
cudf::duplicate_keep_option::KEEP_ANY);
}
}();

auto result =
cudf::unique(sorted_input->view(), keys_indices,
keep_first ? cudf::duplicate_keep_option::KEEP_FIRST :
cudf::duplicate_keep_option::KEEP_LAST,
nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL,
rmm::mr::get_current_device_resource());
cudf::distinct(*input, keys_indices, keep_option,
nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL,
cudf::nan_equality::ALL_EQUAL, rmm::mr::get_current_device_resource());
return convert_table_for_return(env, result);
}
CATCH_STD(env, 0);
Expand Down
16 changes: 9 additions & 7 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5347,11 +5347,11 @@ void testWindowingNthElement() {
ColumnVector expect_first = ColumnVector.fromBoxedInts(7, 7, 5, 3, 7, 7, 9, X, X, X, 0, 4);
ColumnVector expect_last = ColumnVector.fromBoxedInts(5, 3, 7, 7, 9, X, 4, 4, 0, 4, X, X);
ColumnVector expect_1th = ColumnVector.fromBoxedInts(5, 5, 3, 7, 9, 9, X, 4, 0, 0, 4, X);
ColumnVector expect_first_skip_null =
ColumnVector expect_first_skip_null =
ColumnVector.fromBoxedInts(7, 7, 5, 3, 7, 7, 9, 4, 0, 0, 0, 4);
ColumnVector expect_last_skip_null =
ColumnVector expect_last_skip_null =
ColumnVector.fromBoxedInts(5, 3, 7, 7, 9, 9, 4, 4, 0, 4, 4, 4);
ColumnVector expect_1th_skip_null =
ColumnVector expect_1th_skip_null =
ColumnVector.fromBoxedInts(5, 5, 3, 7, 9, 9, 4, X, X, 4, 4, X)) {
assertColumnsAreEqual(expect_first, windowAggResults.getColumn(0));
assertColumnsAreEqual(expect_last, windowAggResults.getColumn(1));
Expand Down Expand Up @@ -7510,19 +7510,21 @@ void testDropDuplicates() {
Table input = new Table(col1, col2)) {

// Keep the first duplicate element.
try (Table result = input.dropDuplicates(keyColumns, true, true, true);
try (Table result = input.dropDuplicates(keyColumns, Table.DuplicateKeepOption.KEEP_FIRST, true);
Table resultSorted = result.orderBy(OrderByArg.asc(1, true));
ColumnVector expectedCol1 = ColumnVector.fromBoxedInts(null, 5, 5, 8);
ColumnVector expectedCol2 = ColumnVector.fromBoxedInts(null, 19, 20, 21);
Table expected = new Table(expectedCol1, expectedCol2)) {
assertTablesAreEqual(expected, result);
assertTablesAreEqual(expected, resultSorted);
}

// Keep the last duplicate element.
try (Table result = input.dropDuplicates(keyColumns, false, true, true);
try (Table result = input.dropDuplicates(keyColumns, Table.DuplicateKeepOption.KEEP_LAST, true);
Table resultSorted = result.orderBy(OrderByArg.asc(1, true));
ColumnVector expectedCol1 = ColumnVector.fromBoxedInts(3, 1, 5, 8);
ColumnVector expectedCol2 = ColumnVector.fromBoxedInts(null, 19, 20, 21);
Table expected = new Table(expectedCol1, expectedCol2)) {
assertTablesAreEqual(expected, result);
assertTablesAreEqual(expected, resultSorted);
}
}
}
Expand Down

0 comments on commit 89a8e70

Please sign in to comment.