Skip to content

Commit

Permalink
JNI changes for range-extents in window functions. (#13199)
Browse files Browse the repository at this point in the history
This commit adds back-end JNI changes to support explicit range-extents for ranged window functions. The change is analogous to the addition of `cudf::range_window_bounds::extent_type`, in `libcudf`. This change will allow for `STRING` order-by columns for range window functions. It is required because without it, it would be impossible to differentiate between `UNBOUNDED` and `CURRENT ROW` for a window function over a `STRING` order-by column.

Authors:
  - MithunR (https://github.com/mythrocks)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #13199
  • Loading branch information
mythrocks authored Apr 24, 2023
1 parent bc2fa11 commit 181b946
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 49 deletions.
24 changes: 13 additions & 11 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ private static native long[] rollingWindowAggregate(

private static native long[] rangeRollingWindowAggregate(long inputTable, int[] keyIndices, int[] orderByIndices, boolean[] isOrderByAscending,
int[] aggColumnsIndices, long[] aggInstances, int[] minPeriods,
long[] preceding, long[] following, boolean[] unboundedPreceding, boolean[] unboundedFollowing,
long[] preceding, long[] following, int[] precedingRangeExtent, int[] followingRangeExtent,
boolean ignoreNullKeys) throws CudfException;

private static native long sortOrder(long inputTable, long[] sortKeys, boolean[] isDescending,
Expand Down Expand Up @@ -3981,10 +3981,11 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate
case DECIMAL32:
case DECIMAL64:
case DECIMAL128:
case STRING:
break;
default:
throw new IllegalArgumentException("Expected range-based window orderBy's " +
"type: integral (Boolean-exclusive), decimal, and timestamp");
"type: integral (Boolean-exclusive), decimal, timestamp, and string");
}

ColumnWindowOps ops = groupedOps.computeIfAbsent(agg.getColumnIndex(), (idx) -> new ColumnWindowOps());
Expand All @@ -3998,27 +3999,28 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate
long[] aggPrecedingWindows = new long[totalOps];
long[] aggFollowingWindows = new long[totalOps];
try {
boolean[] aggPrecedingWindowsUnbounded = new boolean[totalOps];
boolean[] aggFollowingWindowsUnbounded = new boolean[totalOps];
int[] aggPrecedingWindowsExtent = new int[totalOps];
int[] aggFollowingWindowsExtent = new int[totalOps];
int[] aggMinPeriods = new int[totalOps];
int opIndex = 0;
for (Map.Entry<Integer, ColumnWindowOps> entry: groupedOps.entrySet()) {
int columnIndex = entry.getKey();
for (AggregationOverWindow op: entry.getValue().operations()) {
aggColumnIndexes[opIndex] = columnIndex;
aggInstances[opIndex] = op.createNativeInstance();
Scalar p = op.getWindowOptions().getPrecedingScalar();
Scalar f = op.getWindowOptions().getFollowingScalar();
if ((p == null || !p.isValid()) && !op.getWindowOptions().isUnboundedPreceding()) {
WindowOptions windowOptions = op.getWindowOptions();
Scalar p = windowOptions.getPrecedingScalar();
Scalar f = windowOptions.getFollowingScalar();
if ((p == null || !p.isValid()) && !(windowOptions.isUnboundedPreceding() || windowOptions.isCurrentRowPreceding())) {
throw new IllegalArgumentException("Some kind of preceding must be set and a preceding column is not currently supported");
}
if ((f == null || !f.isValid()) && !op.getWindowOptions().isUnboundedFollowing()) {
if ((f == null || !f.isValid()) && !(windowOptions.isUnboundedFollowing() || windowOptions.isCurrentRowFollowing())) {
throw new IllegalArgumentException("some kind of following must be set and a follow column is not currently supported");
}
aggPrecedingWindows[opIndex] = p == null ? 0 : p.getScalarHandle();
aggFollowingWindows[opIndex] = f == null ? 0 : f.getScalarHandle();
aggPrecedingWindowsUnbounded[opIndex] = op.getWindowOptions().isUnboundedPreceding();
aggFollowingWindowsUnbounded[opIndex] = op.getWindowOptions().isUnboundedFollowing();
aggPrecedingWindowsExtent[opIndex] = windowOptions.getPrecedingBoundsExtent().nominalValue;
aggFollowingWindowsExtent[opIndex] = windowOptions.getFollowingBoundsExtent().nominalValue;
aggMinPeriods[opIndex] = op.getWindowOptions().getMinPeriods();
assert (op.getWindowOptions().getFrameType() == WindowOptions.FrameType.RANGE);
orderByColumnIndexes[opIndex] = op.getWindowOptions().getOrderByColumnIndex();
Expand All @@ -4040,7 +4042,7 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate
isOrderByOrderAscending,
aggColumnIndexes,
aggInstances, aggMinPeriods, aggPrecedingWindows, aggFollowingWindows,
aggPrecedingWindowsUnbounded, aggFollowingWindowsUnbounded,
aggPrecedingWindowsExtent, aggFollowingWindowsExtent,
groupByOptions.getIgnoreNullKeys()))) {
// prepare the final table
ColumnVector[] finalCols = new ColumnVector[windowAggregates.length];
Expand Down
81 changes: 60 additions & 21 deletions java/src/main/java/ai/rapids/cudf/WindowOptions.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,23 @@ public class WindowOptions implements AutoCloseable {

enum FrameType {ROWS, RANGE}

/**
* Extent of (range) window bounds.
* Analogous to cudf::range_window_bounds::extent_type.
*/
enum RangeExtentType {
CURRENT_ROW(0), // Bounds defined as the first/last row that matches the current row.
BOUNDED(1), // Bounds defined as the first/last row that falls within
// a specified range from the current row.
UNBOUNDED(2); // Bounds stretching to the first/last row in the entire group.

public final int nominalValue;

RangeExtentType(int n) {
this.nominalValue = n;
}
}

private final int minPeriods;
private final Scalar precedingScalar;
private final Scalar followingScalar;
Expand All @@ -33,8 +50,8 @@ enum FrameType {ROWS, RANGE}
private final int orderByColumnIndex;
private final boolean orderByOrderAscending;
private final FrameType frameType;
private final boolean isUnboundedPreceding;
private final boolean isUnboundedFollowing;
private final RangeExtentType precedingBoundsExtent;
private final RangeExtentType followingBoundsExtent;

private WindowOptions(Builder builder) {
this.minPeriods = builder.minPeriods;
Expand All @@ -57,9 +74,8 @@ private WindowOptions(Builder builder) {
this.orderByColumnIndex = builder.orderByColumnIndex;
this.orderByOrderAscending = builder.orderByOrderAscending;
this.frameType = orderByColumnIndex == -1? FrameType.ROWS : FrameType.RANGE;
this.isUnboundedPreceding = builder.isUnboundedPreceding;
this.isUnboundedFollowing = builder.isUnboundedFollowing;

this.precedingBoundsExtent = builder.precedingBoundsExtent;
this.followingBoundsExtent = builder.followingBoundsExtent;
}

@Override
Expand All @@ -72,8 +88,8 @@ public boolean equals(Object other) {
this.orderByColumnIndex == o.orderByColumnIndex &&
this.orderByOrderAscending == o.orderByOrderAscending &&
this.frameType == o.frameType &&
this.isUnboundedPreceding == o.isUnboundedPreceding &&
this.isUnboundedFollowing == o.isUnboundedFollowing;
this.precedingBoundsExtent == o.precedingBoundsExtent &&
this.followingBoundsExtent == o.followingBoundsExtent;
if (precedingCol != null) {
ret = ret && precedingCol.equals(o.precedingCol);
}
Expand Down Expand Up @@ -110,8 +126,8 @@ public int hashCode() {
if (followingScalar != null) {
ret = 31 * ret + followingScalar.hashCode();
}
ret = 31 * ret + Boolean.hashCode(isUnboundedPreceding);
ret = 31 * ret + Boolean.hashCode(isUnboundedFollowing);
ret = 31 * ret + precedingBoundsExtent.hashCode();
ret = 31 * ret + followingBoundsExtent.hashCode();
return ret;
}

Expand Down Expand Up @@ -139,9 +155,16 @@ public static Builder builder(){

boolean isOrderByOrderAscending() { return this.orderByOrderAscending; }

boolean isUnboundedPreceding() { return this.isUnboundedPreceding; }
boolean isUnboundedPreceding() { return this.precedingBoundsExtent == RangeExtentType.UNBOUNDED; }

boolean isUnboundedFollowing() { return this.isUnboundedFollowing; }
boolean isUnboundedFollowing() { return this.followingBoundsExtent == RangeExtentType.UNBOUNDED; }

boolean isCurrentRowPreceding() { return this.precedingBoundsExtent == RangeExtentType.CURRENT_ROW; }

boolean isCurrentRowFollowing() { return this.followingBoundsExtent == RangeExtentType.CURRENT_ROW; }

RangeExtentType getPrecedingBoundsExtent() { return this.precedingBoundsExtent; }
RangeExtentType getFollowingBoundsExtent() { return this.followingBoundsExtent; }

FrameType getFrameType() { return frameType; }

Expand All @@ -154,8 +177,8 @@ public static class Builder {
private ColumnVector followingCol = null;
private int orderByColumnIndex = -1;
private boolean orderByOrderAscending = true;
private boolean isUnboundedPreceding = false;
private boolean isUnboundedFollowing = false;
private RangeExtentType precedingBoundsExtent = RangeExtentType.BOUNDED;
private RangeExtentType followingBoundsExtent = RangeExtentType.BOUNDED;

/**
* Set the minimum number of observation required to evaluate an element. If there are not
Expand All @@ -171,7 +194,7 @@ public Builder minPeriods(int minPeriods) {

/**
* Set the size of the window, one entry per row. This does not take ownership of the
* columns passed in so you have to be sure that the life time of the column outlives
* columns passed in so you have to be sure that the lifetime of the column outlives
* this operation.
* @param precedingCol the number of rows preceding the current row and
* precedingCol will be live outside of WindowOptions.
Expand All @@ -185,10 +208,10 @@ public Builder window(ColumnVector precedingCol, ColumnVector followingCol) {
if (followingCol == null || followingCol.hasNulls()) {
throw new IllegalArgumentException("following cannot be null or have nulls");
}
if (isUnboundedPreceding || precedingScalar != null) {
if (precedingBoundsExtent != RangeExtentType.BOUNDED || precedingScalar != null) {
throw new IllegalStateException("preceding has already been set a different way");
}
if (isUnboundedFollowing || followingScalar != null) {
if (followingBoundsExtent != RangeExtentType.BOUNDED || followingScalar != null) {
throw new IllegalStateException("following has already been set a different way");
}
this.precedingCol = precedingCol;
Expand Down Expand Up @@ -246,19 +269,35 @@ public Builder timestampDescending() {
return orderByDescending();
}

public Builder currentRowPreceding() {
if (precedingCol != null || precedingScalar != null) {
throw new IllegalStateException("preceding has already been set a different way");
}
this.precedingBoundsExtent = RangeExtentType.CURRENT_ROW;
return this;
}

public Builder currentRowFollowing() {
if (followingCol != null || followingScalar != null) {
throw new IllegalStateException("following has already been set a different way");
}
this.followingBoundsExtent = RangeExtentType.CURRENT_ROW;
return this;
}

public Builder unboundedPreceding() {
if (precedingCol != null || precedingScalar != null) {
throw new IllegalStateException("preceding has already been set a different way");
}
this.isUnboundedPreceding = true;
this.precedingBoundsExtent = RangeExtentType.UNBOUNDED;
return this;
}

public Builder unboundedFollowing() {
if (followingCol != null || followingScalar != null) {
throw new IllegalStateException("following has already been set a different way");
}
this.isUnboundedFollowing = true;
this.followingBoundsExtent = RangeExtentType.UNBOUNDED;
return this;
}

Expand All @@ -270,7 +309,7 @@ public Builder preceding(Scalar preceding) {
if (preceding == null || !preceding.isValid()) {
throw new IllegalArgumentException("preceding cannot be null");
}
if (isUnboundedPreceding || precedingCol != null) {
if (precedingBoundsExtent != RangeExtentType.BOUNDED || precedingCol != null) {
throw new IllegalStateException("preceding has already been set a different way");
}
this.precedingScalar = preceding;
Expand All @@ -285,7 +324,7 @@ public Builder following(Scalar following) {
if (following == null || !following.isValid()) {
throw new IllegalArgumentException("following cannot be null");
}
if (isUnboundedFollowing || followingCol != null) {
if (followingBoundsExtent != RangeExtentType.BOUNDED || followingCol != null) {
throw new IllegalStateException("following has already been set a different way");
}
this.followingScalar = following;
Expand Down
49 changes: 32 additions & 17 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3221,8 +3221,8 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega
JNIEnv *env, jclass, jlong j_input_table, jintArray j_keys, jintArray j_orderby_column_indices,
jbooleanArray j_is_orderby_ascending, jintArray j_aggregate_column_indices,
jlongArray j_agg_instances, jintArray j_min_periods, jlongArray j_preceding,
jlongArray j_following, jbooleanArray j_unbounded_preceding,
jbooleanArray j_unbounded_following, jboolean ignore_null_keys) {
jlongArray j_following, jintArray j_preceding_extent, jintArray j_following_extent,
jboolean ignore_null_keys) {

JNI_NULL_CHECK(env, j_input_table, "input table is null", NULL);
JNI_NULL_CHECK(env, j_keys, "input keys are null", NULL);
Expand All @@ -3246,8 +3246,8 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega
cudf::jni::native_jintArray values{env, j_aggregate_column_indices};
cudf::jni::native_jpointerArray<cudf::aggregation> agg_instances(env, j_agg_instances);
cudf::jni::native_jintArray min_periods{env, j_min_periods};
cudf::jni::native_jbooleanArray unbounded_preceding{env, j_unbounded_preceding};
cudf::jni::native_jbooleanArray unbounded_following{env, j_unbounded_following};
cudf::jni::native_jintArray preceding_extent{env, j_preceding_extent};
cudf::jni::native_jintArray following_extent{env, j_following_extent};
cudf::jni::native_jpointerArray<cudf::scalar> preceding(env, j_preceding);
cudf::jni::native_jpointerArray<cudf::scalar> following(env, j_following);

Expand All @@ -3266,24 +3266,32 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega
int agg_column_index = values[i];
cudf::column_view const &order_by_column = input_table->column(orderbys[i]);
cudf::data_type order_by_type = order_by_column.type();
cudf::data_type unbounded_type = order_by_type;

if (unbounded_preceding[i] || unbounded_following[i]) {
cudf::data_type duration_type = order_by_type;

// Range extents are defined as:
// a) 0 == CURRENT ROW
// b) 1 == BOUNDED
// c) 2 == UNBOUNDED
// Must set unbounded_type for only the BOUNDED case.
auto constexpr CURRENT_ROW = 0;
auto constexpr BOUNDED = 1;
auto constexpr UNBOUNDED = 2;
if (preceding_extent[i] != BOUNDED || following_extent[i] != BOUNDED) {
switch (order_by_type.id()) {
case cudf::type_id::TIMESTAMP_DAYS:
unbounded_type = cudf::data_type{cudf::type_id::DURATION_DAYS};
duration_type = cudf::data_type{cudf::type_id::DURATION_DAYS};
break;
case cudf::type_id::TIMESTAMP_SECONDS:
unbounded_type = cudf::data_type{cudf::type_id::DURATION_SECONDS};
duration_type = cudf::data_type{cudf::type_id::DURATION_SECONDS};
break;
case cudf::type_id::TIMESTAMP_MILLISECONDS:
unbounded_type = cudf::data_type{cudf::type_id::DURATION_MILLISECONDS};
duration_type = cudf::data_type{cudf::type_id::DURATION_MILLISECONDS};
break;
case cudf::type_id::TIMESTAMP_MICROSECONDS:
unbounded_type = cudf::data_type{cudf::type_id::DURATION_MICROSECONDS};
duration_type = cudf::data_type{cudf::type_id::DURATION_MICROSECONDS};
break;
case cudf::type_id::TIMESTAMP_NANOSECONDS:
unbounded_type = cudf::data_type{cudf::type_id::DURATION_NANOSECONDS};
duration_type = cudf::data_type{cudf::type_id::DURATION_NANOSECONDS};
break;
default: break;
}
Expand All @@ -3293,15 +3301,22 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega
JNI_ARG_CHECK(env, agg != nullptr, "aggregation is not an instance of rolling_aggregation",
nullptr);

auto const make_window_bounds = [&](auto const &range_extent, auto const *p_scalar) {
if (range_extent == CURRENT_ROW) {
return cudf::range_window_bounds::current_row(duration_type);
} else if (range_extent == UNBOUNDED) {
return cudf::range_window_bounds::unbounded(duration_type);
} else {
return cudf::range_window_bounds::get(*p_scalar);
}
};

result_columns.emplace_back(cudf::grouped_range_rolling_window(
groupby_keys, order_by_column,
orderbys_ascending[i] ? cudf::order::ASCENDING : cudf::order::DESCENDING,
input_table->column(agg_column_index),
unbounded_preceding[i] ? cudf::range_window_bounds::unbounded(unbounded_type) :
cudf::range_window_bounds::get(*preceding[i]),
unbounded_following[i] ? cudf::range_window_bounds::unbounded(unbounded_type) :
cudf::range_window_bounds::get(*following[i]),
min_periods[i], *agg));
make_window_bounds(preceding_extent[i], preceding[i]),
make_window_bounds(following_extent[i], following[i]), min_periods[i], *agg));
}

auto result_table = std::make_unique<cudf::table>(std::move(result_columns));
Expand Down
Loading

0 comments on commit 181b946

Please sign in to comment.