Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added in JNI support for out of core sort algorithm [skip ci] #7381

Merged
merged 2 commits into from
Feb 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 118 additions & 25 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import ai.rapids.cudf.HostColumnVector.StructType;

import java.io.File;
import java.io.Serializable;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -466,6 +467,9 @@ private static native long[] timeRangeRollingWindowAggregate(long inputTable, in
int[] preceding, int[] following, boolean[] unboundedPreceding, boolean[] unboundedFollowing,
boolean ignoreNullKeys) throws CudfException;

private static native long sortOrder(long inputTable, long[] sortKeys, boolean[] isDescending,
boolean[] areNullsSmallest) throws CudfException;

private static native long[] orderBy(long inputTable, long[] sortKeys, boolean[] isDescending,
boolean[] areNullsSmallest) throws CudfException;

Expand Down Expand Up @@ -1247,7 +1251,8 @@ public Table repeat(ColumnVector counts, boolean checkCount) {
}

/**
* Given a sorted table return the lower bound.
* Find smallest indices in a sorted table where values should be inserted to maintain order.
* <pre>
* Example:
*
* Single column:
Expand All @@ -1265,14 +1270,11 @@ public Table repeat(ColumnVector counts, boolean checkCount) {
* { .7 },
* { 61 }}
* result = { 3 }
* NaNs in column values produce incorrect results.
* </pre>
* The input table and the values table need to be non-empty (row count > 0)
* The column data types of the tables' have to match in order.
* Strings and String categories do not work for this method. If the input table is
* unsorted the results are wrong. Types of columns can be of mixed data types.
* @param areNullsSmallest true if nulls are assumed smallest
* @param valueTable the table of values that need to be inserted
* @param descFlags indicates the ordering of the column(s), true if descending
* @param areNullsSmallest per column, true if nulls are assumed smallest
* @param valueTable the table of values to find insertion locations for
* @param descFlags per column indicates the ordering, true if descending.
* @return ColumnVector with lower bound indices for all rows in valueTable
*/
public ColumnVector lowerBound(boolean[] areNullsSmallest,
Expand All @@ -1283,7 +1285,34 @@ public ColumnVector lowerBound(boolean[] areNullsSmallest,
}

/**
* Find smallest indices in a sorted table where values should be inserted to maintain order.
* This is a convenience method. It pulls out the columns indicated by the args and sets up the
* ordering properly to call `lowerBound`.
* @param valueTable the table of values to find insertion locations for
* @param args the sort order used to sort this table.
* @return ColumnVector with lower bound indices for all rows in valueTable
*/
public ColumnVector lowerBound(Table valueTable, OrderByArg... args) {
boolean[] areNullsSmallest = new boolean[args.length];
boolean[] descFlags = new boolean[args.length];
ColumnVector[] inputColumns = new ColumnVector[args.length];
ColumnVector[] searchColumns = new ColumnVector[args.length];
for (int i = 0; i < args.length; i++) {
areNullsSmallest[i] = args[i].isNullSmallest;
descFlags[i] = args[i].isDescending;
inputColumns[i] = columns[args[i].index];
searchColumns[i] = valueTable.columns[args[i].index];
}
try (Table input = new Table(inputColumns);
Table search = new Table(searchColumns)) {
return input.lowerBound(areNullsSmallest, search, descFlags);
}
}

/**
* Find largest indices in a sorted table where values should be inserted to maintain order.
* Given a sorted table return the upper bound.
* <pre>
* Example:
*
* Single column:
Expand All @@ -1301,14 +1330,11 @@ public ColumnVector lowerBound(boolean[] areNullsSmallest,
* { .7 },
* { 61 }}
* result = { 5 }
* NaNs in column values produce incorrect results.
* </pre>
* The input table and the values table need to be non-empty (row count > 0)
* The column data types of the tables' have to match in order.
* Strings and String categories do not work for this method. If the input table is
* unsorted the results are wrong. Types of columns can be of mixed data types.
* @param areNullsSmallest true if nulls are assumed smallest
* @param valueTable the table of values that need to be inserted
* @param descFlags indicates the ordering of the column(s), true if descending
* @param areNullsSmallest per column, true if nulls are assumed smallest
* @param valueTable the table of values to find insertion locations for
* @param descFlags per column indicates the ordering, true if descending.
* @return ColumnVector with upper bound indices for all rows in valueTable
*/
public ColumnVector upperBound(boolean[] areNullsSmallest,
Expand All @@ -1318,6 +1344,31 @@ public ColumnVector upperBound(boolean[] areNullsSmallest,
descFlags, areNullsSmallest, true));
}

/**
* Find largest indices in a sorted table where values should be inserted to maintain order.
* This is a convenience method. It pulls out the columns indicated by the args and sets up the
* ordering properly to call `upperBound`.
* @param valueTable the table of values to find insertion locations for
* @param args the sort order used to sort this table.
* @return ColumnVector with upper bound indices for all rows in valueTable
*/
public ColumnVector upperBound(Table valueTable, OrderByArg... args) {
boolean[] areNullsSmallest = new boolean[args.length];
boolean[] descFlags = new boolean[args.length];
ColumnVector[] inputColumns = new ColumnVector[args.length];
ColumnVector[] searchColumns = new ColumnVector[args.length];
for (int i = 0; i < args.length; i++) {
areNullsSmallest[i] = args[i].isNullSmallest;
descFlags[i] = args[i].isDescending;
inputColumns[i] = columns[args[i].index];
searchColumns[i] = valueTable.columns[args[i].index];
}
try (Table input = new Table(inputColumns);
Table search = new Table(searchColumns)) {
return input.upperBound(areNullsSmallest, search, descFlags);
}
}

