Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added in JNI support for out of core sort algorithm [skip ci] #7381

Merged
merged 2 commits into from
Feb 16, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 102 additions & 10 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import ai.rapids.cudf.HostColumnVector.StructType;

import java.io.File;
import java.io.Serializable;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -466,6 +467,9 @@ private static native long[] timeRangeRollingWindowAggregate(long inputTable, in
int[] preceding, int[] following, boolean[] unboundedPreceding, boolean[] unboundedFollowing,
boolean ignoreNullKeys) throws CudfException;

private static native long sortOrder(long inputTable, long[] sortKeys, boolean[] isDescending,
boolean[] areNullsSmallest) throws CudfException;

private static native long[] orderBy(long inputTable, long[] sortKeys, boolean[] isDescending,
boolean[] areNullsSmallest) throws CudfException;

Expand Down Expand Up @@ -1282,6 +1286,29 @@ public ColumnVector lowerBound(boolean[] areNullsSmallest,
descFlags, areNullsSmallest, false));
}

/**
* Given a sorted table return the lower bound.
* @param valueTable the table of values that would be inserted.
* @param args the sort order used to sort the table.
jlowe marked this conversation as resolved.
Show resolved Hide resolved
* @return ColumnVector with lower bound indices for all rows in valueTable
*/
public ColumnVector lowerBound(Table valueTable, OrderByArg... args) {
boolean[] areNullsSmallest = new boolean[args.length];
boolean[] descFlags = new boolean[args.length];
ColumnVector[] inputColumns = new ColumnVector[args.length];
ColumnVector[] searchColumns = new ColumnVector[args.length];
for (int i = 0; i < args.length; i++) {
areNullsSmallest[i] = args[i].isNullSmallest;
descFlags[i] = args[i].isDescending;
inputColumns[i] = columns[args[i].index];
searchColumns[i] = valueTable.columns[args[i].index];
}
try (Table input = new Table(inputColumns);
Table search = new Table(searchColumns)) {
return input.lowerBound(areNullsSmallest, search, descFlags);
}
}

/**
* Given a sorted table return the upper bound.
* Example:
Expand Down Expand Up @@ -1318,6 +1345,29 @@ public ColumnVector upperBound(boolean[] areNullsSmallest,
descFlags, areNullsSmallest, true));
}

/**
* Given a sorted table return the upper bound.
* @param valueTable the table of values that would be inserted.
* @param args the sort order used to sort the table.
* @return ColumnVector with upper bound indices for all rows in valueTable
*/
public ColumnVector upperBound(Table valueTable, OrderByArg... args) {
boolean[] areNullsSmallest = new boolean[args.length];
boolean[] descFlags = new boolean[args.length];
ColumnVector[] inputColumns = new ColumnVector[args.length];
ColumnVector[] searchColumns = new ColumnVector[args.length];
for (int i = 0; i < args.length; i++) {
areNullsSmallest[i] = args[i].isNullSmallest;
descFlags[i] = args[i].isDescending;
inputColumns[i] = columns[args[i].index];
searchColumns[i] = valueTable.columns[args[i].index];
}
try (Table input = new Table(inputColumns);
Table search = new Table(searchColumns)) {
return input.upperBound(areNullsSmallest, search, descFlags);
}
}

