Skip to content

Commit

Permalink
JNI: Add generateListOffsets API (#10683)
Browse files Browse the repository at this point in the history
Add generateListOffsets API, converting list lengths to list offsets, which is useful in the development of spark-rapids.

For example, the support of [array_repeat](NVIDIA/spark-rapids#5226) and [arrays_zip](NVIDIA/spark-rapids#5229) relies on this API.

Authors:
  - Alfred Xu (https://github.com/sperlingxx)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - Liangcai Li (https://github.com/firestarman)

URL: #10683
  • Loading branch information
sperlingxx authored Apr 26, 2022
1 parent 57b9d0b commit cc0bf12
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 2 deletions.
12 changes: 12 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -3473,6 +3473,16 @@ public final ColumnVector listSortRows(boolean isDescending, boolean isNullSmall
return new ColumnVector(listSortRows(getNativeView(), isDescending, isNullSmallest));
}

/**
* Generate list offsets from sizes of each list.
* NOTICE: This API only works for INT32. Otherwise, the behavior is undefined. And no null and negative value is allowed.
*
* @return a column of list offsets whose size is N + 1
*/
public final ColumnVector generateListOffsets() {
return new ColumnVector(generateListOffsets(getNativeView()));
}

/**
* Get a single item from the column at the specified index as a Scalar.
*
Expand Down Expand Up @@ -4162,6 +4172,8 @@ static native long makeCudfColumnView(int type, int scale, long data, long dataS

static native long copyColumnViewToCV(long viewHandle) throws CudfException;

static native long generateListOffsets(long handle) throws CudfException;

/**
* A utility class to create column vector like objects without refcounts and other APIs when
* creating the device side vector from host side nested vectors. Eventually this can go away or
Expand Down
11 changes: 11 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,17 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listSortRows(JNIEnv *env,
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_generateListOffsets(JNIEnv *env, jclass,
jlong handle) {
JNI_NULL_CHECK(env, handle, "handle is null", 0)
try {
cudf::jni::auto_set_device(env);
auto const cv = reinterpret_cast<cudf::column_view const *>(handle);
return release_as_jlong(cudf::jni::generate_list_offsets(*cv));
}
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv *env, jclass,
jlong input_handle,
jstring pattern_obj,
Expand Down
23 changes: 22 additions & 1 deletion java/src/main/native/src/ColumnViewJni.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,8 +15,11 @@
*/

#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/iterator.cuh>
#include <cudf/detail/valid_if.cuh>
#include <rmm/exec_policy.hpp>
#include <thrust/scan.h>

#include "ColumnViewJni.hpp"

Expand Down Expand Up @@ -51,4 +54,22 @@ new_column_with_boolean_column_as_validity(cudf::column_view const &exemplar,
return deep_copy;
}

std::unique_ptr<cudf::column> generate_list_offsets(cudf::column_view const &list_length,
rmm::cuda_stream_view stream) {
CUDF_EXPECTS(list_length.type().id() == cudf::type_id::INT32,
"Input column does not have type INT32.");

auto const begin_iter = list_length.template begin<cudf::size_type>();
auto const end_iter = list_length.template end<cudf::size_type>();

auto offsets_column = make_numeric_column(data_type{type_id::INT32}, list_length.size() + 1,
mask_state::UNALLOCATED, stream);
auto offsets_view = offsets_column->mutable_view();
auto d_offsets = offsets_view.template begin<int32_t>();

thrust::inclusive_scan(rmm::exec_policy(stream), begin_iter, end_iter, d_offsets + 1);
CUDF_CUDA_TRY(cudaMemsetAsync(d_offsets, 0, sizeof(int32_t), stream));

return offsets_column;
}
} // namespace cudf::jni
18 changes: 17 additions & 1 deletion java/src/main/native/src/ColumnViewJni.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
*/

#include <cudf/column/column.hpp>
#include <rmm/cuda_stream_view.hpp>

namespace cudf::jni {

Expand All @@ -35,4 +36,19 @@ std::unique_ptr<cudf::column>
new_column_with_boolean_column_as_validity(cudf::column_view const &exemplar,
cudf::column_view const &bool_column);

/**
* @brief Generates list offsets with lengths of each list.
*
* For example,
* Given a list column: [[1,2,3], [4,5], [6], [], [7,8]]
* The list lengths of it: [3, 2, 1, 0, 2]
* The list offsets of it: [0, 3, 5, 6, 6, 8]
*
* @param list_length The column represents list lengths.
* @return The column represents list offsets.
*/
std::unique_ptr<cudf::column>
generate_list_offsets(cudf::column_view const &list_length,
rmm::cuda_stream_view stream = rmm::cuda_stream_default);

} // namespace cudf::jni
15 changes: 15 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6284,4 +6284,19 @@ void testSegmentedGather() {
assertColumnsAreEqual(expected, actual);
}
}

@Test
void testGenerateListOffsets() {
try (ColumnVector index = ColumnVector.fromInts(1, 3, 3, 0, 2, 0, 0, 5, 10, 25);
ColumnVector actual = index.generateListOffsets();
ColumnVector expected = ColumnVector.fromInts(0, 1, 4, 7, 7, 9, 9, 9, 14, 24, 49)) {
assertColumnsAreEqual(expected, actual);
}

try (ColumnVector index = ColumnVector.fromInts(0, 0, 1, 0, 0);
ColumnVector actual = index.generateListOffsets();
ColumnVector expected = ColumnVector.fromInts(0, 0, 0, 1, 1, 1)) {
assertColumnsAreEqual(expected, actual);
}
}
}

0 comments on commit cc0bf12

Please sign in to comment.