private void assertForBounds(Table valueTable) {
assert this.getRowCount() != 0 : "Input table cannot be empty";
assert valueTable.getRowCount() != 0 : "Value table cannot be empty";
Expand All @@ -1342,17 +1393,39 @@ public Table crossJoin(Table right) {
// TABLE MANIPULATION APIs
/////////////////////////////////////////////////////////////////////////////

/**
* Get back a gather map that can be used to sort the data. This allows you to sort by data
* that does not appear in the final result and not pay the cost of gathering the data that
* is only needed for sorting.
* @param args what order to sort the data by
* @return a gather map
*/
public ColumnVector sortOrder(OrderByArg... args) {
long[] sortKeys = new long[args.length];
boolean[] isDescending = new boolean[args.length];
boolean[] areNullsSmallest = new boolean[args.length];
for (int i = 0; i < args.length; i++) {
int index = args[i].index;
assert (index >= 0 && index < columns.length) :
"index is out of range 0 <= " + index + " < " + columns.length;
isDescending[i] = args[i].isDescending;
areNullsSmallest[i] = args[i].isNullSmallest;
sortKeys[i] = columns[index].getNativeView();
}

return new ColumnVector(sortOrder(nativeHandle, sortKeys, isDescending, areNullsSmallest));
}

/**
* Orders the table using the sortkeys returning a new allocated table. The caller is
* responsible for cleaning up
* the {@link ColumnVector} returned as part of the output {@link Table}
* <p>
* Example usage: orderBy(true, Table.asc(0), Table.desc(3)...);
* @param args - Suppliers to initialize sortKeys.
* @param args Suppliers to initialize sortKeys.
* @return Sorted Table
*/
public Table orderBy(OrderByArg... args) {
assert args.length <= columns.length;
long[] sortKeys = new long[args.length];
boolean[] isDescending = new boolean[args.length];
boolean[] areNullsSmallest = new boolean[args.length];
Expand All @@ -1377,13 +1450,13 @@ public Table orderBy(OrderByArg... args) {
* initially.
* @return a combined sorted table.
*/
public static Table merge(List<Table> tables, OrderByArg... args) {
assert !tables.isEmpty();
long[] tableHandles = new long[tables.size()];
Table first = tables.get(0);
public static Table merge(Table[] tables, OrderByArg... args) {
assert tables.length > 0;
long[] tableHandles = new long[tables.length];
Table first = tables[0];
assert args.length <= first.columns.length;
for (int i = 0; i < tables.size(); i++) {
Table t = tables.get(i);
for (int i = 0; i < tables.length; i++) {
Table t = tables[i];
assert t != null;
assert t.columns.length == first.columns.length;
tableHandles[i] = t.nativeHandle;
Expand All @@ -1394,7 +1467,7 @@ public static Table merge(List<Table> tables, OrderByArg... args) {
for (int i = 0; i < args.length; i++) {
int index = args[i].index;
assert (index >= 0 && index < first.columns.length) :
"index is out of range 0 <= " + index + " < " + first.columns.length;
"index is out of range 0 <= " + index + " < " + first.columns.length;
isDescending[i] = args[i].isDescending;
areNullsSmallest[i] = args[i].isNullSmallest;
sortKeyIndexes[i] = index;
Expand All @@ -1403,6 +1476,19 @@ public static Table merge(List<Table> tables, OrderByArg... args) {
return new Table(merge(tableHandles, sortKeyIndexes, isDescending, areNullsSmallest));
}

/**
* Merge multiple already sorted tables keeping the sort order the same.
* This is a more efficient version of concatenate followed by orderBy, but requires that
* the input already be sorted.
* @param tables the tables that should be merged.
* @param args the ordering of the tables. Should match how they were sorted
* initially.
* @return a combined sorted table.
*/
public static Table merge(List<Table> tables, OrderByArg... args) {
return merge(tables.toArray(new Table[tables.size()]), args);
}

public static OrderByArg asc(final int index) {
return new OrderByArg(index, false, false);
}
Expand Down Expand Up @@ -1852,7 +1938,7 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data
// HELPER CLASSES
/////////////////////////////////////////////////////////////////////////////

public static final class OrderByArg {
public static final class OrderByArg implements Serializable {
final int index;
final boolean isDescending;
final boolean isNullSmallest;
Expand All @@ -1862,6 +1948,13 @@ public static final class OrderByArg {
this.isDescending = isDescending;
this.isNullSmallest = isNullSmallest;
}

@Override
public String toString() {
return "ORDER BY " + index +
(isDescending ? " DESC " : " ASC ") +
(isNullSmallest ? "NULL SMALLEST" : "NULL LARGEST");
}
}

/**
Expand Down
Loading