private void assertForBounds(Table valueTable) {
assert this.getRowCount() != 0 : "Input table cannot be empty";
assert valueTable.getRowCount() != 0 : "Value table cannot be empty";
Expand All @@ -1342,17 +1392,39 @@ public Table crossJoin(Table right) {
// TABLE MANIPULATION APIs
/////////////////////////////////////////////////////////////////////////////

/**
* Get back a gather map that can be used to sort the data. This allows you to sort by data
* that does not appear in the final result and not pay the cost of gathering the data that
* is only needed for sorting.
* @param args what order to sort the data by
* @return a gather map
*/
public ColumnVector sortOrder(OrderByArg... args) {
long[] sortKeys = new long[args.length];
boolean[] isDescending = new boolean[args.length];
boolean[] areNullsSmallest = new boolean[args.length];
for (int i = 0; i < args.length; i++) {
int index = args[i].index;
assert (index >= 0 && index < columns.length) :
"index is out of range 0 <= " + index + " < " + columns.length;
isDescending[i] = args[i].isDescending;
areNullsSmallest[i] = args[i].isNullSmallest;
sortKeys[i] = columns[index].getNativeView();
}

return new ColumnVector(sortOrder(nativeHandle, sortKeys, isDescending, areNullsSmallest));
}

/**
* Orders the table using the sortkeys returning a new allocated table. The caller is
* responsible for cleaning up
* the {@link ColumnVector} returned as part of the output {@link Table}
* <p>
* Example usage: orderBy(true, Table.asc(0), Table.desc(3)...);
* @param args - Suppliers to initialize sortKeys.
* @param args Suppliers to initialize sortKeys.
* @return Sorted Table
*/
public Table orderBy(OrderByArg... args) {
assert args.length <= columns.length;
long[] sortKeys = new long[args.length];
boolean[] isDescending = new boolean[args.length];
boolean[] areNullsSmallest = new boolean[args.length];
Expand All @@ -1377,13 +1449,13 @@ public Table orderBy(OrderByArg... args) {
* initially.
* @return a combined sorted table.
*/
public static Table merge(List<Table> tables, OrderByArg... args) {
assert !tables.isEmpty();
long[] tableHandles = new long[tables.size()];
Table first = tables.get(0);
public static Table merge(Table[] tables, OrderByArg... args) {
assert tables.length > 0;
long[] tableHandles = new long[tables.length];
Table first = tables[0];
assert args.length <= first.columns.length;
for (int i = 0; i < tables.size(); i++) {
Table t = tables.get(i);
for (int i = 0; i < tables.length; i++) {
Table t = tables[i];
assert t != null;
assert t.columns.length == first.columns.length;
tableHandles[i] = t.nativeHandle;
Expand All @@ -1394,7 +1466,7 @@ public static Table merge(List<Table> tables, OrderByArg... args) {
for (int i = 0; i < args.length; i++) {
int index = args[i].index;
assert (index >= 0 && index < first.columns.length) :
"index is out of range 0 <= " + index + " < " + first.columns.length;
"index is out of range 0 <= " + index + " < " + first.columns.length;
isDescending[i] = args[i].isDescending;
areNullsSmallest[i] = args[i].isNullSmallest;
sortKeyIndexes[i] = index;
Expand All @@ -1403,6 +1475,19 @@ public static Table merge(List<Table> tables, OrderByArg... args) {
return new Table(merge(tableHandles, sortKeyIndexes, isDescending, areNullsSmallest));
}

/**
* Merge multiple already sorted tables keeping the sort order the same.
* This is a more efficient version of concatenate followed by orderBy, but requires that
* the input already be sorted.
* @param tables the tables that should be merged.
* @param args the ordering of the tables. Should match how they were sorted
* initially.
* @return a combined sorted table.
*/
public static Table merge(List<Table> tables, OrderByArg... args) {
return merge(tables.toArray(new Table[tables.size()]), args);
}

public static OrderByArg asc(final int index) {
return new OrderByArg(index, false, false);
}
Expand Down Expand Up @@ -1852,7 +1937,7 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data
// HELPER CLASSES
/////////////////////////////////////////////////////////////////////////////

public static final class OrderByArg {
public static final class OrderByArg implements Serializable {
final int index;
final boolean isDescending;
final boolean isNullSmallest;
Expand All @@ -1862,6 +1947,13 @@ public static final class OrderByArg {
this.isDescending = isDescending;
this.isNullSmallest = isNullSmallest;
}

@Override
public String toString() {
return "ORDER BY " + index +
(isDescending ? " DESC " : " ASC ") +
(isNullSmallest ? "NULL SMALLEST" : "NULL LARGEST");
}
}

/**
Expand Down
58 changes: 57 additions & 1 deletion java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,62 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_columnViewsFromPacked(JNI
CATCH_STD(env, nullptr);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_sortOrder(JNIEnv *env, jclass j_class_object,
jlowe marked this conversation as resolved.
Show resolved Hide resolved
jlong j_input_table,
jlongArray j_sort_keys_columns,
jbooleanArray j_is_descending,
jbooleanArray j_are_nulls_smallest) {

// input validations & verifications
JNI_NULL_CHECK(env, j_input_table, "input table is null", 0);
JNI_NULL_CHECK(env, j_sort_keys_columns, "sort keys columns is null", 0);
JNI_NULL_CHECK(env, j_is_descending, "sort order array is null", 0);
JNI_NULL_CHECK(env, j_are_nulls_smallest, "null order array is null", 0);

try {
cudf::jni::auto_set_device(env);
cudf::jni::native_jpointerArray<cudf::column_view> n_sort_keys_columns(env,
j_sort_keys_columns);
jsize num_columns = n_sort_keys_columns.size();
const cudf::jni::native_jbooleanArray n_is_descending(env, j_is_descending);
jsize num_columns_is_desc = n_is_descending.size();

if (num_columns_is_desc != num_columns) {
JNI_THROW_NEW(env, "java/lang/IllegalArgumentException",
"columns and is_descending lengths don't match", 0);
}
jlowe marked this conversation as resolved.
Show resolved Hide resolved

const cudf::jni::native_jbooleanArray n_are_nulls_smallest(env, j_are_nulls_smallest);
jsize num_columns_null_smallest = n_are_nulls_smallest.size();

if (num_columns_null_smallest != num_columns) {
JNI_THROW_NEW(env, "java/lang/IllegalArgumentException",
"columns and areNullsSmallest lengths don't match", 0);
}
jlowe marked this conversation as resolved.
Show resolved Hide resolved

std::vector<cudf::order> order(n_is_descending.size());
for (int i = 0; i < n_is_descending.size(); i++) {
order[i] = n_is_descending[i] ? cudf::order::DESCENDING : cudf::order::ASCENDING;
}
std::vector<cudf::null_order> null_order(n_are_nulls_smallest.size());
for (int i = 0; i < n_are_nulls_smallest.size(); i++) {
null_order[i] = n_are_nulls_smallest[i] ? cudf::null_order::BEFORE : cudf::null_order::AFTER;
}

std::vector<cudf::column_view> columns;
columns.reserve(num_columns);
jlowe marked this conversation as resolved.
Show resolved Hide resolved
for (int i = 0; i < num_columns; i++) {
columns.push_back(*n_sort_keys_columns[i]);
}
cudf::table_view keys(columns);

auto sorted_col = cudf::sorted_order(keys, order, null_order);
return reinterpret_cast<jlong>(sorted_col.release());
}
CATCH_STD(env, 0);
}


JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_orderBy(JNIEnv *env, jclass j_class_object,
jlong j_input_table,
jlongArray j_sort_keys_columns,
Expand All @@ -646,7 +702,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_orderBy(JNIEnv *env, jcla

// input validations & verifications
JNI_NULL_CHECK(env, j_input_table, "input table is null", NULL);
JNI_NULL_CHECK(env, j_sort_keys_columns, "input table is null", NULL);
JNI_NULL_CHECK(env, j_sort_keys_columns, "sort keys columns is null", NULL);
JNI_NULL_CHECK(env, j_is_descending, "sort order array is null", NULL);
JNI_NULL_CHECK(env, j_are_nulls_smallest, "null order array is null", NULL);

Expand Down
18 changes: 18 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,24 @@ void testOrderByAD() {
}
}

@Test
void testSortOrderSimple() {
try (Table table = new Table.TestBuilder()
.column(5, 3, 3, 1, 1)
.column(5, 3, 4, 1, 2)
.column(1, 3, 5, 7, 9)
.build();
Table expected = new Table.TestBuilder()
.column(1, 1, 3, 3, 5)
.column(2, 1, 4, 3, 5)
.column(9, 7, 5, 3, 1)
.build();
ColumnVector gatherMap = table.sortOrder(Table.asc(0), Table.desc(1));
Table sortedTable = table.gather(gatherMap)) {
assertTablesAreEqual(expected, sortedTable);
}
}

@Test
void testOrderByDD() {
try (Table table = new Table.TestBuilder()
Expand Down