Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use cudf::distinct in Java binding #11232

Merged
merged 2 commits into from
Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -708,8 +708,7 @@ private static native long[] mixedLeftAntiJoinGatherMapWithSize(long leftKeysTab
private static native long[] filter(long input, long mask);

private static native long[] dropDuplicates(long nativeHandle, int[] keyColumns,
boolean keepFirst, boolean nullsEqual,
boolean nullsBefore) throws CudfException;
int keepValue, boolean nullsEqual) throws CudfException;

private static native long[] gather(long tableHandle, long gatherView, boolean checkBounds);

Expand Down Expand Up @@ -2029,28 +2028,39 @@ public Table filter(ColumnView mask) {
return new Table(filter(nativeHandle, mask.getNativeView()));
}

/**
* Enum to specify which of duplicate rows/elements will be copied to the output.
*/
public enum DuplicateKeepOption {
KEEP_ANY(0),
KEEP_FIRST(1),
KEEP_LAST(2),
KEEP_NONE(3);

final int keepValue;

DuplicateKeepOption(int keepValue) {
this.keepValue = keepValue;
}
}

/**
* Copy rows of the current table to an output table such that duplicate rows in the key columns
* are ignored (i.e., only one row from the duplicate ones will be copied). These keys columns are
* a subset of the current table columns and their indices are specified by an input array.
*
* Currently, the output table is sorted by key columns, using stable sort. However, this is not
* guaranteed in the future.
* The order of rows in the output table is not specified.
*
* @param keyColumns Array of indices representing key columns from the current table.
* @param keepFirst If it is true, the first row with a duplicated key will be copied. Otherwise,
* copy the last row with a duplicated key.
* @param keep Option specifying to keep any, first, last, or none of the found duplicates.
* @param nullsEqual Flag to denote whether nulls are treated as equal when comparing rows of the
* key columns to check for uniqueness.
* @param nullsBefore Flag to specify whether nulls in the key columns will appear before or
* after non-null elements when sorting the table.
*
* @return Table with unique keys.
*/
public Table dropDuplicates(int[] keyColumns, boolean keepFirst, boolean nullsEqual,
boolean nullsBefore) {
public Table dropDuplicates(int[] keyColumns, DuplicateKeepOption keep, boolean nullsEqual) {
assert keyColumns.length >= 1 : "Input keyColumns must contain indices of at least one column";
return new Table(dropDuplicates(nativeHandle, keyColumns, keepFirst, nullsEqual, nullsBefore));
return new Table(dropDuplicates(nativeHandle, keyColumns, keep.keepValue, nullsEqual));
}

/**
Expand Down
36 changes: 19 additions & 17 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3053,9 +3053,11 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_filter(JNIEnv *env, jclas
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_dropDuplicates(
JNIEnv *env, jclass, jlong input_jtable, jintArray key_columns, jboolean keep_first,
jboolean nulls_equal, jboolean nulls_before) {
JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_dropDuplicates(JNIEnv *env, jclass,
jlong input_jtable,
jintArray key_columns,
jint keep,
jboolean nulls_equal) {
JNI_NULL_CHECK(env, input_jtable, "input table is null", 0);
JNI_NULL_CHECK(env, key_columns, "input key_columns is null", 0);
try {
Expand All @@ -3066,22 +3068,22 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_dropDuplicates(
auto const native_keys_indices = cudf::jni::native_jintArray(env, key_columns);
auto const keys_indices =
std::vector<cudf::size_type>(native_keys_indices.begin(), native_keys_indices.end());

// cudf::unique keeps unique rows in each consecutive group of equivalent rows. To match the
// behavior of pandas.DataFrame.drop_duplicates, users need to stable sort the input first and
// then invoke cudf::unique.
std::vector<cudf::order> order(keys_indices.size(), cudf::order::ASCENDING);
std::vector<cudf::null_order> null_precedence(
keys_indices.size(), nulls_before ? cudf::null_order::BEFORE : cudf::null_order::AFTER);
auto const sorted_input =
cudf::stable_sort_by_key(*input, input->select(keys_indices), order, null_precedence);
auto const keep_option = [&] {
switch (keep) {
case 0: return cudf::duplicate_keep_option::KEEP_ANY;
case 1: return cudf::duplicate_keep_option::KEEP_FIRST;
case 2: return cudf::duplicate_keep_option::KEEP_LAST;
case 3: return cudf::duplicate_keep_option::KEEP_NONE;
default:
JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Invalid `keep` option",
cudf::duplicate_keep_option::KEEP_ANY);
}
}();

auto result =
cudf::unique(sorted_input->view(), keys_indices,
keep_first ? cudf::duplicate_keep_option::KEEP_FIRST :
cudf::duplicate_keep_option::KEEP_LAST,
nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL,
rmm::mr::get_current_device_resource());
cudf::distinct(*input, keys_indices, keep_option,
nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL,
cudf::nan_equality::ALL_EQUAL, rmm::mr::get_current_device_resource());
return convert_table_for_return(env, result);
}
CATCH_STD(env, 0);
Expand Down
16 changes: 9 additions & 7 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5347,11 +5347,11 @@ void testWindowingNthElement() {
ColumnVector expect_first = ColumnVector.fromBoxedInts(7, 7, 5, 3, 7, 7, 9, X, X, X, 0, 4);
ColumnVector expect_last = ColumnVector.fromBoxedInts(5, 3, 7, 7, 9, X, 4, 4, 0, 4, X, X);
ColumnVector expect_1th = ColumnVector.fromBoxedInts(5, 5, 3, 7, 9, 9, X, 4, 0, 0, 4, X);
ColumnVector expect_first_skip_null =
ColumnVector expect_first_skip_null =
ColumnVector.fromBoxedInts(7, 7, 5, 3, 7, 7, 9, 4, 0, 0, 0, 4);
ColumnVector expect_last_skip_null =
ColumnVector expect_last_skip_null =
ColumnVector.fromBoxedInts(5, 3, 7, 7, 9, 9, 4, 4, 0, 4, 4, 4);
ColumnVector expect_1th_skip_null =
ColumnVector expect_1th_skip_null =
ColumnVector.fromBoxedInts(5, 5, 3, 7, 9, 9, 4, X, X, 4, 4, X)) {
assertColumnsAreEqual(expect_first, windowAggResults.getColumn(0));
assertColumnsAreEqual(expect_last, windowAggResults.getColumn(1));
Expand Down Expand Up @@ -7510,19 +7510,21 @@ void testDropDuplicates() {
Table input = new Table(col1, col2)) {

// Keep the first duplicate element.
try (Table result = input.dropDuplicates(keyColumns, true, true, true);
try (Table result = input.dropDuplicates(keyColumns, Table.DuplicateKeepOption.KEEP_FIRST, true);
Table resultSorted = result.orderBy(OrderByArg.asc(1, true));
ColumnVector expectedCol1 = ColumnVector.fromBoxedInts(null, 5, 5, 8);
ColumnVector expectedCol2 = ColumnVector.fromBoxedInts(null, 19, 20, 21);
Table expected = new Table(expectedCol1, expectedCol2)) {
assertTablesAreEqual(expected, result);
assertTablesAreEqual(expected, resultSorted);
}

// Keep the last duplicate element.
try (Table result = input.dropDuplicates(keyColumns, false, true, true);
try (Table result = input.dropDuplicates(keyColumns, Table.DuplicateKeepOption.KEEP_LAST, true);
Table resultSorted = result.orderBy(OrderByArg.asc(1, true));
ColumnVector expectedCol1 = ColumnVector.fromBoxedInts(3, 1, 5, 8);
ColumnVector expectedCol2 = ColumnVector.fromBoxedInts(null, 19, 20, 21);
Table expected = new Table(expectedCol1, expectedCol2)) {
assertTablesAreEqual(expected, result);
assertTablesAreEqual(expected, resultSorted);
}
}
}
Expand Down