Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[auto-merge] bot-auto-merge-branch-24.02 to branch-24.04 [skip ci] [bot] #1731

Merged
merged 2 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/main/cpp/src/ParseURIJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,19 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_ParseURI_parseQueryWith
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_ParseURI_parseQueryWithColumn(
JNIEnv* env, jclass, jlong input_column, jlong query_column)
{
JNI_NULL_CHECK(env, input_column, "input column is null", 0);
JNI_NULL_CHECK(env, query_column, "query column is null", 0);

try {
cudf::jni::auto_set_device(env);
auto const input = reinterpret_cast<cudf::column_view const*>(input_column);
auto const query = reinterpret_cast<cudf::column_view const*>(query_column);
return cudf::jni::ptr_as_jlong(spark_rapids_jni::parse_uri_to_query(*input, *query).release());
}
CATCH_STD(env, 0);
}
}
17 changes: 16 additions & 1 deletion src/main/cpp/src/parse_uri.cu
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,11 @@ uri_parts __device__ validate_uri(const char* str,
// passed as param0, the return would simply be 5.
if (query_match && query_match->size() > 0) {
auto const match_idx = row_idx % query_match->size();
auto in_match = query_match->element<string_view>(match_idx);
if (query_match->is_null(match_idx)) {
ret.valid = 0;
return ret;
}
auto in_match = query_match->element<string_view>(match_idx);

auto const [query, valid] = find_query_part(ret.query, in_match);
if (!valid) {
Expand Down Expand Up @@ -993,4 +997,15 @@ std::unique_ptr<cudf::column> parse_uri_to_query(cudf::strings_column_view const
return detail::parse_uri(input, detail::URI_chunks::QUERY, strings_column_view(*col), stream, mr);
}

std::unique_ptr<cudf::column> parse_uri_to_query(cudf::strings_column_view const& input,
cudf::strings_column_view const& query_match,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
CUDF_EXPECTS(input.size() == query_match.size(), "Query column must be the same size as input!");

return detail::parse_uri(input, detail::URI_chunks::QUERY, query_match, stream, mr);
}

} // namespace spark_rapids_jni
15 changes: 15 additions & 0 deletions src/main/cpp/src/parse_uri.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,19 @@ std::unique_ptr<cudf::column> parse_uri_to_query(
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Parse query and copy from the input string column to the output string column.
*
* @param input Input string column of URIs to parse.
* @param query_match string column to match in query.
* @param stream Stream on which to operate.
* @param mr Memory resource for returned column.
* @return std::unique_ptr<column> String column of queries parsed.
*/
std::unique_ptr<cudf::column> parse_uri_to_query(
cudf::strings_column_view const& input,
cudf::strings_column_view const& query_match,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

} // namespace spark_rapids_jni
11 changes: 11 additions & 0 deletions src/main/cpp/tests/parse_uri.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,4 +395,15 @@ TEST_F(ParseURIQueryTests, Queries)

CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view());
}
{
cudf::test::strings_column_wrapper const query(
{"param0", "q", "a", "invalid", "test", "query", "fakeparam0", "C"});
cudf::test::strings_column_wrapper const expected({"1", "", "b", "param", "", "1", "5", "C"},
{1, 0, 1, 1, 0, 1, 1, 1});

auto const result = spark_rapids_jni::parse_uri_to_query(cudf::strings_column_view{col},
cudf::strings_column_view{query});

CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view());
}
}
22 changes: 18 additions & 4 deletions src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,22 @@ public static ColumnVector parseURIQueryWithLiteral(ColumnView uriColumn, String
return new ColumnVector(parseQueryWithLiteral(uriColumn.getNativeView(), query));
}

private static native long parseProtocol(long jsonColumnHandle);
private static native long parseHost(long jsonColumnHandle);
private static native long parseQuery(long jsonColumnHandle);
private static native long parseQueryWithLiteral(long jsonColumnHandle, String query);
/**
* Parse query and return a specific parameter for each URI from the incoming column.
*
* @param URIColumn The input strings column in which each row contains a URI.
* @param String The parameter to extract from the query
* @return A string column with query data extracted.
*/
public static ColumnVector parseURIQueryWithColumn(ColumnView uriColumn, ColumnView queryColumn) {
assert uriColumn.getType().equals(DType.STRING) : "Input type must be String";
assert queryColumn.getType().equals(DType.STRING) : "Query type must be String";
return new ColumnVector(parseQueryWithColumn(uriColumn.getNativeView(), queryColumn.getNativeView()));
}

private static native long parseProtocol(long inputColumnHandle);
private static native long parseHost(long inputColumnHandle);
private static native long parseQuery(long inputColumnHandle);
private static native long parseQueryWithLiteral(long inputColumnHandle, String query);
private static native long parseQueryWithColumn(long inputColumnHandle, long queryColumnHandle);
}
92 changes: 92 additions & 0 deletions src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,41 @@ void testQuery(String[] testData, String param) {
}
}

void testQuery(String[] testData, String[] params) {
String[] expectedQueryStrings = new String[testData.length];
for (int i=0; i<testData.length; i++) {
String query = null;
try {
URI uri = new URI(testData[i]);
query = uri.getRawQuery();
} catch (URISyntaxException ex) {
// leave the query null if URI is invalid
} catch (NullPointerException ex) {
// leave the query null if URI is null
}

String subquery = null;

if (query != null) {
String[] pairs = query.split("&");
for (String pair : pairs) {
int idx = pair.indexOf("=");
if (idx > 0 && pair.substring(0, idx).equals(params[i])) {
subquery = pair.substring(idx + 1);
break;
}
}
}
expectedQueryStrings[i] = subquery;
}
try (ColumnVector v0 = ColumnVector.fromStrings(testData);
ColumnVector p0 = ColumnVector.fromStrings(params);
ColumnVector expectedQuery = ColumnVector.fromStrings(expectedQueryStrings);
ColumnVector queryResult = ParseURI.parseURIQueryWithColumn(v0, p0)) {
AssertUtils.assertColumnsAreEqual(expectedQuery, queryResult);
}
}

@Test
void parseURISparkTest() {
String[] testData = {
Expand Down Expand Up @@ -180,11 +215,68 @@ void parseURISparkTest() {
"userinfo@www.nvidia.com/path?query=1#Ref",
"",
null};


String[] queries = {
"a",
"h",
// commented out until https://github.com/NVIDIA/spark-rapids/issues/10036 is fixed
//"object",
"object",
"a",
"h",
"a",
"f",
"g",
"a",
"a",
"f",
"g",
"a",
"a",
"b",
"a",
"",
"a",
"a",
"a",
"a",
"b",
"a",
"q",
"b",
"a",
"query",
"a",
"primekey_in",
"a",
"q",
"ExpertId",
"query",
"solutionId",
"f",
"param",
"",
"q",
"a",
"f",
"mnid=5080",
"f",
"a",
"param4",
"cloth",
"a",
"invalid",
"invalid",
"query",
"a",
"f"};

testProtocol(testData);
testHost(testData);
testQuery(testData);
testQuery(testData, "query");
testQuery(testData, queries);
}

@Test
Expand Down