From 11e021d6aad013f9c43b8697dccf10d628f14cb6 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Thu, 24 Jun 2021 13:48:09 -0500 Subject: [PATCH] Add in JNI APIs for scan, replace_nulls, group_by.scan, and group_by.replace_nulls (#8503) To be able to do a running window test prototype I added in APIs for `scan`, `group_by.scan`, and `group_by.replace_nulls`. I also added a version of `replace_nulls` that java was missing. It is still not decided exactly how we are going to support running windows, but I thought I should get these in in case we do want to use them. Authors: - Robert (Bobby) Evans (https://github.com/revans2) Approvers: - Jason Lowe (https://github.com/jlowe) URL: https://github.com/rapidsai/cudf/pull/8503 --- .../main/java/ai/rapids/cudf/ColumnView.java | 56 ++++++-- .../main/java/ai/rapids/cudf/NaNEquality.java | 12 +- .../java/ai/rapids/cudf/NullEquality.java | 4 +- .../main/java/ai/rapids/cudf/NullPolicy.java | 4 +- .../java/ai/rapids/cudf/ReplacePolicy.java | 46 +++++++ .../rapids/cudf/ReplacePolicyWithColumn.java | 46 +++++++ .../main/java/ai/rapids/cudf/ScanType.java | 39 ++++++ java/src/main/java/ai/rapids/cudf/Table.java | 110 +++++++++++++-- java/src/main/native/CMakeLists.txt | 1 - java/src/main/native/src/ColumnViewJni.cpp | 51 +++++-- java/src/main/native/src/TableJni.cpp | 130 ++++++++++++++++++ java/src/main/native/src/prefix_sum.cu | 48 ------- java/src/main/native/src/prefix_sum.hpp | 36 ----- .../java/ai/rapids/cudf/ColumnVectorTest.java | 127 +++++++++++++++-- .../test/java/ai/rapids/cudf/TableTest.java | 54 ++++++++ 15 files changed, 629 insertions(+), 135 deletions(-) create mode 100644 java/src/main/java/ai/rapids/cudf/ReplacePolicy.java create mode 100644 java/src/main/java/ai/rapids/cudf/ReplacePolicyWithColumn.java create mode 100644 java/src/main/java/ai/rapids/cudf/ScanType.java delete mode 100644 java/src/main/native/src/prefix_sum.cu delete mode 100644 java/src/main/native/src/prefix_sum.hpp diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 23b0bd1560b..7912a525597 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -412,6 +412,10 @@ public final ColumnVector replaceNulls(ColumnView replacements) { return new ColumnVector(replaceNullsColumn(getNativeView(), replacements.getNativeView())); } + public final ColumnVector replaceNulls(ReplacePolicy policy) { + return new ColumnVector(replaceNullsPolicy(getNativeView(), policy.isPreceding)); + } + /** * For a BOOL8 vector, computes a vector whose rows are selected from two other vectors * based on the boolean value of this vector in the corresponding row. @@ -1384,17 +1388,50 @@ public final ColumnVector rollingWindow(RollingAggregation op, WindowOptions opt } /** - * Compute the cumulative sum/prefix sum of the values in this column. - * This is similar to a rolling window SUM with unbounded preceding and none following. - * Input values 1, 2, 3 - * Output values 1, 3, 6 - * This currently only works for long values that are not nullable as this is currently a - * very simple implementation. It may be expanded in the future if needed. + * Compute the prefix sum (aka cumulative sum) of the values in this column. + * This is just a convenience method for an inclusive scan with a SUM aggregation. */ public final ColumnVector prefixSum() { - return new ColumnVector(prefixSum(getNativeView())); + return scan(Aggregation.sum()); + } + + /** + * Computes a scan for a column. This is very similar to a running window on the column. + * @param aggregation the aggregation to perform + * @param scanType should the scan be inclusive, include the current row, or exclusive. + * @param nullPolicy how should nulls be treated. Note that some aggregations also include a + * null policy too. Currently none of those aggregations are supported so + * it is undefined how they would interact with each other. + */ + public final ColumnVector scan(Aggregation aggregation, ScanType scanType, NullPolicy nullPolicy) { + long nativeId = aggregation.createNativeInstance(); + try { + return new ColumnVector(scan(getNativeView(), nativeId, + scanType.isInclusive, nullPolicy.includeNulls)); + } finally { + Aggregation.close(nativeId); + } } + /** + * Computes a scan for a column that excludes nulls. + * @param aggregation the aggregation to perform + * @param scanType should the scan be inclusive, include the current row, or exclusive. + */ + public final ColumnVector scan(Aggregation aggregation, ScanType scanType) { + return scan(aggregation, scanType, NullPolicy.EXCLUDE); + } + + /** + * Computes an inclusive scan for a column that excludes nulls. + * @param aggregation the aggregation to perform + */ + public final ColumnVector scan(Aggregation aggregation) { + return scan(aggregation, ScanType.INCLUSIVE, NullPolicy.EXCLUDE); + } + + + ///////////////////////////////////////////////////////////////////////////// // LOGICAL ///////////////////////////////////////////////////////////////////////////// @@ -3217,7 +3254,8 @@ private static native long rollingWindow( long preceding_col, long following_col); - private static native long prefixSum(long viewHandle) throws CudfException; + private static native long scan(long viewHandle, long aggregation, + boolean isInclusive, boolean includeNulls) throws CudfException; private static native long nansToNulls(long viewHandle) throws CudfException; @@ -3227,6 +3265,8 @@ private static native long rollingWindow( private static native long replaceNullsColumn(long viewHandle, long replaceViewHandle) throws CudfException; + private static native long replaceNullsPolicy(long nativeView, boolean isPreceding) throws CudfException; + private static native long ifElseVV(long predVec, long trueVec, long falseVec) throws CudfException; private static native long ifElseVS(long predVec, long trueVec, long falseScalar) throws CudfException; diff --git a/java/src/main/java/ai/rapids/cudf/NaNEquality.java b/java/src/main/java/ai/rapids/cudf/NaNEquality.java index b135bc63007..11f34cd1c18 100644 --- a/java/src/main/java/ai/rapids/cudf/NaNEquality.java +++ b/java/src/main/java/ai/rapids/cudf/NaNEquality.java @@ -18,11 +18,19 @@ package ai.rapids.cudf; -/* - * This is analogous to the native 'nan_equality'. +/** + * How should NaNs be compared in an operation. In floating point there are multiple + * different binary representations for NaN. */ public enum NaNEquality { + /** + * No NaN representation is considered equal to any NaN representation, even for the + * exact same representation. + */ UNEQUAL(false), + /** + * All representations of NaN are considered to be equal. + */ ALL_EQUAL(true); NaNEquality(boolean nansEqual) { diff --git a/java/src/main/java/ai/rapids/cudf/NullEquality.java b/java/src/main/java/ai/rapids/cudf/NullEquality.java index d1e97f2cd32..657600d570b 100644 --- a/java/src/main/java/ai/rapids/cudf/NullEquality.java +++ b/java/src/main/java/ai/rapids/cudf/NullEquality.java @@ -18,8 +18,8 @@ package ai.rapids.cudf; -/* - * This is analogous to the native 'null_equality'. +/** + * How should nulls be compared in an operation. */ public enum NullEquality { UNEQUAL(false), diff --git a/java/src/main/java/ai/rapids/cudf/NullPolicy.java b/java/src/main/java/ai/rapids/cudf/NullPolicy.java index 469fbbdddac..225eb4a142f 100644 --- a/java/src/main/java/ai/rapids/cudf/NullPolicy.java +++ b/java/src/main/java/ai/rapids/cudf/NullPolicy.java @@ -18,8 +18,8 @@ package ai.rapids.cudf; -/* - * This is analogous to the native 'null_policy'. +/** + * Specify whether to include nulls or exclude nulls in an operation. */ public enum NullPolicy { EXCLUDE(false), diff --git a/java/src/main/java/ai/rapids/cudf/ReplacePolicy.java b/java/src/main/java/ai/rapids/cudf/ReplacePolicy.java new file mode 100644 index 00000000000..2f9cba88a5a --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ReplacePolicy.java @@ -0,0 +1,46 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * Policy to specify the position of replacement values relative to null rows. + */ +public enum ReplacePolicy { + /** + * The replacement value is the first non-null value preceding the null row. + */ + PRECEDING(true), + /** + * The replacement value is the first non-null value following the null row. + */ + FOLLOWING(false); + + ReplacePolicy(boolean isPreceding) { + this.isPreceding = isPreceding; + } + + final boolean isPreceding; + + /** + * Indicate which column the replacement should happen on. + */ + public ReplacePolicyWithColumn onColumn(int columnNumber) { + return new ReplacePolicyWithColumn(columnNumber, this); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ReplacePolicyWithColumn.java b/java/src/main/java/ai/rapids/cudf/ReplacePolicyWithColumn.java new file mode 100644 index 00000000000..5702f623ee1 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ReplacePolicyWithColumn.java @@ -0,0 +1,46 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * A replacement policy for a specific column + */ +public class ReplacePolicyWithColumn { + final int column; + final ReplacePolicy policy; + + ReplacePolicyWithColumn(int column, ReplacePolicy policy) { + this.column = column; + this.policy = policy; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof ReplacePolicyWithColumn)) { + return false; + } + ReplacePolicyWithColumn ro = (ReplacePolicyWithColumn)other; + return this.column == ro.column && this.policy.equals(ro.policy); + } + + @Override + public int hashCode() { + return 31 * column + policy.hashCode(); + } +} \ No newline at end of file diff --git a/java/src/main/java/ai/rapids/cudf/ScanType.java b/java/src/main/java/ai/rapids/cudf/ScanType.java new file mode 100644 index 00000000000..1fb3ff7e52b --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ScanType.java @@ -0,0 +1,39 @@ +/* + * + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * Scan operation type. + */ +public enum ScanType { + /** + * Include the current row in the scan. + */ + INCLUSIVE(true), + /** + * Exclude the current row from the scan. + */ + EXCLUSIVE(false); + + ScanType(boolean isInclusive) { + this.isInclusive = isInclusive; + } + + final boolean isInclusive; +} diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index c0515521cc5..ea261410585 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -28,14 +28,7 @@ import java.math.BigDecimal; import java.math.RoundingMode; import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.TreeMap; +import java.util.*; /** * Class to represent a collection of ColumnVectors and operations that can be performed on them @@ -464,6 +457,17 @@ private static native long[] groupByAggregate(long inputTable, int[] keyIndices, boolean keySorted, boolean[] keysDescending, boolean[] keysNullSmallest) throws CudfException; + private static native long[] groupByScan(long inputTable, int[] keyIndices, int[] aggColumnsIndices, + long[] aggInstances, boolean ignoreNullKeys, + boolean keySorted, boolean[] keysDescending, + boolean[] keysNullSmallest) throws CudfException; + + private static native long[] groupByReplaceNulls(long inputTable, int[] keyIndices, + int[] replaceColumnsIndices, + boolean[] isPreceding, boolean ignoreNullKeys, + boolean keySorted, boolean[] keysDescending, + boolean[] keysNullSmallest) throws CudfException; + private static native long[] rollingWindowAggregate( long inputTable, int[] keyIndices, @@ -2663,6 +2667,96 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate } } + public Table scan(AggregationOnColumn... aggregates) { + assert aggregates != null; + + // To improve performance and memory we want to remove duplicate operations + // and also group the operations by column so hopefully cudf can do multiple aggregations + // in a single pass. + + // Use a tree map to make debugging simpler (columns are all in the same order) + TreeMap groupedOps = new TreeMap<>(); + // Total number of operations that will need to be done. + int keysLength = operation.indices.length; + int totalOps = 0; + for (int outputIndex = 0; outputIndex < aggregates.length; outputIndex++) { + AggregationOnColumn agg = aggregates[outputIndex]; + ColumnOps ops = groupedOps.computeIfAbsent(agg.getColumnIndex(), (idx) -> new ColumnOps()); + totalOps += ops.add(agg, outputIndex + keysLength); + } + int[] aggColumnIndexes = new int[totalOps]; + long[] aggOperationInstances = new long[totalOps]; + try { + int opIndex = 0; + for (Map.Entry entry: groupedOps.entrySet()) { + int columnIndex = entry.getKey(); + for (Aggregation operation: entry.getValue().operations()) { + aggColumnIndexes[opIndex] = columnIndex; + aggOperationInstances[opIndex] = operation.createNativeInstance(); + opIndex++; + } + } + assert opIndex == totalOps : opIndex + " == " + totalOps; + + try (Table aggregate = new Table(groupByScan( + operation.table.nativeHandle, + operation.indices, + aggColumnIndexes, + aggOperationInstances, + groupByOptions.getIgnoreNullKeys(), + groupByOptions.getKeySorted(), + groupByOptions.getKeysDescending(), + groupByOptions.getKeysNullSmallest()))) { + // prepare the final table + ColumnVector[] finalCols = new ColumnVector[keysLength + aggregates.length]; + + // get the key columns + for (int aggIndex = 0; aggIndex < keysLength; aggIndex++) { + finalCols[aggIndex] = aggregate.getColumn(aggIndex); + } + + int inputColumn = keysLength; + // Now get the aggregation columns + for (ColumnOps ops: groupedOps.values()) { + for (List indices: ops.outputIndices()) { + for (int outIndex: indices) { + finalCols[outIndex] = aggregate.getColumn(inputColumn); + } + inputColumn++; + } + } + return new Table(finalCols); + } + } finally { + Aggregation.close(aggOperationInstances); + } + } + + public Table replaceNulls(ReplacePolicyWithColumn... replacements) { + assert replacements != null; + + // TODO in the future perhaps to improve performance and memory we want to + // remove duplicate operations. + + boolean[] isPreceding = new boolean[replacements.length]; + int [] columnIndexes = new int[replacements.length]; + + for (int index = 0; index < replacements.length; index++) { + isPreceding[index] = replacements[index].policy.isPreceding; + columnIndexes[index] = replacements[index].column; + } + + return new Table(groupByReplaceNulls( + operation.table.nativeHandle, + operation.indices, + columnIndexes, + isPreceding, + groupByOptions.getIgnoreNullKeys(), + groupByOptions.getKeySorted(), + groupByOptions.getKeysDescending(), + groupByOptions.getKeysNullSmallest())); + } + /** * Splits the groups in a single table into separate tables according to the grouping keys. * Each split table represents a single group. diff --git a/java/src/main/native/CMakeLists.txt b/java/src/main/native/CMakeLists.txt index 6c891511a1a..84b44b546a3 100755 --- a/java/src/main/native/CMakeLists.txt +++ b/java/src/main/native/CMakeLists.txt @@ -270,7 +270,6 @@ set(SOURCE_FILES "src/RmmJni.cpp" "src/ScalarJni.cpp" "src/TableJni.cpp" - "src/prefix_sum.cu" "src/map_lookup.cu") add_library(cudfjni SHARED ${SOURCE_FILES}) diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index bb8cc09851d..44ac3a91c77 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -66,7 +66,6 @@ #include #include "cudf/types.hpp" -#include "prefix_sum.hpp" #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" #include "jni.h" @@ -157,6 +156,20 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceNullsColumn(JNIEnv CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceNullsPolicy(JNIEnv *env, jclass, + jlong j_col, + jboolean is_preceding) { + JNI_NULL_CHECK(env, j_col, "column is null", 0); + try { + cudf::jni::auto_set_device(env); + cudf::column_view col = *reinterpret_cast(j_col); + std::unique_ptr result = cudf::replace_nulls(col, + is_preceding ? cudf::replace_policy::PRECEDING : cudf::replace_policy::FOLLOWING); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_ifElseVV(JNIEnv *env, jclass, jlong j_pred_vec, jlong j_true_vec, @@ -262,6 +275,28 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_reduce(JNIEnv *env, jclas CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_scan(JNIEnv *env, jclass, + jlong j_col_view, + jlong j_agg, + jboolean is_inclusive, + jboolean include_nulls) { + JNI_NULL_CHECK(env, j_col_view, "column view is null", 0); + JNI_NULL_CHECK(env, j_agg, "aggregation is null", 0); + try { + cudf::jni::auto_set_device(env); + auto col = reinterpret_cast(j_col_view); + auto agg = reinterpret_cast(j_agg); + + std::unique_ptr result = cudf::scan(*col, agg->clone(), + is_inclusive ? cudf::scan_type::INCLUSIVE : cudf::scan_type::EXCLUSIVE, + include_nulls ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + + + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_quantile(JNIEnv *env, jclass clazz, jlong input_column, jint quantile_method, @@ -1779,20 +1814,6 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_makeStructView(JNIEnv *en CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_prefixSum(JNIEnv *env, jobject j_object, - jlong handle) { - - JNI_NULL_CHECK(env, handle, "native view handle is null", 0) - - try { - cudf::jni::auto_set_device(env); - cudf::column_view *view = reinterpret_cast(handle); - std::unique_ptr result = cudf::jni::prefix_sum(*view); - return reinterpret_cast(result.release()); - } - CATCH_STD(env, 0) -} - JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_nansToNulls(JNIEnv *env, jobject j_object, jlong handle) { diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 4b01745382b..018dd211139 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -614,6 +614,13 @@ jlongArray convert_table_for_return(JNIEnv *env, std::unique_ptr &t return convert_table_for_return(env, table_result, extra); } +jlongArray convert_table_for_return(JNIEnv *env, + std::unique_ptr &first_table, + std::unique_ptr &second_table) { + std::vector> second_tmp = second_table->release(); + return convert_table_for_return(env, first_table, second_tmp); +} + // Convert the JNI boolean array of key column sort order to a vector of cudf::order // for groupby. std::vector resolve_column_order(JNIEnv *env, jbooleanArray jkeys_sort_desc, @@ -2101,6 +2108,129 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_groupByAggregate( CATCH_STD(env, NULL); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_groupByScan( + JNIEnv *env, jclass, jlong input_table, jintArray keys, + jintArray aggregate_column_indices, jlongArray agg_instances, jboolean ignore_null_keys, + jboolean jkey_sorted, jbooleanArray jkeys_sort_desc, jbooleanArray jkeys_null_first) { + JNI_NULL_CHECK(env, input_table, "input table is null", NULL); + JNI_NULL_CHECK(env, keys, "input keys are null", NULL); + JNI_NULL_CHECK(env, aggregate_column_indices, "input aggregate_column_indices are null", NULL); + JNI_NULL_CHECK(env, agg_instances, "agg_instances are null", NULL); + + try { + cudf::jni::auto_set_device(env); + cudf::table_view *n_input_table = reinterpret_cast(input_table); + cudf::jni::native_jintArray n_keys(env, keys); + cudf::jni::native_jintArray n_values(env, aggregate_column_indices); + cudf::jni::native_jpointerArray n_agg_instances(env, agg_instances); + std::vector n_keys_cols; + n_keys_cols.reserve(n_keys.size()); + for (int i = 0; i < n_keys.size(); i++) { + n_keys_cols.push_back(n_input_table->column(n_keys[i])); + } + + cudf::table_view n_keys_table(n_keys_cols); + auto column_order = cudf::jni::resolve_column_order(env, jkeys_sort_desc, + n_keys.size()); + auto null_precedence = cudf::jni::resolve_null_precedence(env, jkeys_null_first, + n_keys.size()); + cudf::groupby::groupby grouper(n_keys_table, + ignore_null_keys ? cudf::null_policy::EXCLUDE + : cudf::null_policy::INCLUDE, + jkey_sorted ? cudf::sorted::YES : cudf::sorted::NO, + column_order, + null_precedence); + + // Aggregates are passed in already grouped by column, so we just need to fill it in + // as we go. + std::vector requests; + + int previous_index = -1; + for (int i = 0; i < n_values.size(); i++) { + cudf::groupby::aggregation_request req; + int col_index = n_values[i]; + if (col_index == previous_index) { + requests.back().aggregations.push_back(n_agg_instances[i]->clone()); + } else { + req.values = n_input_table->column(col_index); + req.aggregations.push_back(n_agg_instances[i]->clone()); + requests.push_back(std::move(req)); + } + previous_index = col_index; + } + + std::pair, std::vector> result = + grouper.scan(requests); + + std::vector> result_columns; + int agg_result_size = result.second.size(); + for (int agg_result_index = 0; agg_result_index < agg_result_size; agg_result_index++) { + int col_agg_size = result.second[agg_result_index].results.size(); + for (int col_agg_index = 0; col_agg_index < col_agg_size; col_agg_index++) { + result_columns.push_back(std::move(result.second[agg_result_index].results[col_agg_index])); + } + } + return cudf::jni::convert_table_for_return(env, result.first, result_columns); + } + CATCH_STD(env, NULL); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_groupByReplaceNulls( + JNIEnv *env, jclass, jlong input_table, jintArray keys, + jintArray replace_column_indices, jbooleanArray is_preceding, jboolean ignore_null_keys, + jboolean jkey_sorted, jbooleanArray jkeys_sort_desc, jbooleanArray jkeys_null_first) { + JNI_NULL_CHECK(env, input_table, "input table is null", NULL); + JNI_NULL_CHECK(env, keys, "input keys are null", NULL); + JNI_NULL_CHECK(env, replace_column_indices, "input replace_column_indices are null", NULL); + JNI_NULL_CHECK(env, is_preceding, "is_preceding are null", NULL); + + try { + cudf::jni::auto_set_device(env); + cudf::table_view *n_input_table = reinterpret_cast(input_table); + cudf::jni::native_jintArray n_keys(env, keys); + cudf::jni::native_jintArray n_values(env, replace_column_indices); + cudf::jni::native_jbooleanArray n_is_preceding(env, is_preceding); + std::vector n_keys_cols; + n_keys_cols.reserve(n_keys.size()); + for (int i = 0; i < n_keys.size(); i++) { + n_keys_cols.push_back(n_input_table->column(n_keys[i])); + } + + cudf::table_view n_keys_table(n_keys_cols); + auto column_order = cudf::jni::resolve_column_order(env, jkeys_sort_desc, + n_keys.size()); + auto null_precedence = cudf::jni::resolve_null_precedence(env, jkeys_null_first, + n_keys.size()); + cudf::groupby::groupby grouper(n_keys_table, + ignore_null_keys ? cudf::null_policy::EXCLUDE + : cudf::null_policy::INCLUDE, + jkey_sorted ? cudf::sorted::YES : cudf::sorted::NO, + column_order, + null_precedence); + + // Aggregates are passed in already grouped by column, so we just need to fill it in + // as we go. + std::vector n_replace_cols; + n_replace_cols.reserve(n_values.size()); + for (int i = 0; i < n_values.size(); i++) { + n_replace_cols.push_back(n_input_table->column(n_values[i])); + } + cudf::table_view n_replace_table(n_replace_cols); + + std::vector policies; + policies.reserve(n_is_preceding.size()); + for (int i = 0; i < n_is_preceding.size(); i++) { + policies.push_back(n_is_preceding[i] ? cudf::replace_policy::PRECEDING : cudf::replace_policy::FOLLOWING); + } + + std::pair, std::unique_ptr> result = + grouper.replace_nulls(n_replace_table, policies); + + return cudf::jni::convert_table_for_return(env, result.first, result.second); + } + CATCH_STD(env, NULL); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_filter(JNIEnv *env, jclass, jlong input_jtable, jlong mask_jcol) { JNI_NULL_CHECK(env, input_jtable, "input table is null", 0); diff --git a/java/src/main/native/src/prefix_sum.cu b/java/src/main/native/src/prefix_sum.cu deleted file mode 100644 index e3c53696185..00000000000 --- a/java/src/main/native/src/prefix_sum.cu +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) 2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include - -#include -#include -#include - - -namespace cudf { -namespace jni { - -std::unique_ptr prefix_sum(column_view const &value_column, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource *mr) { - // Defensive checks. - CUDF_EXPECTS(value_column.type().id() == type_id::INT64, "Only longs are supported."); - CUDF_EXPECTS(!value_column.has_nulls(), "NULLS are not supported"); - - auto result = make_numeric_column(value_column.type(), value_column.size(), - mask_state::ALL_VALID, stream, mr); - - thrust::inclusive_scan(rmm::exec_policy(stream), - value_column.begin(), - value_column.end(), - result->mutable_view().begin()); - - return result; -} -} // namespace jni -} // namespace cudf diff --git a/java/src/main/native/src/prefix_sum.hpp b/java/src/main/native/src/prefix_sum.hpp deleted file mode 100644 index 8f39f9a8c69..00000000000 --- a/java/src/main/native/src/prefix_sum.hpp +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace cudf { - -namespace jni { - -/** - * @brief compute the prefix sum of a column of longs - */ -std::unique_ptr -prefix_sum(column_view const &value_column, - rmm::cuda_stream_view stream = rmm::cuda_stream_default, - rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource()); - -} // namespace jni - -} // namespace cudf diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index e3ca880d587..a121309d8aa 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -1775,6 +1775,18 @@ void testReplaceNullsWithNullScalar() { } } + @Test + void testReplaceNullsPolicy() { + try (ColumnVector input = ColumnVector.fromBoxedInts(null, 1, 2, null, 4, null); + ColumnVector preceding = input.replaceNulls(ReplacePolicy.PRECEDING); + ColumnVector expectedPre = ColumnVector.fromBoxedInts(null, 1, 2, 2, 4, 4); + ColumnVector following = input.replaceNulls(ReplacePolicy.FOLLOWING); + ColumnVector expectedFol = ColumnVector.fromBoxedInts(1, 1, 2, 4, 4, null)) { + assertColumnsAreEqual(expectedPre, preceding); + assertColumnsAreEqual(expectedFol, following); + } + } + @Test void testReplaceNullsColumnEmptyColumn() { try (ColumnVector input = ColumnVector.fromBoxedBooleans(); @@ -2807,21 +2819,110 @@ void testPrefixSum() { } @Test - void testPrefixSumErrors() { - try (ColumnVector v1 = ColumnVector.fromBoxedLongs(1L, 2L, 3L, 5L, 8L, null)) { - assertThrows(CudfException.class, () -> { - try(ColumnVector ignored = v1.prefixSum()) { - // empty - } - }); + void testScanSum() { + try (ColumnVector v1 = ColumnVector.fromBoxedInts(1, 2, null, 3, 5, 8, 10)) { + // Due to https://github.com/rapidsai/cudf/issues/8462 NullPolicy.INCLUDE + // tests have been disabled +// try (ColumnVector result = v1.scan(Aggregation.sum(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); +// ColumnVector expected = ColumnVector.fromBoxedInts(1, 3, null, null, null, null, null)) { +// assertColumnsAreEqual(expected, result); +// } + + try (ColumnVector result = v1.scan(Aggregation.sum(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 3, null, 6, 11, 19, 29)) { + assertColumnsAreEqual(expected, result); + } + +// try (ColumnVector result = v1.scan(Aggregation.sum(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); +// ColumnVector expected = ColumnVector.fromBoxedInts(0, 1, 3, 3, 6, 11, 19)) { +// assertColumnsAreEqual(expected, result); +// } + + try (ColumnVector result = v1.scan(Aggregation.sum(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(0, 1, null, 3, 6, 11, 19)) { + assertColumnsAreEqual(expected, result); + } } + } - try (ColumnVector v1 = ColumnVector.fromInts(1, 2, 3, 5, 8, 10)) { - assertThrows(CudfException.class, () -> { - try(ColumnVector ignored = v1.prefixSum()) { - // empty - } - }); + @Test + void testScanMax() { + // Due to https://github.com/rapidsai/cudf/issues/8462 NullPolicy.INCLUDE + // tests have been disabled + try (ColumnVector v1 = ColumnVector.fromBoxedInts(1, 2, null, 3, 5, 8, 10)) { +// try (ColumnVector result = v1.scan(Aggregation.max(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); +// ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, null, null, null, null)) { +// assertColumnsAreEqual(expected, result); +// } + + try (ColumnVector result = v1.scan(Aggregation.max(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, 3, 5, 8, 10)) { + assertColumnsAreEqual(expected, result); + } + +// try (ColumnVector result = v1.scan(Aggregation.max(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); +// ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MIN_VALUE, 1, 2, 2, 3, 5, 8)) { +// assertColumnsAreEqual(expected, result); +// } + + try (ColumnVector result = v1.scan(Aggregation.max(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MIN_VALUE, 1, null, 2, 3, 5, 8)) { + assertColumnsAreEqual(expected, result); + } + } + } + + @Test + void testScanMin() { + // Due to https://github.com/rapidsai/cudf/issues/8462 NullPolicy.INCLUDE + // tests have been disabled + try (ColumnVector v1 = ColumnVector.fromBoxedInts(1, 2, null, 3, 5, 8, 10)) { +// try (ColumnVector result = v1.scan(Aggregation.min(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); +// ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, null, null, null, null, null)) { +// assertColumnsAreEqual(expected, result); +// } + + try (ColumnVector result = v1.scan(Aggregation.min(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, null, 1, 1, 1, 1)) { + assertColumnsAreEqual(expected, result); + } + +// try (ColumnVector result = v1.scan(Aggregation.min(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); +// ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MAX_VALUE, 1, 1, 1, 1, 1, 1)) { +// assertColumnsAreEqual(expected, result); +// } + + try (ColumnVector result = v1.scan(Aggregation.min(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(Integer.MAX_VALUE, 1, null, 1, 1, 1, 1)) { + assertColumnsAreEqual(expected, result); + } + } + } + + @Test + void testScanProduct() { + // Due to https://github.com/rapidsai/cudf/issues/8462 NullPolicy.INCLUDE + // tests have been disabled + try (ColumnVector v1 = ColumnVector.fromBoxedInts(1, 2, null, 3, 5, 8, 10)) { +// try (ColumnVector result = v1.scan(Aggregation.product(), ScanType.INCLUSIVE, NullPolicy.INCLUDE); +// ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, null, null, null, null)) { +// assertColumnsAreEqual(expected, result); +// } + + try (ColumnVector result = v1.scan(Aggregation.product(), ScanType.INCLUSIVE, NullPolicy.EXCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, 6, 30, 240, 2400)) { + assertColumnsAreEqual(expected, result); + } + +// try (ColumnVector result = v1.scan(Aggregation.product(), ScanType.EXCLUSIVE, NullPolicy.INCLUDE); +// ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, 2, 2, 6, 30, 240)) { +// assertColumnsAreEqual(expected, result); +// } + + try (ColumnVector result = v1.scan(Aggregation.product(), ScanType.EXCLUSIVE, NullPolicy.EXCLUDE); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 1, null, 2, 6, 30, 240)) { + assertColumnsAreEqual(expected, result); + } } } diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 9fc5a39d0ce..38fffbf5adc 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -2658,6 +2658,60 @@ void testValidityFill() { assertEquals(buff[1], 0xFFFFFFFF); } + @Test + void testGroupByScan() { + try (Table t1 = new Table.TestBuilder() + .column( "1", "1", "1", "1", "1", "1", "1", "2", "2", "2", "2") + .column( 0, 1, 3, 3, 5, 5, 5, 5, 5, 5, 5) + .column(12.0, 14.0, 13.0, 17.0, 17.0, 17.0, null, null, 11.0, null, 10.0) + .build()) { + try (Table result = t1 + .groupBy(GroupByOptions.builder() + .withKeysSorted(true) + .withKeysDescending(false, false) + .build(), 0, 1) + .scan(Aggregation.sum().onColumn(2), + Aggregation.count(NullPolicy.INCLUDE).onColumn(2), + Aggregation.min().onColumn(2), + Aggregation.max().onColumn(2)); + Table expected = new Table.TestBuilder() + .column( "1", "1", "1", "1", "1", "1", "1", "2", "2", "2", "2") + .column( 0, 1, 3, 3, 5, 5, 5, 5, 5, 5, 5) + .column(12.0, 14.0, 13.0, 30.0, 17.0, 34.0, null, null, 11.0, null, 21.0) + .column( 0, 0, 0, 1, 0, 1, 2, 0, 1, 2, 3) // odd why is this not 1 based? + .column(12.0, 14.0, 13.0, 13.0, 17.0, 17.0, null, null, 11.0, null, 10.0) + .column(12.0, 14.0, 13.0, 17.0, 17.0, 17.0, null, null, 11.0, null, 11.0) + .build()) { + assertTablesAreEqual(expected, result); + } + } + } + + @Test + void testGroupByReplaceNulls() { + try (Table t1 = new Table.TestBuilder() + .column( "1", "1", "1", "1", "1", "1", "1", "2", "2", "2", "2") + .column( 0, 1, 3, 3, 5, 5, 5, 5, 5, 5, 5) + .column(null, 14.0, 13.0, 17.0, 17.0, 17.0, null, null, 11.0, null, null) + .build()) { + try (Table result = t1 + .groupBy(GroupByOptions.builder() + .withKeysSorted(true) + .withKeysDescending(false, false) + .build(), 0, 1) + .replaceNulls(ReplacePolicy.PRECEDING.onColumn(2), + ReplacePolicy.FOLLOWING.onColumn(2)); + Table expected = new Table.TestBuilder() + .column( "1", "1", "1", "1", "1", "1", "1", "2", "2", "2", "2") + .column( 0, 1, 3, 3, 5, 5, 5, 5, 5, 5, 5) + .column(null, 14.0, 13.0, 17.0, 17.0, 17.0, 17.0, null, 11.0, 11.0, 11.0) + .column(null, 14.0, 13.0, 17.0, 17.0, 17.0, null, 11.0, 11.0, null, null) + .build()) { + assertTablesAreEqual(expected, result); + } + } + } + @Test void testGroupByUniqueCount() { try (Table t1 = new Table.TestBuilder()