Skip to content

Commit

Permalink
Enable positive group indices for extractAllRecord on JNI (NVIDIA#11215)
Browse files Browse the repository at this point in the history
Closes rapidsai/cudf#11033

Allow `extractAllRecord` to take an `idx` > 0 and exposes `cudf::strings::extract_all_record`

Authors:
  - Anthony Chang (https://github.com/anthony-chang)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - MithunR (https://github.com/mythrocks)
  - Alessandro Bellina (https://github.com/abellina)
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: rapidsai/cudf#11215
  • Loading branch information
anthony-chang authored Jul 8, 2022
1 parent 417136d commit 1331994
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
3 changes: 1 addition & 2 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -3246,13 +3246,12 @@ public final Table extractRe(String pattern) throws CudfException {
* @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html
* @param pattern The regex pattern
* @param idx The regex group index (only 0 is supported currently).
* @param idx The regex group index
* @return A new column vector of extracted matches
*/
public final ColumnVector extractAllRecord(String pattern, int idx) {
assert type.equals(DType.STRING) : "column type must be a String";
assert idx >= 0 : "group index must be at least 0";
assert idx == 0 : "group index > 0 is not supported yet";

return new ColumnVector(extractAllRecord(this.getNativeView(), pattern, idx));
}
Expand Down
9 changes: 4 additions & 5 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1604,17 +1604,16 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractAllRecord(JNIEnv *
jint idx) {
JNI_NULL_CHECK(env, j_view_handle, "column is null", 0);

if (idx > 0) {
JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "group index > 0 is not supported", 0);
}

try {
cudf::jni::auto_set_device(env);
cudf::strings_column_view const strings_column{
*reinterpret_cast<cudf::column_view *>(j_view_handle)};
cudf::jni::native_jstring pattern(env, pattern_obj);

return release_as_jlong(cudf::strings::findall_record(strings_column, pattern.get()));
auto result = (idx == 0) ? cudf::strings::findall_record(strings_column, pattern.get()) :
cudf::strings::extract_all_record(strings_column, pattern.get());

return release_as_jlong(result);
}
CATCH_STD(env, 0);
}
Expand Down
16 changes: 15 additions & 1 deletion java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4111,6 +4111,7 @@ void testExtractRe() {

@Test
void testExtractAllRecord() {
String pattern = "([ab])(\\d)";
try (ColumnVector v = ColumnVector.fromStrings("a1", "b2", "c3", null, "a1b1c3a2");
ColumnVector expectedIdx0 = ColumnVector.fromLists(
new HostColumnVector.ListType(true,
Expand All @@ -4120,9 +4121,22 @@ void testExtractAllRecord() {
Arrays.asList(),
null,
Arrays.asList("a1", "b1", "a2"));
ColumnVector resultIdx0 = v.extractAllRecord("([ab])(\\d)", 0)
ColumnVector expectedIdx12 = ColumnVector.fromLists(
new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.STRING)),
Arrays.asList("a", "1"),
Arrays.asList("b", "2"),
null,
null,
Arrays.asList("a", "1", "b", "1", "a", "2"));

ColumnVector resultIdx0 = v.extractAllRecord(pattern, 0);
ColumnVector resultIdx1 = v.extractAllRecord(pattern, 1);
ColumnVector resultIdx2 = v.extractAllRecord(pattern, 2);
) {
assertColumnsAreEqual(expectedIdx0, resultIdx0);
assertColumnsAreEqual(expectedIdx12, resultIdx1);
assertColumnsAreEqual(expectedIdx12, resultIdx2);
}
}

Expand Down

0 comments on commit 1331994

Please sign in to comment.