diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index c2110a5f8ff..1dce52f7105 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -2323,6 +2323,37 @@ public static ColumnView makeStructView(ColumnView... columns) { return makeStructView(columns[0].rows, columns); } + /** + * Create a column of bool values indicating whether the specified scalar + * is an element of each row of a list column. + * Output `column[i]` is set to null if one or more of the following are true: + * 1. The key is null + * 2. The column vector list value is null + * 3. The list row does not contain the key, and contains at least + * one null. + * @param key the scalar to look up + * @return a Boolean ColumnVector with the result of the lookup + */ + public final ColumnVector listContains(Scalar key) { + assert type.equals(DType.LIST) : "column type must be a LIST"; + return new ColumnVector(listContains(getNativeView(), key.getScalarHandle())); + } + + /** + * Create a column of bool values indicating whether the list rows of the first + * column contain the corresponding values in the second column. + * 1. The key value is null + * 2. The column vector list value is null + * 3. The list row does not contain the key, and contains at least + * one null. + * @param key the ColumnVector with look up values + * @return a Boolean ColumnVector with the result of the lookup + */ + public final ColumnVector listContainsColumn(ColumnView key) { + assert type.equals(DType.LIST) : "column type must be a LIST"; + return new ColumnVector(listContainsColumn(getNativeView(), key.getNativeView())); + } + ///////////////////////////////////////////////////////////////////////////// // INTERNAL/NATIVE ACCESS ///////////////////////////////////////////////////////////////////////////// @@ -2558,6 +2589,22 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long extractListElement(long nativeView, int index); + /** + * Native method for list lookup + * @param nativeView the column view handle of the list + * @param key the scalar key handle + * @return column handle of the resultant + */ + private static native long listContains(long nativeView, long key); + + /** + * Native method for list lookup + * @param nativeView the column view handle of the list + * @param keyColumn the column handle of look up keys + * @return column handle of the resultant + */ + private static native long listContainsColumn(long nativeView, long keyColumn); + private static native long castTo(long nativeHandle, int type, int scale); private static native long logicalCastTo(long nativeHandle, int type, int scale); diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 621344ac38f..82e71b04a2f 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -56,6 +56,7 @@ #include #include #include +#include #include #include #include @@ -329,6 +330,40 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractListElement(JNIEnv CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listContains(JNIEnv *env, jclass, + jlong column_view, + jlong lookup_key) { + JNI_NULL_CHECK(env, column_view, "column is null", 0); + JNI_NULL_CHECK(env, lookup_key, "lookup scalar is null", 0); + try { + cudf::jni::auto_set_device(env); + cudf::column_view *cv = reinterpret_cast(column_view); + cudf::lists_column_view lcv(*cv); + cudf::scalar *lookup_scalar = reinterpret_cast(lookup_key); + + std::unique_ptr ret = cudf::lists::contains(lcv, *lookup_scalar); + return reinterpret_cast(ret.release()); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listContainsColumn(JNIEnv *env, jclass, + jlong column_view, + jlong lookup_key_cv) { + JNI_NULL_CHECK(env, column_view, "column is null", 0); + JNI_NULL_CHECK(env, lookup_key_cv, "lookup column is null", 0); + try { + cudf::jni::auto_set_device(env); + cudf::column_view *cv = reinterpret_cast(column_view); + cudf::lists_column_view lcv(*cv); + cudf::column_view *lookup_cv = reinterpret_cast(lookup_key_cv); + + std::unique_ptr ret = cudf::lists::contains(lcv, *lookup_cv); + return reinterpret_cast(ret.release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv *env, jclass, jlong column_view, jlong delimiter) { diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 88ff50959f7..582b67b8287 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -2899,6 +2899,67 @@ void testExtractListElements() { } } + @Test + void testListContainsString() { + List list1 = Arrays.asList("Héllo there", "thésé"); + List list2 = Arrays.asList("", "ARé some", "test strings"); + List list3 = Arrays.asList(null, "", "ARé some", "test strings", "thésé"); + List list4 = Arrays.asList(null, "", "ARé some", "test strings"); + List list5 = null; + try (ColumnVector v = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), list1, list2, list3, list4, list5); + ColumnVector expected = ColumnVector.fromBoxedBooleans(true, false, true, null, null); + ColumnVector result = v.listContains(Scalar.fromString("thésé"))) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void testListContainsInt() { + List list1 = Arrays.asList(1, 2, 3); + List list2 = Arrays.asList(4, 5, 6); + List list3 = Arrays.asList(7, 8, 9); + List list4 = null; + try (ColumnVector v = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), list1, list2, list3, list4); + ColumnVector expected = ColumnVector.fromBoxedBooleans(false, false, true, null); + ColumnVector result = v.listContains(Scalar.fromInt(7))) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void testListContainsStringCol() { + List list1 = Arrays.asList("Héllo there", "thésé"); + List list2 = Arrays.asList("", "ARé some", "test strings"); + List list3 = Arrays.asList("FOO", "", "ARé some", "test"); + List list4 = Arrays.asList(null, "FOO", "", "ARé some", "test"); + List list5 = Arrays.asList(null, "FOO", "", "ARé some", "test"); + List list6 = null; + try (ColumnVector v = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), list1, list2, list3, list4, list5, list6); + ColumnVector expected = ColumnVector.fromBoxedBooleans(true, true, true, true, null, null); + ColumnVector result = v.listContainsColumn( + ColumnVector.fromStrings("thésé", "", "test", "test", "iotA", null))) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void testListContainsIntCol() { + List list1 = Arrays.asList(1, 2, 3); + List list2 = Arrays.asList(4, 5, 6); + List list3 = Arrays.asList(null, 8, 9); + List list4 = Arrays.asList(null, 8, 9); + List list5 = null; + try (ColumnVector v = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), list1, list2, list3, list4, list5); + ColumnVector expected = ColumnVector.fromBoxedBooleans(true, false, true, null, null); + ColumnVector result = v.listContainsColumn(ColumnVector.fromBoxedInts(3, 3, 8, 3, null))) { + assertColumnsAreEqual(expected, result); + } + } + @Test void testStringSplitRecord() { try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", "null", "", "ARé some", "test strings");