diff --git a/src/main/cpp/src/ParseURIJni.cpp b/src/main/cpp/src/ParseURIJni.cpp index c688d10736..354d47c424 100644 --- a/src/main/cpp/src/ParseURIJni.cpp +++ b/src/main/cpp/src/ParseURIJni.cpp @@ -77,4 +77,19 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_ParseURI_parseQueryWith } CATCH_STD(env, 0); } + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_ParseURI_parseQueryWithColumn( + JNIEnv* env, jclass, jlong input_column, jlong query_column) +{ + JNI_NULL_CHECK(env, input_column, "input column is null", 0); + JNI_NULL_CHECK(env, query_column, "query column is null", 0); + + try { + cudf::jni::auto_set_device(env); + auto const input = reinterpret_cast(input_column); + auto const query = reinterpret_cast(query_column); + return cudf::jni::ptr_as_jlong(spark_rapids_jni::parse_uri_to_query(*input, *query).release()); + } + CATCH_STD(env, 0); +} } diff --git a/src/main/cpp/src/parse_uri.cu b/src/main/cpp/src/parse_uri.cu index 4d21617fd7..cd64c539ef 100644 --- a/src/main/cpp/src/parse_uri.cu +++ b/src/main/cpp/src/parse_uri.cu @@ -629,7 +629,11 @@ uri_parts __device__ validate_uri(const char* str, // passed as param0, the return would simply be 5. if (query_match && query_match->size() > 0) { auto const match_idx = row_idx % query_match->size(); - auto in_match = query_match->element(match_idx); + if (query_match->is_null(match_idx)) { + ret.valid = 0; + return ret; + } + auto in_match = query_match->element(match_idx); auto const [query, valid] = find_query_part(ret.query, in_match); if (!valid) { @@ -993,4 +997,15 @@ std::unique_ptr parse_uri_to_query(cudf::strings_column_view const return detail::parse_uri(input, detail::URI_chunks::QUERY, strings_column_view(*col), stream, mr); } +std::unique_ptr parse_uri_to_query(cudf::strings_column_view const& input, + cudf::strings_column_view const& query_match, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + CUDF_EXPECTS(input.size() == query_match.size(), "Query column must be the same size as input!"); + + return detail::parse_uri(input, detail::URI_chunks::QUERY, query_match, stream, mr); +} + } // namespace spark_rapids_jni diff --git a/src/main/cpp/src/parse_uri.hpp b/src/main/cpp/src/parse_uri.hpp index bb001e3167..004d800ddb 100644 --- a/src/main/cpp/src/parse_uri.hpp +++ b/src/main/cpp/src/parse_uri.hpp @@ -80,4 +80,19 @@ std::unique_ptr parse_uri_to_query( rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); +/** + * @brief Parse query and copy from the input string column to the output string column. + * + * @param input Input string column of URIs to parse. + * @param query_match string column to match in query. + * @param stream Stream on which to operate. + * @param mr Memory resource for returned column. + * @return std::unique_ptr String column of queries parsed. + */ +std::unique_ptr parse_uri_to_query( + cudf::strings_column_view const& input, + cudf::strings_column_view const& query_match, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + } // namespace spark_rapids_jni diff --git a/src/main/cpp/tests/parse_uri.cpp b/src/main/cpp/tests/parse_uri.cpp index 234ad380c7..09f238e18c 100644 --- a/src/main/cpp/tests/parse_uri.cpp +++ b/src/main/cpp/tests/parse_uri.cpp @@ -395,4 +395,15 @@ TEST_F(ParseURIQueryTests, Queries) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view()); } + { + cudf::test::strings_column_wrapper const query( + {"param0", "q", "a", "invalid", "test", "query", "fakeparam0", "C"}); + cudf::test::strings_column_wrapper const expected({"1", "", "b", "param", "", "1", "5", "C"}, + {1, 0, 1, 1, 0, 1, 1, 1}); + + auto const result = spark_rapids_jni::parse_uri_to_query(cudf::strings_column_view{col}, + cudf::strings_column_view{query}); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view()); + } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java b/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java index e9908f9ea5..6de84ea519 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java @@ -72,8 +72,22 @@ public static ColumnVector parseURIQueryWithLiteral(ColumnView uriColumn, String return new ColumnVector(parseQueryWithLiteral(uriColumn.getNativeView(), query)); } - private static native long parseProtocol(long jsonColumnHandle); - private static native long parseHost(long jsonColumnHandle); - private static native long parseQuery(long jsonColumnHandle); - private static native long parseQueryWithLiteral(long jsonColumnHandle, String query); + /** + * Parse query and return a specific parameter for each URI from the incoming column. + * + * @param URIColumn The input strings column in which each row contains a URI. + * @param String The parameter to extract from the query + * @return A string column with query data extracted. + */ + public static ColumnVector parseURIQueryWithColumn(ColumnView uriColumn, ColumnView queryColumn) { + assert uriColumn.getType().equals(DType.STRING) : "Input type must be String"; + assert queryColumn.getType().equals(DType.STRING) : "Query type must be String"; + return new ColumnVector(parseQueryWithColumn(uriColumn.getNativeView(), queryColumn.getNativeView())); + } + + private static native long parseProtocol(long inputColumnHandle); + private static native long parseHost(long inputColumnHandle); + private static native long parseQuery(long inputColumnHandle); + private static native long parseQueryWithLiteral(long inputColumnHandle, String query); + private static native long parseQueryWithColumn(long inputColumnHandle, long queryColumnHandle); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java b/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java index c79633008c..f8ed45c704 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java @@ -124,6 +124,41 @@ void testQuery(String[] testData, String param) { } } + void testQuery(String[] testData, String[] params) { + String[] expectedQueryStrings = new String[testData.length]; + for (int i=0; i 0 && pair.substring(0, idx).equals(params[i])) { + subquery = pair.substring(idx + 1); + break; + } + } + } + expectedQueryStrings[i] = subquery; + } + try (ColumnVector v0 = ColumnVector.fromStrings(testData); + ColumnVector p0 = ColumnVector.fromStrings(params); + ColumnVector expectedQuery = ColumnVector.fromStrings(expectedQueryStrings); + ColumnVector queryResult = ParseURI.parseURIQueryWithColumn(v0, p0)) { + AssertUtils.assertColumnsAreEqual(expectedQuery, queryResult); + } + } + @Test void parseURISparkTest() { String[] testData = { @@ -180,11 +215,68 @@ void parseURISparkTest() { "userinfo@www.nvidia.com/path?query=1#Ref", "", null}; + + + String[] queries = { + "a", + "h", + // commented out until https://github.com/NVIDIA/spark-rapids/issues/10036 is fixed + //"object", + "object", + "a", + "h", + "a", + "f", + "g", + "a", + "a", + "f", + "g", + "a", + "a", + "b", + "a", + "", + "a", + "a", + "a", + "a", + "b", + "a", + "q", + "b", + "a", + "query", + "a", + "primekey_in", + "a", + "q", + "ExpertId", + "query", + "solutionId", + "f", + "param", + "", + "q", + "a", + "f", + "mnid=5080", + "f", + "a", + "param4", + "cloth", + "a", + "invalid", + "invalid", + "query", + "a", + "f"}; testProtocol(testData); testHost(testData); testQuery(testData); testQuery(testData, "query"); + testQuery(testData, queries); } @Test diff --git a/thirdparty/cudf b/thirdparty/cudf index f800f5a2fa..ef3ce4bc8d 160000 --- a/thirdparty/cudf +++ b/thirdparty/cudf @@ -1 +1 @@ -Subproject commit f800f5a2fa9a961699345e6febe740b4b8f4760e +Subproject commit ef3ce4bc8db008f58249241c16c80f7e6e600fa9