From 181b946d8446fd1cfecb35ef9b8d8c5ace8b3098 Mon Sep 17 00:00:00 2001 From: MithunR Date: Mon, 24 Apr 2023 09:40:31 -0700 Subject: [PATCH] JNI changes for range-extents in window functions. (#13199) This commit adds back-end JNI changes to support explicit range-extents for ranged window functions. The change is analogous to the addition of `cudf::range_window_bounds::extent_type`, in `libcudf`. This change will allow for `STRING` order-by columns for range window functions. It is required because without it, it would be impossible to differentiate between `UNBOUNDED` and `CURRENT ROW` for a window function over a `STRING` order-by column. Authors: - MithunR (https://github.com/mythrocks) Approvers: - Nghia Truong (https://github.com/ttnghia) - Robert (Bobby) Evans (https://github.com/revans2) URL: https://github.com/rapidsai/cudf/pull/13199 --- java/src/main/java/ai/rapids/cudf/Table.java | 24 +++--- .../java/ai/rapids/cudf/WindowOptions.java | 81 ++++++++++++++----- java/src/main/native/src/TableJni.cpp | 49 +++++++---- .../test/java/ai/rapids/cudf/TableTest.java | 52 ++++++++++++ 4 files changed, 157 insertions(+), 49 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 3ccab70ccda..9abc2dbcd7c 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -547,7 +547,7 @@ private static native long[] rollingWindowAggregate( private static native long[] rangeRollingWindowAggregate(long inputTable, int[] keyIndices, int[] orderByIndices, boolean[] isOrderByAscending, int[] aggColumnsIndices, long[] aggInstances, int[] minPeriods, - long[] preceding, long[] following, boolean[] unboundedPreceding, boolean[] unboundedFollowing, + long[] preceding, long[] following, int[] precedingRangeExtent, int[] followingRangeExtent, boolean ignoreNullKeys) throws CudfException; private static native long sortOrder(long inputTable, long[] sortKeys, boolean[] isDescending, @@ -3981,10 +3981,11 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate case DECIMAL32: case DECIMAL64: case DECIMAL128: + case STRING: break; default: throw new IllegalArgumentException("Expected range-based window orderBy's " + - "type: integral (Boolean-exclusive), decimal, and timestamp"); + "type: integral (Boolean-exclusive), decimal, timestamp, and string"); } ColumnWindowOps ops = groupedOps.computeIfAbsent(agg.getColumnIndex(), (idx) -> new ColumnWindowOps()); @@ -3998,8 +3999,8 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate long[] aggPrecedingWindows = new long[totalOps]; long[] aggFollowingWindows = new long[totalOps]; try { - boolean[] aggPrecedingWindowsUnbounded = new boolean[totalOps]; - boolean[] aggFollowingWindowsUnbounded = new boolean[totalOps]; + int[] aggPrecedingWindowsExtent = new int[totalOps]; + int[] aggFollowingWindowsExtent = new int[totalOps]; int[] aggMinPeriods = new int[totalOps]; int opIndex = 0; for (Map.Entry entry: groupedOps.entrySet()) { @@ -4007,18 +4008,19 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate for (AggregationOverWindow op: entry.getValue().operations()) { aggColumnIndexes[opIndex] = columnIndex; aggInstances[opIndex] = op.createNativeInstance(); - Scalar p = op.getWindowOptions().getPrecedingScalar(); - Scalar f = op.getWindowOptions().getFollowingScalar(); - if ((p == null || !p.isValid()) && !op.getWindowOptions().isUnboundedPreceding()) { + WindowOptions windowOptions = op.getWindowOptions(); + Scalar p = windowOptions.getPrecedingScalar(); + Scalar f = windowOptions.getFollowingScalar(); + if ((p == null || !p.isValid()) && !(windowOptions.isUnboundedPreceding() || windowOptions.isCurrentRowPreceding())) { throw new IllegalArgumentException("Some kind of preceding must be set and a preceding column is not currently supported"); } - if ((f == null || !f.isValid()) && !op.getWindowOptions().isUnboundedFollowing()) { + if ((f == null || !f.isValid()) && !(windowOptions.isUnboundedFollowing() || windowOptions.isCurrentRowFollowing())) { throw new IllegalArgumentException("some kind of following must be set and a follow column is not currently supported"); } aggPrecedingWindows[opIndex] = p == null ? 0 : p.getScalarHandle(); aggFollowingWindows[opIndex] = f == null ? 0 : f.getScalarHandle(); - aggPrecedingWindowsUnbounded[opIndex] = op.getWindowOptions().isUnboundedPreceding(); - aggFollowingWindowsUnbounded[opIndex] = op.getWindowOptions().isUnboundedFollowing(); + aggPrecedingWindowsExtent[opIndex] = windowOptions.getPrecedingBoundsExtent().nominalValue; + aggFollowingWindowsExtent[opIndex] = windowOptions.getFollowingBoundsExtent().nominalValue; aggMinPeriods[opIndex] = op.getWindowOptions().getMinPeriods(); assert (op.getWindowOptions().getFrameType() == WindowOptions.FrameType.RANGE); orderByColumnIndexes[opIndex] = op.getWindowOptions().getOrderByColumnIndex(); @@ -4040,7 +4042,7 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate isOrderByOrderAscending, aggColumnIndexes, aggInstances, aggMinPeriods, aggPrecedingWindows, aggFollowingWindows, - aggPrecedingWindowsUnbounded, aggFollowingWindowsUnbounded, + aggPrecedingWindowsExtent, aggFollowingWindowsExtent, groupByOptions.getIgnoreNullKeys()))) { // prepare the final table ColumnVector[] finalCols = new ColumnVector[windowAggregates.length]; diff --git a/java/src/main/java/ai/rapids/cudf/WindowOptions.java b/java/src/main/java/ai/rapids/cudf/WindowOptions.java index 6ab5c0525ca..ef6b3ce70c8 100644 --- a/java/src/main/java/ai/rapids/cudf/WindowOptions.java +++ b/java/src/main/java/ai/rapids/cudf/WindowOptions.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,23 @@ public class WindowOptions implements AutoCloseable { enum FrameType {ROWS, RANGE} + /** + * Extent of (range) window bounds. + * Analogous to cudf::range_window_bounds::extent_type. + */ + enum RangeExtentType { + CURRENT_ROW(0), // Bounds defined as the first/last row that matches the current row. + BOUNDED(1), // Bounds defined as the first/last row that falls within + // a specified range from the current row. + UNBOUNDED(2); // Bounds stretching to the first/last row in the entire group. + + public final int nominalValue; + + RangeExtentType(int n) { + this.nominalValue = n; + } + } + private final int minPeriods; private final Scalar precedingScalar; private final Scalar followingScalar; @@ -33,8 +50,8 @@ enum FrameType {ROWS, RANGE} private final int orderByColumnIndex; private final boolean orderByOrderAscending; private final FrameType frameType; - private final boolean isUnboundedPreceding; - private final boolean isUnboundedFollowing; + private final RangeExtentType precedingBoundsExtent; + private final RangeExtentType followingBoundsExtent; private WindowOptions(Builder builder) { this.minPeriods = builder.minPeriods; @@ -57,9 +74,8 @@ private WindowOptions(Builder builder) { this.orderByColumnIndex = builder.orderByColumnIndex; this.orderByOrderAscending = builder.orderByOrderAscending; this.frameType = orderByColumnIndex == -1? FrameType.ROWS : FrameType.RANGE; - this.isUnboundedPreceding = builder.isUnboundedPreceding; - this.isUnboundedFollowing = builder.isUnboundedFollowing; - + this.precedingBoundsExtent = builder.precedingBoundsExtent; + this.followingBoundsExtent = builder.followingBoundsExtent; } @Override @@ -72,8 +88,8 @@ public boolean equals(Object other) { this.orderByColumnIndex == o.orderByColumnIndex && this.orderByOrderAscending == o.orderByOrderAscending && this.frameType == o.frameType && - this.isUnboundedPreceding == o.isUnboundedPreceding && - this.isUnboundedFollowing == o.isUnboundedFollowing; + this.precedingBoundsExtent == o.precedingBoundsExtent && + this.followingBoundsExtent == o.followingBoundsExtent; if (precedingCol != null) { ret = ret && precedingCol.equals(o.precedingCol); } @@ -110,8 +126,8 @@ public int hashCode() { if (followingScalar != null) { ret = 31 * ret + followingScalar.hashCode(); } - ret = 31 * ret + Boolean.hashCode(isUnboundedPreceding); - ret = 31 * ret + Boolean.hashCode(isUnboundedFollowing); + ret = 31 * ret + precedingBoundsExtent.hashCode(); + ret = 31 * ret + followingBoundsExtent.hashCode(); return ret; } @@ -139,9 +155,16 @@ public static Builder builder(){ boolean isOrderByOrderAscending() { return this.orderByOrderAscending; } - boolean isUnboundedPreceding() { return this.isUnboundedPreceding; } + boolean isUnboundedPreceding() { return this.precedingBoundsExtent == RangeExtentType.UNBOUNDED; } - boolean isUnboundedFollowing() { return this.isUnboundedFollowing; } + boolean isUnboundedFollowing() { return this.followingBoundsExtent == RangeExtentType.UNBOUNDED; } + + boolean isCurrentRowPreceding() { return this.precedingBoundsExtent == RangeExtentType.CURRENT_ROW; } + + boolean isCurrentRowFollowing() { return this.followingBoundsExtent == RangeExtentType.CURRENT_ROW; } + + RangeExtentType getPrecedingBoundsExtent() { return this.precedingBoundsExtent; } + RangeExtentType getFollowingBoundsExtent() { return this.followingBoundsExtent; } FrameType getFrameType() { return frameType; } @@ -154,8 +177,8 @@ public static class Builder { private ColumnVector followingCol = null; private int orderByColumnIndex = -1; private boolean orderByOrderAscending = true; - private boolean isUnboundedPreceding = false; - private boolean isUnboundedFollowing = false; + private RangeExtentType precedingBoundsExtent = RangeExtentType.BOUNDED; + private RangeExtentType followingBoundsExtent = RangeExtentType.BOUNDED; /** * Set the minimum number of observation required to evaluate an element. If there are not @@ -171,7 +194,7 @@ public Builder minPeriods(int minPeriods) { /** * Set the size of the window, one entry per row. This does not take ownership of the - * columns passed in so you have to be sure that the life time of the column outlives + * columns passed in so you have to be sure that the lifetime of the column outlives * this operation. * @param precedingCol the number of rows preceding the current row and * precedingCol will be live outside of WindowOptions. @@ -185,10 +208,10 @@ public Builder window(ColumnVector precedingCol, ColumnVector followingCol) { if (followingCol == null || followingCol.hasNulls()) { throw new IllegalArgumentException("following cannot be null or have nulls"); } - if (isUnboundedPreceding || precedingScalar != null) { + if (precedingBoundsExtent != RangeExtentType.BOUNDED || precedingScalar != null) { throw new IllegalStateException("preceding has already been set a different way"); } - if (isUnboundedFollowing || followingScalar != null) { + if (followingBoundsExtent != RangeExtentType.BOUNDED || followingScalar != null) { throw new IllegalStateException("following has already been set a different way"); } this.precedingCol = precedingCol; @@ -246,11 +269,27 @@ public Builder timestampDescending() { return orderByDescending(); } + public Builder currentRowPreceding() { + if (precedingCol != null || precedingScalar != null) { + throw new IllegalStateException("preceding has already been set a different way"); + } + this.precedingBoundsExtent = RangeExtentType.CURRENT_ROW; + return this; + } + + public Builder currentRowFollowing() { + if (followingCol != null || followingScalar != null) { + throw new IllegalStateException("following has already been set a different way"); + } + this.followingBoundsExtent = RangeExtentType.CURRENT_ROW; + return this; + } + public Builder unboundedPreceding() { if (precedingCol != null || precedingScalar != null) { throw new IllegalStateException("preceding has already been set a different way"); } - this.isUnboundedPreceding = true; + this.precedingBoundsExtent = RangeExtentType.UNBOUNDED; return this; } @@ -258,7 +297,7 @@ public Builder unboundedFollowing() { if (followingCol != null || followingScalar != null) { throw new IllegalStateException("following has already been set a different way"); } - this.isUnboundedFollowing = true; + this.followingBoundsExtent = RangeExtentType.UNBOUNDED; return this; } @@ -270,7 +309,7 @@ public Builder preceding(Scalar preceding) { if (preceding == null || !preceding.isValid()) { throw new IllegalArgumentException("preceding cannot be null"); } - if (isUnboundedPreceding || precedingCol != null) { + if (precedingBoundsExtent != RangeExtentType.BOUNDED || precedingCol != null) { throw new IllegalStateException("preceding has already been set a different way"); } this.precedingScalar = preceding; @@ -285,7 +324,7 @@ public Builder following(Scalar following) { if (following == null || !following.isValid()) { throw new IllegalArgumentException("following cannot be null"); } - if (isUnboundedFollowing || followingCol != null) { + if (followingBoundsExtent != RangeExtentType.BOUNDED || followingCol != null) { throw new IllegalStateException("following has already been set a different way"); } this.followingScalar = following; diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index ddcc8644a9c..22a0e4f5c33 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -3221,8 +3221,8 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega JNIEnv *env, jclass, jlong j_input_table, jintArray j_keys, jintArray j_orderby_column_indices, jbooleanArray j_is_orderby_ascending, jintArray j_aggregate_column_indices, jlongArray j_agg_instances, jintArray j_min_periods, jlongArray j_preceding, - jlongArray j_following, jbooleanArray j_unbounded_preceding, - jbooleanArray j_unbounded_following, jboolean ignore_null_keys) { + jlongArray j_following, jintArray j_preceding_extent, jintArray j_following_extent, + jboolean ignore_null_keys) { JNI_NULL_CHECK(env, j_input_table, "input table is null", NULL); JNI_NULL_CHECK(env, j_keys, "input keys are null", NULL); @@ -3246,8 +3246,8 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega cudf::jni::native_jintArray values{env, j_aggregate_column_indices}; cudf::jni::native_jpointerArray agg_instances(env, j_agg_instances); cudf::jni::native_jintArray min_periods{env, j_min_periods}; - cudf::jni::native_jbooleanArray unbounded_preceding{env, j_unbounded_preceding}; - cudf::jni::native_jbooleanArray unbounded_following{env, j_unbounded_following}; + cudf::jni::native_jintArray preceding_extent{env, j_preceding_extent}; + cudf::jni::native_jintArray following_extent{env, j_following_extent}; cudf::jni::native_jpointerArray preceding(env, j_preceding); cudf::jni::native_jpointerArray following(env, j_following); @@ -3266,24 +3266,32 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega int agg_column_index = values[i]; cudf::column_view const &order_by_column = input_table->column(orderbys[i]); cudf::data_type order_by_type = order_by_column.type(); - cudf::data_type unbounded_type = order_by_type; - - if (unbounded_preceding[i] || unbounded_following[i]) { + cudf::data_type duration_type = order_by_type; + + // Range extents are defined as: + // a) 0 == CURRENT ROW + // b) 1 == BOUNDED + // c) 2 == UNBOUNDED + // Must set unbounded_type for only the BOUNDED case. + auto constexpr CURRENT_ROW = 0; + auto constexpr BOUNDED = 1; + auto constexpr UNBOUNDED = 2; + if (preceding_extent[i] != BOUNDED || following_extent[i] != BOUNDED) { switch (order_by_type.id()) { case cudf::type_id::TIMESTAMP_DAYS: - unbounded_type = cudf::data_type{cudf::type_id::DURATION_DAYS}; + duration_type = cudf::data_type{cudf::type_id::DURATION_DAYS}; break; case cudf::type_id::TIMESTAMP_SECONDS: - unbounded_type = cudf::data_type{cudf::type_id::DURATION_SECONDS}; + duration_type = cudf::data_type{cudf::type_id::DURATION_SECONDS}; break; case cudf::type_id::TIMESTAMP_MILLISECONDS: - unbounded_type = cudf::data_type{cudf::type_id::DURATION_MILLISECONDS}; + duration_type = cudf::data_type{cudf::type_id::DURATION_MILLISECONDS}; break; case cudf::type_id::TIMESTAMP_MICROSECONDS: - unbounded_type = cudf::data_type{cudf::type_id::DURATION_MICROSECONDS}; + duration_type = cudf::data_type{cudf::type_id::DURATION_MICROSECONDS}; break; case cudf::type_id::TIMESTAMP_NANOSECONDS: - unbounded_type = cudf::data_type{cudf::type_id::DURATION_NANOSECONDS}; + duration_type = cudf::data_type{cudf::type_id::DURATION_NANOSECONDS}; break; default: break; } @@ -3293,15 +3301,22 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega JNI_ARG_CHECK(env, agg != nullptr, "aggregation is not an instance of rolling_aggregation", nullptr); + auto const make_window_bounds = [&](auto const &range_extent, auto const *p_scalar) { + if (range_extent == CURRENT_ROW) { + return cudf::range_window_bounds::current_row(duration_type); + } else if (range_extent == UNBOUNDED) { + return cudf::range_window_bounds::unbounded(duration_type); + } else { + return cudf::range_window_bounds::get(*p_scalar); + } + }; + result_columns.emplace_back(cudf::grouped_range_rolling_window( groupby_keys, order_by_column, orderbys_ascending[i] ? cudf::order::ASCENDING : cudf::order::DESCENDING, input_table->column(agg_column_index), - unbounded_preceding[i] ? cudf::range_window_bounds::unbounded(unbounded_type) : - cudf::range_window_bounds::get(*preceding[i]), - unbounded_following[i] ? cudf::range_window_bounds::unbounded(unbounded_type) : - cudf::range_window_bounds::get(*following[i]), - min_periods[i], *agg)); + make_window_bounds(preceding_extent[i], preceding[i]), + make_window_bounds(following_extent[i], following[i]), min_periods[i], *agg)); } auto result_table = std::make_unique(std::move(result_columns)); diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index c31bcf4f78d..6b03079fa81 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -5641,6 +5641,58 @@ void testRangeWindowingCountUnboundedPreceding() { } } + @Test + void testRangeWindowingWithStringOrderByColumn() { + final String X = null; + final int orderIndex = 3; // Index of order-by column. + try (Table unsorted = new Table.TestBuilder() + .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column("0", "1", "2", "3", "4", "5", X, X, "1", "2", "4", "5", "7") // String orderBy Key + .build()) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(3, true)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(2); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + + try (WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() + .minPeriods(1) + .unboundedPreceding() + .unboundedFollowing() + .orderByColumnIndex(orderIndex) + .build(); + WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() + .minPeriods(1) + .unboundedPreceding() + .currentRowFollowing() + .orderByColumnIndex(orderIndex) + .build(); + WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() + .minPeriods(1) + .currentRowPreceding() + .unboundedFollowing() + .orderByColumnIndex(orderIndex) + .build()) { + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverRanges( + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), + RollingAggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), + RollingAggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); + ColumnVector expect_0 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); + ColumnVector expect_1 = ColumnVector.fromBoxedInts(1, 2, 3, 4, 5, 6, 2, 2, 3, 4, 5, 6, 7); + ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 5, 4, 3, 2, 1, 7, 7, 5, 4, 3, 2, 1)) { + + assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); + assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); + } + } + } + } + } + @Test void testRangeWindowingCountUnboundedASCWithNullsFirst() { try (Table unsorted = new Table.TestBuilder()