Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memcheck error found in STRINGS_TEST #13578

Merged
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 35 additions & 26 deletions cpp/src/strings/search/find.cu
Original file line number Diff line number Diff line change
Expand Up @@ -274,30 +274,39 @@ namespace {
/**
* @brief Check if `d_target` appears in a row in `d_strings`.
*
* This executes as a warp per string/row.
* This executes as a warp per string/row and performs well for longer strings.
* @see AVG_CHAR_BYTES_THRESHOLD
*
* @param d_strings Column of input strings
* @param d_target String to search for in each row of `d_strings`
* @param d_results Indicates which rows contain `d_target`
*/
struct contains_warp_fn {
column_device_view const d_strings;
string_view const d_target;
bool* d_results;
__global__ void contains_warp_parallel_fn(column_device_view const d_strings,
string_view const d_target,
bool* d_results)
{
size_type const idx = static_cast<size_type>(threadIdx.x + blockIdx.x * blockDim.x);
using warp_reduce = cub::WarpReduce<bool>;
__shared__ typename warp_reduce::TempStorage temp_storage;

__device__ void operator()(std::size_t idx)
{
auto const str_idx = static_cast<size_type>(idx / cudf::detail::warp_size);
if (d_strings.is_null(str_idx)) { return; }
// get the string for this warp
auto const d_str = d_strings.element<string_view>(str_idx);
// each thread of the warp will check just part of the string
auto found = false;
for (auto i = static_cast<size_type>(idx % cudf::detail::warp_size);
!found && (i + d_target.size_bytes()) < d_str.size_bytes();
i += cudf::detail::warp_size) {
// check the target matches this part of the d_str data
if (d_target.compare(d_str.data() + i, d_target.size_bytes()) == 0) { found = true; }
}
if (found) { atomicOr(d_results + str_idx, true); }
if (idx >= (d_strings.size() * cudf::detail::warp_size)) { return; }
Copy link
Contributor

@ttnghia ttnghia Jun 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again I suspect that this mul can overflow, as all the operands are of type int. So maybe we should cast into int64_t?

Suggested change
if (idx >= (d_strings.size() * cudf::detail::warp_size)) { return; }
if (static_cast<int64_t>(idx) >= static_cast<...>(d_strings.size()) * static_cast<...>(cudf::detail::warp_size)) { return; }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An overflow cannot technically occur here since this code path is only for long strings which is always much greater than 32 bytes on average. This means the (number of rows * 32) will never overflow under these conditions.


auto const str_idx = idx / cudf::detail::warp_size;
auto const lane_idx = idx % cudf::detail::warp_size;
if (d_strings.is_null(str_idx)) { return; }
// get the string for this warp
auto const d_str = d_strings.element<string_view>(str_idx);
// each thread of the warp will check just part of the string
auto found = false;
for (auto i = static_cast<size_type>(idx % cudf::detail::warp_size);
!found && (i + d_target.size_bytes()) < d_str.size_bytes();
i += cudf::detail::warp_size) {
// check the target matches this part of the d_str data
if (d_target.compare(d_str.data() + i, d_target.size_bytes()) == 0) { found = true; }
}
};
auto const result = warp_reduce(temp_storage).Reduce(found, cub::Max());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we could get an early-exit benefit by checking the warp-reduced result before reading the full string. (Discussed offline with @davidwendt.) I don't have a good expectation for the synchronization cost of a single warp sync. It'll probably be slower, but I'd like to learn by how much.

if (lane_idx == 0) { d_results[str_idx] = result; }
}

std::unique_ptr<column> contains_warp_parallel(strings_column_view const& input,
string_scalar const& target,
Expand All @@ -324,11 +333,11 @@ std::unique_ptr<column> contains_warp_parallel(strings_column_view const& input,

if (!d_target.empty()) {
// launch warp per string
auto d_strings = column_device_view::create(input.parent(), stream);
thrust::for_each_n(rmm::exec_policy(stream),
thrust::make_counting_iterator<std::size_t>(0),
static_cast<std::size_t>(input.size()) * cudf::detail::warp_size,
contains_warp_fn{*d_strings, d_target, results_view.data<bool>()});
auto const d_strings = column_device_view::create(input.parent(), stream);
constexpr int block_size = 256;
cudf::detail::grid_1d grid{input.size() * cudf::detail::warp_size, block_size};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if input.size() * cudf::detail::warp_size overflow? grid_1d only has int members.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I just found the similar code in other files (attributes.cu and find.cu). So this may be a new potential issue.

Copy link
Contributor Author

@davidwendt davidwendt Jun 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You missed this line of code perhaps?
https://github.com/rapidsai/cudf/pull/13578/files#diff-048f86c21559b14f64f86aaeaa57776d366c3a4948a5aba7c0ab1a3801be87bcR292

 if (idx >= (d_strings.size() * cudf::detail::warp_size)) { return; }

That did not format too well. It is line 292 currently.
This line is in attributes.cu as well.

Copy link
Contributor

@ttnghia ttnghia Jun 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, that line is inside the kernel, while this line is before kernel launch. If we have overflow here, we may still launch a kernel with some (large?) input. We should avoid launching the kernel from here instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An overflow cannot technically occur here since this code path is only for long strings which is always much greater than 32 bytes on average. This means the number of rows * 32 will never overflow under these conditions.

contains_warp_parallel_fn<<<grid.num_blocks, grid.num_threads_per_block, 0, stream.value()>>>(
*d_strings, d_target, results_view.data<bool>());
}
results->set_null_count(input.null_count());
return results;
Expand Down