From e6a0e8a1a5ab83097bab92d362701c729b557057 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 15 Dec 2021 18:12:03 -0600 Subject: [PATCH 1/5] Java bindings for mixed left, inner and full joins --- .../java/ai/rapids/cudf/MixedJoinSize.java | 43 +++ java/src/main/java/ai/rapids/cudf/Table.java | 235 +++++++++++- java/src/main/native/src/TableJni.cpp | 163 +++++++++ .../test/java/ai/rapids/cudf/TableTest.java | 340 ++++++++++++++++++ 4 files changed, 780 insertions(+), 1 deletion(-) create mode 100644 java/src/main/java/ai/rapids/cudf/MixedJoinSize.java diff --git a/java/src/main/java/ai/rapids/cudf/MixedJoinSize.java b/java/src/main/java/ai/rapids/cudf/MixedJoinSize.java new file mode 100644 index 00000000000..8e95f87b3da --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/MixedJoinSize.java @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf; + +/** This class tracks size information associated with a mixed table join. */ +public final class MixedJoinSize implements AutoCloseable { + private final long outputRowCount; + // This is in flux, avoid exposing publicly until the dust settles. + private ColumnVector matches; + + MixedJoinSize(long outputRowCount, ColumnVector matches) { + this.outputRowCount = outputRowCount; + this.matches = matches; + } + + /** Return the number of output rows that would be generated from the mixed join */ + public long getOutputRowCount() { + return outputRowCount; + } + + ColumnVector getMatches() { + return matches; + } + + @Override + public synchronized void close() { + matches.close(); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 00c98c4fef8..4779372c929 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -635,6 +635,36 @@ private static native long[] conditionalLeftAntiJoinGatherMapWithCount(long left long condition, long rowCount) throws CudfException; + private static native long[] mixedLeftJoinSize(long leftKeysTable, long rightKeysTable, + long leftConditionTable, long rightConditionTable, + long condition, boolean compareNullsEqual); + + private static native long[] mixedLeftJoinGatherMaps(long leftKeysTable, long rightKeysTable, + long leftConditionTable, long rightConditionTable, + long condition, boolean compareNullsEqual); + + private static native long[] mixedLeftJoinGatherMapsWithSize(long leftKeysTable, long rightKeysTable, + long leftConditionTable, long rightConditionTable, + long condition, boolean compareNullsEqual, + long outputRowCount, long matchesColumnView); + + private static native long[] mixedInnerJoinSize(long leftKeysTable, long rightKeysTable, + long leftConditionTable, long rightConditionTable, + long condition, boolean compareNullsEqual); + + private static native long[] mixedInnerJoinGatherMaps(long leftKeysTable, long rightKeysTable, + long leftConditionTable, long rightConditionTable, + long condition, boolean compareNullsEqual); + + private static native long[] mixedInnerJoinGatherMapsWithSize(long leftKeysTable, long rightKeysTable, + long leftConditionTable, long rightConditionTable, + long condition, boolean compareNullsEqual, + long outputRowCount, long matchesColumnView); + + private static native long[] mixedFullJoinGatherMaps(long leftKeysTable, long rightKeysTable, + long leftConditionTable, long rightConditionTable, + long condition, boolean compareNullsEqual); + private static native long[] crossJoin(long leftTable, long rightTable) throws CudfException; private static native long[] concatenate(long[] cudfTablePointers) throws CudfException; @@ -2121,7 +2151,7 @@ public static Table scatter(Scalar[] source, ColumnView scatterMap, Table target target.getNativeView(), checkBounds)); } - private GatherMap[] buildJoinGatherMaps(long[] gatherMapData) { + private static GatherMap[] buildJoinGatherMaps(long[] gatherMapData) { long bufferSize = gatherMapData[0]; long leftAddr = gatherMapData[1]; long leftHandle = gatherMapData[2]; @@ -2274,6 +2304,94 @@ public GatherMap[] conditionalLeftJoinGatherMaps(Table rightTable, return buildJoinGatherMaps(gatherMapData); } + /** + * Computes output size information for a left join between two tables using a mix of equality + * and inequality conditions. The entire join condition is assumed to be a logical AND of the + * equality condition and inequality condition. + * NOTE: It is the responsibility of the caller to close the resulting size information object + * or native resources can be leaked! + * @param leftKeys the left table's key columns for the equality condition + * @param rightKeys the right table's key columns for the equality condition + * @param leftConditional the left table's columns needed to evaluate the inequality condition + * @param rightConditional the right table's columns needed to evaluate the inequality condition + * @param condition the inequality condition of the join + * @param nullEquality whether nulls should compare as equal + * @return size information for the join + */ + public static MixedJoinSize mixedLeftJoinSize(Table leftKeys, Table rightKeys, + Table leftConditional, Table rightConditional, + CompiledExpression condition, + NullEquality nullEquality) { + long[] mixedSizeInfo = mixedLeftJoinSize( + leftKeys.getNativeView(), rightKeys.getNativeView(), + leftConditional.getNativeView(), rightConditional.getNativeView(), + condition.getNativeHandle(), nullEquality == NullEquality.EQUAL); + assert mixedSizeInfo.length == 2; + long outputRowCount = mixedSizeInfo[0]; + long matchesColumnHandle = mixedSizeInfo[1]; + return new MixedJoinSize(outputRowCount, new ColumnVector(matchesColumnHandle)); + } + + /** + * Computes the gather maps that can be used to manifest the result of a left join between + * two tables using a mix of equality and inequality conditions. The entire join condition is + * assumed to be a logical AND of the equality condition and inequality condition. + * Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the left join. + * It is the responsibility of the caller to close the resulting gather map instances. + * @param leftKeys the left table's key columns for the equality condition + * @param rightKeys the right table's key columns for the equality condition + * @param leftConditional the left table's columns needed to evaluate the inequality condition + * @param rightConditional the right table's columns needed to evaluate the inequality condition + * @param condition the inequality condition of the join + * @param nullEquality whether nulls should compare as equal + * @return left and right table gather maps + */ + public static GatherMap[] mixedLeftJoinGatherMaps(Table leftKeys, Table rightKeys, + Table leftConditional, Table rightConditional, + CompiledExpression condition, + NullEquality nullEquality) { + long[] gatherMapData = mixedLeftJoinGatherMaps( + leftKeys.getNativeView(), rightKeys.getNativeView(), + leftConditional.getNativeView(), rightConditional.getNativeView(), + condition.getNativeHandle(), + nullEquality == NullEquality.EQUAL); + return buildJoinGatherMaps(gatherMapData); + } + + /** + * Computes the gather maps that can be used to manifest the result of a left join between + * two tables using a mix of equality and inequality conditions. The entire join condition is + * assumed to be a logical AND of the equality condition and inequality condition. + * Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the left join. + * It is the responsibility of the caller to close the resulting gather map instances. + * This interface allows passing the size result from + * {@link #mixedLeftJoinSize(Table, Table, Table, Table, CompiledExpression, NullEquality)} + * when the output size was computed previously. + * @param leftKeys the left table's key columns for the equality condition + * @param rightKeys the right table's key columns for the equality condition + * @param leftConditional the left table's columns needed to evaluate the inequality condition + * @param rightConditional the right table's columns needed to evaluate the inequality condition + * @param condition the inequality condition of the join + * @param nullEquality whether nulls should compare as equal + * @param joinSize mixed join size result + * @return left and right table gather maps + */ + public static GatherMap[] mixedLeftJoinGatherMaps(Table leftKeys, Table rightKeys, + Table leftConditional, Table rightConditional, + CompiledExpression condition, + NullEquality nullEquality, + MixedJoinSize joinSize) { + long[] gatherMapData = mixedLeftJoinGatherMapsWithSize( + leftKeys.getNativeView(), rightKeys.getNativeView(), + leftConditional.getNativeView(), rightConditional.getNativeView(), + condition.getNativeHandle(), + nullEquality == NullEquality.EQUAL, + joinSize.getOutputRowCount(), joinSize.getMatches().getNativeView()); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the gather maps that can be used to manifest the result of an inner equi-join between * two tables. It is assumed this table instance holds the key columns from the left table, and @@ -2414,6 +2532,94 @@ public GatherMap[] conditionalInnerJoinGatherMaps(Table rightTable, return buildJoinGatherMaps(gatherMapData); } + /** + * Computes output size information for an inner join between two tables using a mix of equality + * and inequality conditions. The entire join condition is assumed to be a logical AND of the + * equality condition and inequality condition. + * NOTE: It is the responsibility of the caller to close the resulting size information object + * or native resources can be leaked! + * @param leftKeys the left table's key columns for the equality condition + * @param rightKeys the right table's key columns for the equality condition + * @param leftConditional the left table's columns needed to evaluate the inequality condition + * @param rightConditional the right table's columns needed to evaluate the inequality condition + * @param condition the inequality condition of the join + * @param nullEquality whether nulls should compare as equal + * @return size information for the join + */ + public static MixedJoinSize mixedInnerJoinSize(Table leftKeys, Table rightKeys, + Table leftConditional, Table rightConditional, + CompiledExpression condition, + NullEquality nullEquality) { + long[] mixedSizeInfo = mixedInnerJoinSize( + leftKeys.getNativeView(), rightKeys.getNativeView(), + leftConditional.getNativeView(), rightConditional.getNativeView(), + condition.getNativeHandle(), nullEquality == NullEquality.EQUAL); + assert mixedSizeInfo.length == 2; + long outputRowCount = mixedSizeInfo[0]; + long matchesColumnHandle = mixedSizeInfo[1]; + return new MixedJoinSize(outputRowCount, new ColumnVector(matchesColumnHandle)); + } + + /** + * Computes the gather maps that can be used to manifest the result of an inner join between + * two tables using a mix of equality and inequality conditions. The entire join condition is + * assumed to be a logical AND of the equality condition and inequality condition. + * Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the inner join. + * It is the responsibility of the caller to close the resulting gather map instances. + * @param leftKeys the left table's key columns for the equality condition + * @param rightKeys the right table's key columns for the equality condition + * @param leftConditional the left table's columns needed to evaluate the inequality condition + * @param rightConditional the right table's columns needed to evaluate the inequality condition + * @param condition the inequality condition of the join + * @param nullEquality whether nulls should compare as equal + * @return left and right table gather maps + */ + public static GatherMap[] mixedInnerJoinGatherMaps(Table leftKeys, Table rightKeys, + Table leftConditional, Table rightConditional, + CompiledExpression condition, + NullEquality nullEquality) { + long[] gatherMapData = mixedInnerJoinGatherMaps( + leftKeys.getNativeView(), rightKeys.getNativeView(), + leftConditional.getNativeView(), rightConditional.getNativeView(), + condition.getNativeHandle(), + nullEquality == NullEquality.EQUAL); + return buildJoinGatherMaps(gatherMapData); + } + + /** + * Computes the gather maps that can be used to manifest the result of an inner join between + * two tables using a mix of equality and inequality conditions. The entire join condition is + * assumed to be a logical AND of the equality condition and inequality condition. + * Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the inner join. + * It is the responsibility of the caller to close the resulting gather map instances. + * This interface allows passing the size result from + * {@link #mixedInnerJoinSize(Table, Table, Table, Table, CompiledExpression, NullEquality)} + * when the output size was computed previously. + * @param leftKeys the left table's key columns for the equality condition + * @param rightKeys the right table's key columns for the equality condition + * @param leftConditional the left table's columns needed to evaluate the inequality condition + * @param rightConditional the right table's columns needed to evaluate the inequality condition + * @param condition the inequality condition of the join + * @param nullEquality whether nulls should compare as equal + * @param joinSize mixed join size result + * @return left and right table gather maps + */ + public static GatherMap[] mixedInnerJoinGatherMaps(Table leftKeys, Table rightKeys, + Table leftConditional, Table rightConditional, + CompiledExpression condition, + NullEquality nullEquality, + MixedJoinSize joinSize) { + long[] gatherMapData = mixedInnerJoinGatherMapsWithSize( + leftKeys.getNativeView(), rightKeys.getNativeView(), + leftConditional.getNativeView(), rightConditional.getNativeView(), + condition.getNativeHandle(), + nullEquality == NullEquality.EQUAL, + joinSize.getOutputRowCount(), joinSize.getMatches().getNativeView()); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the gather maps that can be used to manifest the result of an full equi-join between * two tables. It is assumed this table instance holds the key columns from the left table, and @@ -2520,6 +2726,33 @@ public GatherMap[] conditionalFullJoinGatherMaps(Table rightTable, return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the gather maps that can be used to manifest the result of a full join between + * two tables using a mix of equality and inequality conditions. The entire join condition is + * assumed to be a logical AND of the equality condition and inequality condition. + * Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the full join. + * It is the responsibility of the caller to close the resulting gather map instances. + * @param leftKeys the left table's key columns for the equality condition + * @param rightKeys the right table's key columns for the equality condition + * @param leftConditional the left table's columns needed to evaluate the inequality condition + * @param rightConditional the right table's columns needed to evaluate the inequality condition + * @param condition the inequality condition of the join + * @param nullEquality whether nulls should compare as equal + * @return left and right table gather maps + */ + public static GatherMap[] mixedFullJoinGatherMaps(Table leftKeys, Table rightKeys, + Table leftConditional, Table rightConditional, + CompiledExpression condition, + NullEquality nullEquality) { + long[] gatherMapData = mixedFullJoinGatherMaps( + leftKeys.getNativeView(), rightKeys.getNativeView(), + leftConditional.getNativeView(), rightConditional.getNativeView(), + condition.getNativeHandle(), + nullEquality == NullEquality.EQUAL); + return buildJoinGatherMaps(gatherMapData); + } + private GatherMap buildSemiJoinGatherMap(long[] gatherMapData) { long bufferSize = gatherMapData[0]; long leftAddr = gatherMapData[1]; diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 0e6425ea7a2..ac8c38c9745 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -885,6 +885,67 @@ jlongArray cond_join_gather_single_map(JNIEnv *env, jlong j_left_table, jlong j_ CATCH_STD(env, NULL); } +template +jlongArray mixed_join_size(JNIEnv *env, jlong j_left_keys, jlong j_right_keys, + jlong j_left_condition, jlong j_right_condition, jlong j_condition, + jboolean j_nulls_equal, T join_size_func) { + JNI_NULL_CHECK(env, j_left_keys, "left keys table is null", 0); + JNI_NULL_CHECK(env, j_right_keys, "right keys table is null", 0); + JNI_NULL_CHECK(env, j_left_condition, "left condition table is null", 0); + JNI_NULL_CHECK(env, j_right_condition, "right condition table is null", 0); + JNI_NULL_CHECK(env, j_condition, "condition is null", 0); + try { + cudf::jni::auto_set_device(env); + auto const left_keys = reinterpret_cast(j_left_keys); + auto const right_keys = reinterpret_cast(j_right_keys); + auto const left_condition = reinterpret_cast(j_left_condition); + auto const right_condition = reinterpret_cast(j_right_condition); + auto const condition = reinterpret_cast(j_condition); + auto const nulls_equal = + j_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + std::pair> join_size_info = + join_size_func(*left_keys, *right_keys, *left_condition, *right_condition, + condition->get_top_expression(), nulls_equal); + cudf::jni::native_jlongArray result(env, 2); + result[0] = static_cast(join_size_info.first); + result[1] = reinterpret_cast(join_size_info.second.release()); + return result.get_jArray(); + } + CATCH_STD(env, NULL); +} + +template +jlongArray mixed_join_gather_maps(JNIEnv *env, jlong j_left_keys, jlong j_right_keys, + jlong j_left_condition, jlong j_right_condition, + jlong j_condition, jboolean j_nulls_equal, T join_func) { + JNI_NULL_CHECK(env, j_left_keys, "left keys table is null", 0); + JNI_NULL_CHECK(env, j_right_keys, "right keys table is null", 0); + JNI_NULL_CHECK(env, j_left_condition, "left condition table is null", 0); + JNI_NULL_CHECK(env, j_right_condition, "right condition table is null", 0); + JNI_NULL_CHECK(env, j_condition, "condition is null", 0); + try { + cudf::jni::auto_set_device(env); + auto const left_keys = reinterpret_cast(j_left_keys); + auto const right_keys = reinterpret_cast(j_right_keys); + auto const left_condition = reinterpret_cast(j_left_condition); + auto const right_condition = reinterpret_cast(j_right_condition); + auto const condition = reinterpret_cast(j_condition); + auto const nulls_equal = + j_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + return gather_maps_to_java(env, + join_func(*left_keys, *right_keys, *left_condition, *right_condition, + condition->get_top_expression(), nulls_equal)); + } + CATCH_STD(env, NULL); +} + +std::pair get_mixed_size_info(JNIEnv *env, jlong j_output_row_count, + jlong j_matches_view) { + auto const row_count = static_cast(j_output_row_count); + auto const matches = reinterpret_cast(j_matches_view); + return std::pair(row_count, *matches); +} + // Returns a table view containing only the columns at the specified indices cudf::table_view const get_keys_table(cudf::table_view const *t, native_jintArray const &key_indices) { @@ -2112,6 +2173,50 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinGather }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_mixedLeftJoinSize( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jlong j_left_condition, + jlong j_right_condition, jlong j_condition, jboolean j_nulls_equal) { + return cudf::jni::mixed_join_size( + env, j_left_keys, j_right_keys, j_left_condition, j_right_condition, j_condition, + j_nulls_equal, + [](cudf::table_view const &left_keys, cudf::table_view const &right_keys, + cudf::table_view const &left_condition, cudf::table_view const &right_condition, + cudf::ast::expression const &condition, cudf::null_equality nulls_equal) { + return cudf::mixed_left_join_size(left_keys, right_keys, left_condition, right_condition, + condition, nulls_equal); + }); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_mixedLeftJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jlong j_left_condition, + jlong j_right_condition, jlong j_condition, jboolean j_nulls_equal) { + return cudf::jni::mixed_join_gather_maps( + env, j_left_keys, j_right_keys, j_left_condition, j_right_condition, j_condition, + j_nulls_equal, + [](cudf::table_view const &left_keys, cudf::table_view const &right_keys, + cudf::table_view const &left_condition, cudf::table_view const &right_condition, + cudf::ast::expression const &condition, cudf::null_equality nulls_equal) { + return cudf::mixed_left_join(left_keys, right_keys, left_condition, right_condition, + condition, nulls_equal); + }); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_mixedLeftJoinGatherMapsWithSize( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jlong j_left_condition, + jlong j_right_condition, jlong j_condition, jboolean j_nulls_equal, jlong j_output_row_count, + jlong j_matches_view) { + auto size_info = cudf::jni::get_mixed_size_info(env, j_output_row_count, j_matches_view); + return cudf::jni::mixed_join_gather_maps( + env, j_left_keys, j_right_keys, j_left_condition, j_right_condition, j_condition, + j_nulls_equal, + [&size_info](cudf::table_view const &left_keys, cudf::table_view const &right_keys, + cudf::table_view const &left_condition, cudf::table_view const &right_condition, + cudf::ast::expression const &condition, cudf::null_equality nulls_equal) { + return cudf::mixed_left_join(left_keys, right_keys, left_condition, right_condition, + condition, nulls_equal, size_info); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerJoinGatherMaps( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_maps( @@ -2201,6 +2306,50 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinGathe }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_mixedInnerJoinSize( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jlong j_left_condition, + jlong j_right_condition, jlong j_condition, jboolean j_nulls_equal) { + return cudf::jni::mixed_join_size( + env, j_left_keys, j_right_keys, j_left_condition, j_right_condition, j_condition, + j_nulls_equal, + [](cudf::table_view const &left_keys, cudf::table_view const &right_keys, + cudf::table_view const &left_condition, cudf::table_view const &right_condition, + cudf::ast::expression const &condition, cudf::null_equality nulls_equal) { + return cudf::mixed_inner_join_size(left_keys, right_keys, left_condition, right_condition, + condition, nulls_equal); + }); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_mixedInnerJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jlong j_left_condition, + jlong j_right_condition, jlong j_condition, jboolean j_nulls_equal) { + return cudf::jni::mixed_join_gather_maps( + env, j_left_keys, j_right_keys, j_left_condition, j_right_condition, j_condition, + j_nulls_equal, + [](cudf::table_view const &left_keys, cudf::table_view const &right_keys, + cudf::table_view const &left_condition, cudf::table_view const &right_condition, + cudf::ast::expression const &condition, cudf::null_equality nulls_equal) { + return cudf::mixed_inner_join(left_keys, right_keys, left_condition, right_condition, + condition, nulls_equal); + }); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_mixedInnerJoinGatherMapsWithSize( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jlong j_left_condition, + jlong j_right_condition, jlong j_condition, jboolean j_nulls_equal, jlong j_output_row_count, + jlong j_matches_view) { + auto size_info = cudf::jni::get_mixed_size_info(env, j_output_row_count, j_matches_view); + return cudf::jni::mixed_join_gather_maps( + env, j_left_keys, j_right_keys, j_left_condition, j_right_condition, j_condition, + j_nulls_equal, + [&size_info](cudf::table_view const &left_keys, cudf::table_view const &right_keys, + cudf::table_view const &left_condition, cudf::table_view const &right_condition, + cudf::ast::expression const &condition, cudf::null_equality nulls_equal) { + return cudf::mixed_inner_join(left_keys, right_keys, left_condition, right_condition, + condition, nulls_equal, size_info); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullJoinGatherMaps( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_maps( @@ -2259,6 +2408,20 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalFullJoinGather }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_mixedFullJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jlong j_left_condition, + jlong j_right_condition, jlong j_condition, jboolean j_nulls_equal) { + return cudf::jni::mixed_join_gather_maps( + env, j_left_keys, j_right_keys, j_left_condition, j_right_condition, j_condition, + j_nulls_equal, + [](cudf::table_view const &left_keys, cudf::table_view const &right_keys, + cudf::table_view const &left_condition, cudf::table_view const &right_condition, + cudf::ast::expression const &condition, cudf::null_equality nulls_equal) { + return cudf::mixed_full_join(left_keys, right_keys, left_condition, right_condition, + condition, nulls_equal); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftSemiJoinGatherMap( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_single_map( diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 86c55e19776..644e50cf9da 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -1467,6 +1467,144 @@ void testConditionalLeftJoinGatherMapsNullsWithCount() { } } + @Test + void testMixedLeftJoinGatherMaps() { + final int inv = Integer.MIN_VALUE; + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8) + .column(1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(6, 5, 9, 8, 10, 32) + .column(0, 1, 2, 3, 4, 5) + .column(7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, inv) + .build()) { + GatherMap[] maps = Table.mixedLeftJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.UNEQUAL); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testMixedLeftJoinGatherMapsNulls() { + final int inv = Integer.MIN_VALUE; + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(null, 3, 9, 0, 1, 7, 4, null, 5, 8) + .column( 1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(null, 5, null, 8, 10, 32) + .column( 0, 1, 2, 3, 4, 5) + .column( 7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9) + .column(0, inv, inv, inv, inv, inv, inv, 0, 2, 1, inv) + .build()) { + GatherMap[] maps = Table.mixedLeftJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.EQUAL); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testMixedLeftJoinGatherMapsWithSize() { + final int inv = Integer.MIN_VALUE; + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8) + .column(1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(6, 5, 9, 8, 10, 32) + .column(0, 1, 2, 3, 4, 5) + .column(7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, inv) + .build(); + MixedJoinSize sizeInfo = Table.mixedLeftJoinSize(leftKeys, rightKeys, left, right, + condition, NullEquality.UNEQUAL)) { + assertEquals(expected.getRowCount(), sizeInfo.getOutputRowCount()); + GatherMap[] maps = Table.mixedLeftJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.UNEQUAL, sizeInfo); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testMixedLeftJoinGatherMapsNullsWithSize() { + final int inv = Integer.MIN_VALUE; + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(null, 3, 9, 0, 1, 7, 4, null, 5, 8) + .column( 1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(null, 5, null, 8, 10, 32) + .column( 0, 1, 2, 3, 4, 5) + .column( 7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9) + .column(0, inv, inv, inv, inv, inv, inv, 0, 2, 1, inv) + .build(); + MixedJoinSize sizeInfo = Table.mixedLeftJoinSize(leftKeys, rightKeys, left, right, + condition, NullEquality.EQUAL)) { + assertEquals(expected.getRowCount(), sizeInfo.getOutputRowCount()); + GatherMap[] maps = Table.mixedLeftJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.EQUAL, sizeInfo); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + @Test void testInnerJoinGatherMaps() { try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1737,6 +1875,142 @@ void testConditionalInnerJoinGatherMapsNullsWithCount() { } } + @Test + void testMixedInnerJoinGatherMaps() { + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8) + .column(1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(6, 5, 9, 8, 10, 32) + .column(0, 1, 2, 3, 4, 5) + .column(7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(2, 7, 8) + .column(2, 0, 1) + .build()) { + GatherMap[] maps = Table.mixedInnerJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.UNEQUAL); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testMixedInnerJoinGatherMapsNulls() { + final int inv = Integer.MIN_VALUE; + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(null, 3, 9, 0, 1, 7, 4, null, 5, 8) + .column( 1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(null, 5, null, 8, 10, 32) + .column( 0, 1, 2, 3, 4, 5) + .column( 7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(0, 7, 7, 8) + .column(0, 0, 2, 1) + .build()) { + GatherMap[] maps = Table.mixedInnerJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.EQUAL); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testMixedInnerJoinGatherMapsWithSize() { + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8) + .column(1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(6, 5, 9, 8, 10, 32) + .column(0, 1, 2, 3, 4, 5) + .column(7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(2, 7, 8) + .column(2, 0, 1) + .build(); + MixedJoinSize sizeInfo = Table.mixedInnerJoinSize(leftKeys, rightKeys, left, right, + condition, NullEquality.UNEQUAL)) { + assertEquals(expected.getRowCount(), sizeInfo.getOutputRowCount()); + GatherMap[] maps = Table.mixedInnerJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.UNEQUAL, sizeInfo); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testMixedInnerJoinGatherMapsNullsWithSize() { + final int inv = Integer.MIN_VALUE; + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(null, 3, 9, 0, 1, 7, 4, null, 5, 8) + .column( 1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(null, 5, null, 8, 10, 32) + .column( 0, 1, 2, 3, 4, 5) + .column( 7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(0, 7, 7, 8) + .column(0, 0, 2, 1) + .build(); + MixedJoinSize sizeInfo = Table.mixedInnerJoinSize(leftKeys, rightKeys, left, right, + condition, NullEquality.EQUAL)) { + assertEquals(expected.getRowCount(), sizeInfo.getOutputRowCount()); + GatherMap[] maps = Table.mixedInnerJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.EQUAL, sizeInfo); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + @Test void testFullJoinGatherMaps() { final int inv = Integer.MIN_VALUE; @@ -1931,6 +2205,72 @@ void testConditionalFullJoinGatherMapsNulls() { } } + @Test + void testMixedFullJoinGatherMaps() { + final int inv = Integer.MIN_VALUE; + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8) + .column(1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(6, 5, 9, 8, 10, 32) + .column(0, 1, 2, 3, 4, 5) + .column(7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(inv, inv, inv, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + .column( 3, 4, 5, inv, inv, 2, inv, inv, inv, inv, 0, 1, inv) + .build()) { + GatherMap[] maps = Table.mixedFullJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.UNEQUAL); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testMixedFullJoinGatherMapsNulls() { + final int inv = Integer.MIN_VALUE; + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(null, 3, 9, 0, 1, 7, 4, null, 5, 8) + .column( 1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(null, 5, null, 8, 10, 32) + .column( 0, 1, 2, 3, 4, 5) + .column( 7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(inv, inv, inv, 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9) + .column( 3, 4, 5, 0, inv, inv, inv, inv, inv, inv, 0, 2, 1, inv) + .build()) { + GatherMap[] maps = Table.mixedFullJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.EQUAL); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + @Test void testLeftSemiJoinGatherMap() { try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); From 3af69800a1c3bc23e99e1432e831754c803021b6 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 11 Jan 2022 13:34:32 -0600 Subject: [PATCH 2/5] Update to new API that uses device_uvector --- java/src/main/native/src/TableJni.cpp | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index ac8c38c9745..b9ac68218c8 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -40,6 +40,7 @@ #include #include #include +#include #include #include "cudf_jni_apis.hpp" @@ -903,12 +904,18 @@ jlongArray mixed_join_size(JNIEnv *env, jlong j_left_keys, jlong j_right_keys, auto const condition = reinterpret_cast(j_condition); auto const nulls_equal = j_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; - std::pair> join_size_info = + std::pair>> join_size_info = join_size_func(*left_keys, *right_keys, *left_condition, *right_condition, condition->get_top_expression(), nulls_equal); + if (join_size_info.second->size() > std::numeric_limits::max()) { + throw std::runtime_error("Too many values in device buffer to convert into a column"); + } + auto col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, join_size_info.second->size(), + join_size_info.second->release(), rmm::device_buffer{}, 0); cudf::jni::native_jlongArray result(env, 2); result[0] = static_cast(join_size_info.first); - result[1] = reinterpret_cast(join_size_info.second.release()); + result[1] = reinterpret_cast(col.release()); return result.get_jArray(); } CATCH_STD(env, NULL); @@ -939,11 +946,13 @@ jlongArray mixed_join_gather_maps(JNIEnv *env, jlong j_left_keys, jlong j_right_ CATCH_STD(env, NULL); } -std::pair get_mixed_size_info(JNIEnv *env, jlong j_output_row_count, - jlong j_matches_view) { +std::pair> +get_mixed_size_info(JNIEnv *env, jlong j_output_row_count, jlong j_matches_view) { auto const row_count = static_cast(j_output_row_count); auto const matches = reinterpret_cast(j_matches_view); - return std::pair(row_count, *matches); + return std::pair>( + row_count, cudf::device_span(matches->template data(), + matches->size())); } // Returns a table view containing only the columns at the specified indices From 1ab70c84c4368d6a6ead65ce063d24a3fa875004 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 11 Jan 2022 14:16:25 -0600 Subject: [PATCH 3/5] Update copyright --- java/src/main/java/ai/rapids/cudf/MixedJoinSize.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/src/main/java/ai/rapids/cudf/MixedJoinSize.java b/java/src/main/java/ai/rapids/cudf/MixedJoinSize.java index 8e95f87b3da..811f0b9a0b0 100644 --- a/java/src/main/java/ai/rapids/cudf/MixedJoinSize.java +++ b/java/src/main/java/ai/rapids/cudf/MixedJoinSize.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 394497c078a5588bbdd39d97336de9e1b3ad1aff Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 12 Jan 2022 10:34:14 -0600 Subject: [PATCH 4/5] Fix construction of mixed join size matches column --- java/src/main/native/src/TableJni.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 46fe647f241..03faf9be021 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -911,9 +911,10 @@ jlongArray mixed_join_size(JNIEnv *env, jlong j_left_keys, jlong j_right_keys, if (join_size_info.second->size() > std::numeric_limits::max()) { throw std::runtime_error("Too many values in device buffer to convert into a column"); } - auto col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, join_size_info.second->size(), - join_size_info.second->release(), rmm::device_buffer{}, 0); + auto col_size = join_size_info.second->size(); + auto col_data = join_size_info.second->release(); + auto col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, col_size, + std::move(col_data), rmm::device_buffer{}, 0); cudf::jni::native_jlongArray result(env, 2); result[0] = static_cast(join_size_info.first); result[1] = reinterpret_cast(col.release()); From 61b0e7d9b86698d46491dd2710fa8d229edfb12d Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 12 Jan 2022 12:11:12 -0600 Subject: [PATCH 5/5] Remove unused variables --- java/src/test/java/ai/rapids/cudf/TableTest.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index f4f3bb58fee..8e074a5e4ff 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -2020,7 +2020,6 @@ void testMixedInnerJoinGatherMaps() { @Test void testMixedInnerJoinGatherMapsNulls() { - final int inv = Integer.MIN_VALUE; BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, new ColumnReference(1, TableReference.LEFT), new ColumnReference(1, TableReference.RIGHT)); @@ -2088,7 +2087,6 @@ void testMixedInnerJoinGatherMapsWithSize() { @Test void testMixedInnerJoinGatherMapsNullsWithSize() { - final int inv = Integer.MIN_VALUE; BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, new ColumnReference(1, TableReference.LEFT), new ColumnReference(1, TableReference.RIGHT));