Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize string gather performance for large strings #7980

Merged
3 changes: 3 additions & 0 deletions cpp/benchmarks/string/copy_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ static void generate_bench_args(benchmark::internal::Benchmark* b)
int const max_rowlen = 1 << 13;
int const len_mult = 4;
generate_string_bench_args(b, min_rows, max_rows, row_mult, min_rowlen, max_rowlen, len_mult);

// Benchmark for very small strings
b->Args({67108864, 2});
}

#define COPY_BENCHMARK_DEFINE(name) \
Expand Down
161 changes: 148 additions & 13 deletions cpp/include/cudf/strings/detail/gather.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,139 @@ namespace cudf {
namespace strings {
namespace detail {

// Strategy 1: String-parallel
gaohao95 marked this conversation as resolved.
Show resolved Hide resolved
// This strategy assigns strings to warps so that each warp can cooperatively copy from the input
// location of the string to the corresponding output location. Large datatype (uint4) is used for
// stores. This strategy is best suited for large strings.

// Helper function for loading 16B from a potentially unaligned memory location to registers.
__forceinline__ __device__ uint4 load_uint4(const char* ptr)
{
unsigned int* aligned_ptr = (unsigned int*)((size_t)ptr & ~(3));
uint4 regs = {0, 0, 0, 0};

regs.x = aligned_ptr[0];
regs.y = aligned_ptr[1];
regs.z = aligned_ptr[2];
regs.w = aligned_ptr[3];
uint tail = aligned_ptr[4];
gaohao95 marked this conversation as resolved.
Show resolved Hide resolved

unsigned int shift = ((size_t)ptr & 3) * 8;
gaohao95 marked this conversation as resolved.
Show resolved Hide resolved

regs.x = __funnelshift_r(regs.x, regs.y, shift);
regs.y = __funnelshift_r(regs.y, regs.z, shift);
regs.z = __funnelshift_r(regs.z, regs.w, shift);
regs.w = __funnelshift_r(regs.w, tail, shift);

return regs;
}

template <typename StringIterator, typename MapIterator>
__global__ void gather_chars_fn_string_parallel(StringIterator strings_begin,
char* out_chars,
cudf::device_span<int32_t const> const out_offsets,
MapIterator string_indices,
size_type total_out_strings)
{
int global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
int global_warp_id = global_thread_id / 32;
int warp_lane = global_thread_id % 32;
int nwarps = gridDim.x * blockDim.x / 32;

size_t alignment_offset = reinterpret_cast<size_t>(out_chars) & 15;
gaohao95 marked this conversation as resolved.
Show resolved Hide resolved
uint4* out_chars_aligned = reinterpret_cast<uint4*>(out_chars - alignment_offset);

for (size_type istring = global_warp_id; istring < total_out_strings; istring += nwarps) {
auto out_start = out_offsets[istring];
auto out_end = out_offsets[istring + 1];
gaohao95 marked this conversation as resolved.
Show resolved Hide resolved

// This check is necessary because string_indices[istring] may be out of bound.
if (out_start == out_end) continue;

const char* in_start = strings_begin[string_indices[istring]].data();

int32_t out_start_aligned = (out_start + alignment_offset + 15) / 16 * 16 - alignment_offset;
gaohao95 marked this conversation as resolved.
Show resolved Hide resolved
int32_t out_end_aligned = (out_end + alignment_offset) / 16 * 16 - alignment_offset;

for (size_type ichar = out_start_aligned + warp_lane * 16; ichar < out_end_aligned;
ichar += 32 * 16) {
gaohao95 marked this conversation as resolved.
Show resolved Hide resolved
*(out_chars_aligned + (ichar + alignment_offset) / 16) =
load_uint4(in_start + ichar - out_start);
}

if (out_end_aligned <= out_start_aligned) {
gaohao95 marked this conversation as resolved.
Show resolved Hide resolved
int32_t ichar = out_start + warp_lane;
if (ichar < out_end) { out_chars[ichar] = in_start[warp_lane]; }
} else {
if (out_start + warp_lane < out_start_aligned) {
out_chars[out_start + warp_lane] = in_start[warp_lane];
}

int32_t ichar = out_end_aligned + warp_lane;
if (ichar < out_end) { out_chars[ichar] = in_start[ichar - out_start]; }
}
}
}

// Strategy 2: Char-parallel
// This strategy assigns characters to threads, and uses binary search for getting the string
// index. To improve the binary search performance, fixed number of strings per threadblock is
// used. This strategy is best suited for small strings.
constexpr static int strings_per_threadblock = 32;

// Binary search `value` in `offsets` of length `nelements`. Require `nelements` to be less than or
// equal to `strings_per_threadblock`. Require `strings_per_threadblock` to be an exponential of 2.
__forceinline__ __device__ size_type binary_search(int32_t* offsets,
gaohao95 marked this conversation as resolved.
Show resolved Hide resolved
int32_t value,
size_type nelements)
{
size_type idx = 0;
#pragma unroll
for (size_type i = strings_per_threadblock / 2; i > 0; i /= 2) {
if (idx + i < nelements && offsets[idx + i] <= value) idx += i;
}
return idx;
}

template <typename StringIterator, typename MapIterator>
__global__ void gather_chars_fn_char_parallel(StringIterator strings_begin,
char* out_chars,
cudf::device_span<int32_t const> const out_offsets,
MapIterator string_indices,
size_type total_out_strings)
{
__shared__ int32_t out_offsets_threadblock[strings_per_threadblock + 1];
gaohao95 marked this conversation as resolved.
Show resolved Hide resolved

// Current thread block will process output strings starting at `begin_out_string_idx`.
size_type begin_out_string_idx = blockIdx.x * strings_per_threadblock;

// Number of strings to be processed by the current threadblock.
size_type strings_current_threadblock =
min(strings_per_threadblock, total_out_strings - begin_out_string_idx);

if (strings_current_threadblock <= 0) return;

// Collectively load offsets of strings processed by the current thread block.
for (size_type idx = threadIdx.x; idx <= strings_current_threadblock; idx += blockDim.x) {
out_offsets_threadblock[idx] = out_offsets[idx + begin_out_string_idx];
}
__syncthreads();

for (int32_t out_ibyte = threadIdx.x + out_offsets_threadblock[0];
out_ibyte < out_offsets_threadblock[strings_current_threadblock];
out_ibyte += blockDim.x) {
// binary search for the string index corresponding to out_ibyte
size_type string_idx =
binary_search(out_offsets_threadblock, out_ibyte, strings_current_threadblock);

// calculate which character to load within the string
int32_t icharacter = out_ibyte - out_offsets_threadblock[string_idx];

size_type in_string_idx = string_indices[begin_out_string_idx + string_idx];
out_chars[out_ibyte] = strings_begin[in_string_idx].data()[icharacter];
}
}

/**
* @brief Returns a new chars column using the specified indices to select
* strings from the input iterator.
Expand Down Expand Up @@ -68,20 +201,22 @@ std::unique_ptr<cudf::column> gather_chars(StringIterator strings_begin,
auto chars_column = create_chars_child_column(output_count, chars_bytes, stream, mr);
auto const d_chars = chars_column->mutable_view().template data<char>();

auto gather_chars_fn = [strings_begin, map_begin, offsets] __device__(size_type out_idx) -> char {
auto const out_row =
thrust::prev(thrust::upper_bound(thrust::seq, offsets.begin(), offsets.end(), out_idx));
auto const row_idx = map_begin[thrust::distance(offsets.begin(), out_row)]; // get row index
auto const d_str = strings_begin[row_idx]; // get row's string
auto const offset = out_idx - *out_row; // get string's char
return d_str.data()[offset];
};
size_type average_string_length = chars_bytes / output_count;

thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(chars_bytes),
d_chars,
gather_chars_fn);
if (average_string_length > 32) {
gather_chars_fn_string_parallel<<<min((static_cast<int>(output_count) + 3) / 4, 65536),
gaohao95 marked this conversation as resolved.
Show resolved Hide resolved
128,
0,
stream.value()>>>(
strings_begin, d_chars, offsets, map_begin, output_count);
} else {
gather_chars_fn_char_parallel<<<(output_count + strings_per_threadblock - 1) /
strings_per_threadblock,
128,
0,
stream.value()>>>(
strings_begin, d_chars, offsets, map_begin, output_count);
}

return chars_column;
}
Expand Down