Skip to content

Commit

Permalink
Add JNI and Java bindings for list_contains (#7125)
Browse files Browse the repository at this point in the history
Adds JNI and Java side bindings for `list_contains` that is being added as part of #7039.

Authors:
  - Kuhu Shukla (@kuhushukla)

Approvers:
  - Robert (Bobby) Evans (@revans2)
  - MithunR (@mythrocks)

URL: #7125
  • Loading branch information
Kuhu Shukla authored Jan 27, 2021
1 parent fc40c52 commit dd1efe1
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 0 deletions.
47 changes: 47 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2323,6 +2323,37 @@ public static ColumnView makeStructView(ColumnView... columns) {
return makeStructView(columns[0].rows, columns);
}

/**
* Create a column of bool values indicating whether the specified scalar
* is an element of each row of a list column.
* Output `column[i]` is set to null if one or more of the following are true:
* 1. The key is null
* 2. The column vector list value is null
* 3. The list row does not contain the key, and contains at least
* one null.
* @param key the scalar to look up
* @return a Boolean ColumnVector with the result of the lookup
*/
public final ColumnVector listContains(Scalar key) {
assert type.equals(DType.LIST) : "column type must be a LIST";
return new ColumnVector(listContains(getNativeView(), key.getScalarHandle()));
}

/**
* Create a column of bool values indicating whether the list rows of the first
* column contain the corresponding values in the second column.
* 1. The key value is null
* 2. The column vector list value is null
* 3. The list row does not contain the key, and contains at least
* one null.
* @param key the ColumnVector with look up values
* @return a Boolean ColumnVector with the result of the lookup
*/
public final ColumnVector listContainsColumn(ColumnView key) {
assert type.equals(DType.LIST) : "column type must be a LIST";
return new ColumnVector(listContainsColumn(getNativeView(), key.getNativeView()));
}

/////////////////////////////////////////////////////////////////////////////
// INTERNAL/NATIVE ACCESS
/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -2558,6 +2589,22 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat

private static native long extractListElement(long nativeView, int index);

/**
* Native method for list lookup
* @param nativeView the column view handle of the list
* @param key the scalar key handle
* @return column handle of the resultant
*/
private static native long listContains(long nativeView, long key);

/**
* Native method for list lookup
* @param nativeView the column view handle of the list
* @param keyColumn the column handle of look up keys
* @return column handle of the resultant
*/
private static native long listContainsColumn(long nativeView, long keyColumn);

private static native long castTo(long nativeHandle, int type, int scale);

private static native long logicalCastTo(long nativeHandle, int type, int scale);
Expand Down
35 changes: 35 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include <cudf/transform.hpp>
#include <cudf/unary.hpp>
#include <cudf/utilities/bit.hpp>
#include <cudf/lists/contains.hpp>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/structs/structs_column_view.hpp>
#include <map_lookup.hpp>
Expand Down Expand Up @@ -329,6 +330,40 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractListElement(JNIEnv
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listContains(JNIEnv *env, jclass,
jlong column_view,
jlong lookup_key) {
JNI_NULL_CHECK(env, column_view, "column is null", 0);
JNI_NULL_CHECK(env, lookup_key, "lookup scalar is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::column_view *cv = reinterpret_cast<cudf::column_view *>(column_view);
cudf::lists_column_view lcv(*cv);
cudf::scalar *lookup_scalar = reinterpret_cast<cudf::scalar *>(lookup_key);

std::unique_ptr<cudf::column> ret = cudf::lists::contains(lcv, *lookup_scalar);
return reinterpret_cast<jlong>(ret.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listContainsColumn(JNIEnv *env, jclass,
jlong column_view,
jlong lookup_key_cv) {
JNI_NULL_CHECK(env, column_view, "column is null", 0);
JNI_NULL_CHECK(env, lookup_key_cv, "lookup column is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::column_view *cv = reinterpret_cast<cudf::column_view *>(column_view);
cudf::lists_column_view lcv(*cv);
cudf::column_view *lookup_cv = reinterpret_cast<cudf::column_view *>(lookup_key_cv);

std::unique_ptr<cudf::column> ret = cudf::lists::contains(lcv, *lookup_cv);
return reinterpret_cast<jlong>(ret.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv *env, jclass,
jlong column_view,
jlong delimiter) {
Expand Down
61 changes: 61 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2899,6 +2899,67 @@ void testExtractListElements() {
}
}

@Test
void testListContainsString() {
List<String> list1 = Arrays.asList("Héllo there", "thésé");
List<String> list2 = Arrays.asList("", "ARé some", "test strings");
List<String> list3 = Arrays.asList(null, "", "ARé some", "test strings", "thésé");
List<String> list4 = Arrays.asList(null, "", "ARé some", "test strings");
List<String> list5 = null;
try (ColumnVector v = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.STRING)), list1, list2, list3, list4, list5);
ColumnVector expected = ColumnVector.fromBoxedBooleans(true, false, true, null, null);
ColumnVector result = v.listContains(Scalar.fromString("thésé"))) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testListContainsInt() {
List<Integer> list1 = Arrays.asList(1, 2, 3);
List<Integer> list2 = Arrays.asList(4, 5, 6);
List<Integer> list3 = Arrays.asList(7, 8, 9);
List<Integer> list4 = null;
try (ColumnVector v = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)), list1, list2, list3, list4);
ColumnVector expected = ColumnVector.fromBoxedBooleans(false, false, true, null);
ColumnVector result = v.listContains(Scalar.fromInt(7))) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testListContainsStringCol() {
List<String> list1 = Arrays.asList("Héllo there", "thésé");
List<String> list2 = Arrays.asList("", "ARé some", "test strings");
List<String> list3 = Arrays.asList("FOO", "", "ARé some", "test");
List<String> list4 = Arrays.asList(null, "FOO", "", "ARé some", "test");
List<String> list5 = Arrays.asList(null, "FOO", "", "ARé some", "test");
List<String> list6 = null;
try (ColumnVector v = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.STRING)), list1, list2, list3, list4, list5, list6);
ColumnVector expected = ColumnVector.fromBoxedBooleans(true, true, true, true, null, null);
ColumnVector result = v.listContainsColumn(
ColumnVector.fromStrings("thésé", "", "test", "test", "iotA", null))) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testListContainsIntCol() {
List<Integer> list1 = Arrays.asList(1, 2, 3);
List<Integer> list2 = Arrays.asList(4, 5, 6);
List<Integer> list3 = Arrays.asList(null, 8, 9);
List<Integer> list4 = Arrays.asList(null, 8, 9);
List<Integer> list5 = null;
try (ColumnVector v = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)), list1, list2, list3, list4, list5);
ColumnVector expected = ColumnVector.fromBoxedBooleans(true, false, true, null, null);
ColumnVector result = v.listContainsColumn(ColumnVector.fromBoxedInts(3, 3, 8, 3, null))) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testStringSplitRecord() {
try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", "null", "", "ARé some", "test strings");
Expand Down

0 comments on commit dd1efe1

Please sign in to comment.