Skip to content

Commit

Permalink
Add JNI bindings for extractAllRecord (#11196)
Browse files Browse the repository at this point in the history
Part of #11033

Adds `extractAllRecord` to the JNI layer. For now this only supports `idx` 0 using `cudf::strings::findall_record`. 

Corresponding spark plugin PR: NVIDIA/spark-rapids#5947

Authors:
  - Anthony Chang (https://github.com/anthony-chang)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #11196
  • Loading branch information
anthony-chang authored Jul 6, 2022
1 parent 4881af0 commit fd3cf3e
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 10 deletions.
30 changes: 30 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -3238,6 +3238,26 @@ public final Table extractRe(String pattern) throws CudfException {
return new Table(extractRe(this.getNativeView(), pattern));
}

/**
* Extracts all strings that match the given regular expression and corresponds to the
* regular expression group index. Any null inputs also result in null output entries.
*
* For supported regex patterns refer to:
* @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html
* @param pattern The regex pattern
* @param idx The regex group index (only 0 is supported currently).
* @return A new column vector of extracted matches
*/
public final ColumnVector extractAllRecord(String pattern, int idx) {
assert type.equals(DType.STRING) : "column type must be a String";
assert idx >= 0 : "group index must be at least 0";
assert idx == 0 : "group index > 0 is not supported yet";

return new ColumnVector(extractAllRecord(this.getNativeView(), pattern, idx));
}


/**
* Converts all character sequences starting with '%' into character code-points
* interpreting the 2 following characters as hex values to create the code-point.
Expand Down Expand Up @@ -3926,6 +3946,16 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat
*/
private static native long[] extractRe(long cudfViewHandle, String pattern) throws CudfException;

/**
* Native method for extracting all results corresponding to group idx from a regular expression.
*
* @param nativeHandle Native handle of the cudf::column_view being operated on.
* @param pattern String regex pattern.
* @param idx Regex group index. A 0 value means matching the entire regex.
* @return Native handle of a string column of the result.
*/
private static native long extractAllRecord(long nativeHandle, String pattern, int idx);

private static native long urlDecode(long cudfViewHandle);

private static native long urlEncode(long cudfViewHandle);
Expand Down
22 changes: 22 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include <cudf/strings/convert/convert_urls.hpp>
#include <cudf/strings/extract.hpp>
#include <cudf/strings/find.hpp>
#include <cudf/strings/findall.hpp>
#include <cudf/strings/json.hpp>
#include <cudf/strings/padding.hpp>
#include <cudf/strings/repeat_strings.hpp>
Expand Down Expand Up @@ -1576,6 +1577,27 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_extractRe(JNIEnv *en
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractAllRecord(JNIEnv *env, jclass,
jlong j_view_handle,
jstring pattern_obj,
jint idx) {
JNI_NULL_CHECK(env, j_view_handle, "column is null", 0);

if (idx > 0) {
JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "group index > 0 is not supported", 0);
}

try {
cudf::jni::auto_set_device(env);
cudf::strings_column_view const strings_column{
*reinterpret_cast<cudf::column_view *>(j_view_handle)};
cudf::jni::native_jstring pattern(env, pattern_obj);

return release_as_jlong(cudf::strings::findall_record(strings_column, pattern.get()));
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_urlDecode(JNIEnv *env, jclass,
jlong j_view_handle) {
JNI_NULL_CHECK(env, j_view_handle, "column is null", 0);
Expand Down
37 changes: 27 additions & 10 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4097,17 +4097,34 @@ void testStringFindOperations() {
}
}

@Test
void testExtractRe() {
try (ColumnVector input = ColumnVector.fromStrings("a1", "b2", "c3", null);
Table expected = new Table.TestBuilder()
.column("a", "b", null, null)
.column("1", "2", null, null)
.build();
Table found = input.extractRe("([ab])(\\d)")) {
assertTablesAreEqual(expected, found);
}
@Test
void testExtractRe() {
try (ColumnVector input = ColumnVector.fromStrings("a1", "b2", "c3", null);
Table expected = new Table.TestBuilder()
.column("a", "b", null, null)
.column("1", "2", null, null)
.build();
Table found = input.extractRe("([ab])(\\d)")) {
assertTablesAreEqual(expected, found);
}
}

@Test
void testExtractAllRecord() {
try (ColumnVector v = ColumnVector.fromStrings("a1", "b2", "c3", null, "a1b1c3a2");
ColumnVector expectedIdx0 = ColumnVector.fromLists(
new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.STRING)),
Arrays.asList("a1"),
Arrays.asList("b2"),
Arrays.asList(),
null,
Arrays.asList("a1", "b1", "a2"));
ColumnVector resultIdx0 = v.extractAllRecord("([ab])(\\d)", 0)
) {
assertColumnsAreEqual(expectedIdx0, resultIdx0);
}
}

@Test
void testMatchesRe() {
Expand Down

0 comments on commit fd3cf3e

Please sign in to comment.