Skip to content

Commit

Permalink
Fix memcheck error found in STRINGS_TEST (#13578)
Browse files Browse the repository at this point in the history
Fixes a memcheck error found in `STRINGS_TEST` where an `atomicOr` was used on a boolean device scalar. The workaround uses a `cub::WarpReduce` to compute the result in the warp-per-string kernel.

Reference #13574

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Nghia Truong (https://github.com/ttnghia)

URL: #13578
  • Loading branch information
davidwendt authored Jun 22, 2023
1 parent 708ee59 commit 7cbef2a
Showing 1 changed file with 35 additions and 26 deletions.
61 changes: 35 additions & 26 deletions cpp/src/strings/search/find.cu
Original file line number Diff line number Diff line change
Expand Up @@ -274,30 +274,39 @@ namespace {
/**
* @brief Check if `d_target` appears in a row in `d_strings`.
*
* This executes as a warp per string/row.
* This executes as a warp per string/row and performs well for longer strings.
* @see AVG_CHAR_BYTES_THRESHOLD
*
* @param d_strings Column of input strings
* @param d_target String to search for in each row of `d_strings`
* @param d_results Indicates which rows contain `d_target`
*/
struct contains_warp_fn {
column_device_view const d_strings;
string_view const d_target;
bool* d_results;
__global__ void contains_warp_parallel_fn(column_device_view const d_strings,
string_view const d_target,
bool* d_results)
{
size_type const idx = static_cast<size_type>(threadIdx.x + blockIdx.x * blockDim.x);
using warp_reduce = cub::WarpReduce<bool>;
__shared__ typename warp_reduce::TempStorage temp_storage;

__device__ void operator()(std::size_t idx)
{
auto const str_idx = static_cast<size_type>(idx / cudf::detail::warp_size);
if (d_strings.is_null(str_idx)) { return; }
// get the string for this warp
auto const d_str = d_strings.element<string_view>(str_idx);
// each thread of the warp will check just part of the string
auto found = false;
for (auto i = static_cast<size_type>(idx % cudf::detail::warp_size);
!found && (i + d_target.size_bytes()) < d_str.size_bytes();
i += cudf::detail::warp_size) {
// check the target matches this part of the d_str data
if (d_target.compare(d_str.data() + i, d_target.size_bytes()) == 0) { found = true; }
}
if (found) { atomicOr(d_results + str_idx, true); }
if (idx >= (d_strings.size() * cudf::detail::warp_size)) { return; }

auto const str_idx = idx / cudf::detail::warp_size;
auto const lane_idx = idx % cudf::detail::warp_size;
if (d_strings.is_null(str_idx)) { return; }
// get the string for this warp
auto const d_str = d_strings.element<string_view>(str_idx);
// each thread of the warp will check just part of the string
auto found = false;
for (auto i = static_cast<size_type>(idx % cudf::detail::warp_size);
!found && (i + d_target.size_bytes()) < d_str.size_bytes();
i += cudf::detail::warp_size) {
// check the target matches this part of the d_str data
if (d_target.compare(d_str.data() + i, d_target.size_bytes()) == 0) { found = true; }
}
};
auto const result = warp_reduce(temp_storage).Reduce(found, cub::Max());
if (lane_idx == 0) { d_results[str_idx] = result; }
}

std::unique_ptr<column> contains_warp_parallel(strings_column_view const& input,
string_scalar const& target,
Expand All @@ -324,11 +333,11 @@ std::unique_ptr<column> contains_warp_parallel(strings_column_view const& input,

if (!d_target.empty()) {
// launch warp per string
auto d_strings = column_device_view::create(input.parent(), stream);
thrust::for_each_n(rmm::exec_policy(stream),
thrust::make_counting_iterator<std::size_t>(0),
static_cast<std::size_t>(input.size()) * cudf::detail::warp_size,
contains_warp_fn{*d_strings, d_target, results_view.data<bool>()});
auto const d_strings = column_device_view::create(input.parent(), stream);
constexpr int block_size = 256;
cudf::detail::grid_1d grid{input.size() * cudf::detail::warp_size, block_size};
contains_warp_parallel_fn<<<grid.num_blocks, grid.num_threads_per_block, 0, stream.value()>>>(
*d_strings, d_target, results_view.data<bool>());
}
results->set_null_count(input.null_count());
return results;
Expand Down

0 comments on commit 7cbef2a

Please sign in to comment.