From fcddd333d16fcfc6ddc9b612e2df6144d628d950 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Fri, 23 Jul 2021 16:52:59 -0500 Subject: [PATCH] Java bindings for regex replace --- .../main/java/ai/rapids/cudf/ColumnView.java | 64 +++++++++++++++++++ java/src/main/native/src/ColumnViewJni.cpp | 45 +++++++++++++ .../java/ai/rapids/cudf/ColumnVectorTest.java | 40 ++++++++++++ 3 files changed, 149 insertions(+) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 7299a6a716b..e61462202c2 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -2386,6 +2386,48 @@ public final ColumnVector stringReplace(Scalar target, Scalar replace) { replace.getScalarHandle())); } + /** + * For each string, replaces any character sequence matching the given pattern using the + * replacement string scalar. + * + * @param pattern The regular expression pattern to search within each string. + * @param repl The string scalar to replace for each pattern match. + * @return A new column vector containing the string results. + */ + public final ColumnVector replaceRegex(String pattern, Scalar repl) { + return replaceRegex(pattern, repl, -1); + } + + /** + * For each string, replaces any character sequence matching the given pattern using the + * replacement string scalar. + * + * @param pattern The regular expression pattern to search within each string. + * @param repl The string scalar to replace for each pattern match. + * @param maxRepl The maximum number of times a replacement should occur within each string. + * @return A new column vector containing the string results. + */ + public final ColumnVector replaceRegex(String pattern, Scalar repl, int maxRepl) { + if (!repl.getType().equals(DType.STRING)) { + throw new IllegalArgumentException("Replacement must be a string scalar"); + } + return new ColumnVector(replaceRegex(getNativeView(), pattern, repl.getScalarHandle(), + maxRepl)); + } + + /** + * For each string, replaces any character sequence matching any of the regular expression + * patterns with the corresponding replacement strings. + * + * @param patterns The regular expression patterns to search within each string. + * @param repls The string scalars to replace for each corresponding pattern match. + * @return A new column vector containing the string results. + */ + public final ColumnVector replaceMultiRegex(String[] patterns, ColumnView repls) { + return new ColumnVector(replaceMultiRegex(getNativeView(), patterns, + repls.getNativeView())); + } + /** * For each string, replaces any character sequence matching the given pattern * using the replace template for back-references. @@ -3117,6 +3159,28 @@ private static native long substringColumn(long columnView, long startColumn, lo */ private static native long stringReplace(long columnView, long target, long repl) throws CudfException; + /** + * Native method for replacing each regular expression pattern match with the specified + * replacement string. + * @param columnView native handle of the cudf::column_view being operated on. + * @param pattern The regular expression pattern to search within each string. + * @param repl native handle of the cudf::scalar containing the replacement string. + * @param maxRepl maximum number of times to replace the pattern within a string + * @return native handle of the resulting cudf column containing the string results. + */ + private static native long replaceRegex(long columnView, String pattern, + long repl, long maxRepl) throws CudfException; + + /** + * Native method for multiple instance regular expression replacement. + * @param columnView native handle of the cudf::column_view being operated on. + * @param patterns native handle of the cudf::column_view containing the regex patterns. + * @param repls The replacement template for creating the output string. + * @return native handle of the resulting cudf column containing the string results. + */ + private static native long replaceMultiRegex(long columnView, String[] patterns, + long repls) throws CudfException; + /** * Native method for replacing any character sequence matching the given pattern * using the replace template for back-references. diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 83ba4d56d68..82e07577915 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -1213,6 +1213,51 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapContains(JNIEnv *env, CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceRegex(JNIEnv *env, jclass, + jlong j_column_view, + jstring j_pattern, jlong j_repl, + jlong j_maxrepl) { + + JNI_NULL_CHECK(env, j_column_view, "column is null", 0); + JNI_NULL_CHECK(env, j_pattern, "pattern string is null", 0); + JNI_NULL_CHECK(env, j_repl, "replace scalar is null", 0); + try { + cudf::jni::auto_set_device(env); + auto cv = reinterpret_cast(j_column_view); + cudf::strings_column_view scv(*cv); + cudf::jni::native_jstring pattern(env, j_pattern); + auto repl = reinterpret_cast(j_repl); + + std::unique_ptr result = + cudf::strings::replace_re(scv, pattern.get(), *repl, j_maxrepl); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceMultiRegex(JNIEnv *env, jclass, + jlong j_column_view, + jobjectArray j_patterns, + jlong j_repls) { + + JNI_NULL_CHECK(env, j_column_view, "column is null", 0); + JNI_NULL_CHECK(env, j_patterns, "patterns is null", 0); + JNI_NULL_CHECK(env, j_repls, "repls is null", 0); + try { + cudf::jni::auto_set_device(env); + auto cv = reinterpret_cast(j_column_view); + cudf::strings_column_view scv(*cv); + cudf::jni::native_jstringArray patterns(env, j_patterns); + auto repl_cv = reinterpret_cast(j_repls); + cudf::strings_column_view repl_scv(*repl_cv); + + std::unique_ptr result = + cudf::strings::replace_re(scv, patterns.as_cpp_vector(), repl_scv); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplaceWithBackrefs( JNIEnv *env, jclass, jlong column_view, jstring patternObj, jstring replaceObj) { diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index a63fd408dad..ff84a2a7c69 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4401,6 +4401,46 @@ void teststringReplaceThrowsException() { }); } + @Test + void testReplaceRegex() { + try (ColumnVector v = + ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); + Scalar repl = Scalar.fromString("Repl"); + ColumnVector actual = v.replaceRegex("[tT]itle", repl); + ColumnVector expected = + ColumnVector.fromStrings("Repl and Repl with Repl", "nothing", null, "Repl")) { + assertColumnsAreEqual(expected, actual); + } + + try (ColumnVector v = + ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); + Scalar repl = Scalar.fromString("Repl"); + ColumnVector actual = v.replaceRegex("[tT]itle", repl, 0)) { + assertColumnsAreEqual(v, actual); + } + + try (ColumnVector v = + ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); + Scalar repl = Scalar.fromString("Repl"); + ColumnVector actual = v.replaceRegex("[tT]itle", repl, 1); + ColumnVector expected = + ColumnVector.fromStrings("Repl and Title with title", "nothing", null, "Repl")) { + assertColumnsAreEqual(expected, actual); + } + } + + @Test + void testReplaceMultiRegex() { + try (ColumnVector v = + ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); + ColumnVector repls = ColumnVector.fromStrings("Repl", "**"); + ColumnVector actual = v.replaceMultiRegex(new String[] { "[tT]itle", "and|th" }, repls); + ColumnVector expected = + ColumnVector.fromStrings("Repl ** Repl wi** Repl", "no**ing", null, "Repl")) { + assertColumnsAreEqual(expected, actual); + } + } + @Test void testStringReplaceWithBackrefs() {