Skip to content

Commit

Permalink
Use make_device_uvector_async, reverse structured binding, and re-o…
Browse files Browse the repository at this point in the history
…rganize input validity
  • Loading branch information
ttnghia committed Apr 19, 2021
1 parent 281fb61 commit 4ba910a
Showing 1 changed file with 21 additions and 30 deletions.
51 changes: 21 additions & 30 deletions cpp/src/search/search.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <cudf/detail/iterator.cuh>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/search.hpp>
#include <cudf/detail/utilities/vector_factories.hpp>
#include <cudf/dictionary/detail/search.hpp>
#include <cudf/dictionary/detail/update_keys.hpp>
#include <cudf/scalar/scalar_device_view.cuh>
Expand Down Expand Up @@ -77,6 +78,13 @@ std::unique_ptr<column> search_ordered(table_view const& t,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_EXPECTS(
column_order.empty() or static_cast<std::size_t>(t.num_columns()) == column_order.size(),
"Mismatch between number of columns and column order.");
CUDF_EXPECTS(
null_precedence.empty() or static_cast<std::size_t>(t.num_columns()) == null_precedence.size(),
"Mismatch between number of columns and null precedence.");

// Allocate result column
auto result = make_numeric_column(
data_type{type_to_id<size_type>()}, values.num_rows(), mask_state::UNALLOCATED, stream, mr);
Expand All @@ -88,45 +96,28 @@ std::unique_ptr<column> search_ordered(table_view const& t,
return result;
}

if (not column_order.empty()) {
CUDF_EXPECTS(static_cast<std::size_t>(t.num_columns()) == column_order.size(),
"Mismatch between number of columns and column order.");
}

if (not null_precedence.empty()) {
CUDF_EXPECTS(static_cast<std::size_t>(t.num_columns()) == null_precedence.size(),
"Mismatch between number of columns and null precedence.");
}

// This utility will ensure all corresponding dictionary columns have matching keys.
// It will return any new dictionary columns created as well as updated table_views.
auto const [_, t_vals_matched] = dictionary::detail::match_dictionaries({t, values}, stream);
auto const matched = dictionary::detail::match_dictionaries({t, values}, stream);

auto const [t_flattened, column_order_flattened, null_precedence_flattened, __] =
structs::detail::flatten_nested_columns(t_vals_matched.front(), column_order, null_precedence);
auto const [values_flattened, ___, ____, _____] =
structs::detail::flatten_nested_columns(t_vals_matched.back(), {}, {});
// 0-table_view, 1-column_order, 2-null_precedence, 3-validity_columns
auto const t_flattened =
structs::detail::flatten_nested_columns(matched.second.front(), column_order, null_precedence);
auto const values_flattened =
structs::detail::flatten_nested_columns(matched.second.back(), {}, {});

auto const t_d = table_device_view::create(t_flattened, stream);
auto const values_d = table_device_view::create(values_flattened, stream);
auto const t_d = table_device_view::create(std::get<0>(t_flattened), stream);
auto const values_d = table_device_view::create(std::get<0>(values_flattened), stream);
auto const& lhs = find_first ? *t_d : *values_d;
auto const& rhs = find_first ? *values_d : *t_d;

rmm::device_uvector<order> column_order_dv(column_order_flattened.size(), stream);
rmm::device_uvector<null_order> null_precedence_dv(null_precedence_flattened.size(), stream);
CUDA_TRY(cudaMemcpyAsync(column_order_dv.data(),
column_order_flattened.data(),
sizeof(order) * column_order_flattened.size(),
cudaMemcpyDefault,
stream.value()));
CUDA_TRY(cudaMemcpyAsync(null_precedence_dv.data(),
null_precedence_flattened.data(),
sizeof(null_order) * null_precedence_flattened.size(),
cudaMemcpyDefault,
stream.value()));
auto const& column_order_flattened = std::get<1>(t_flattened);
auto const& null_precedence_flattened = std::get<2>(t_flattened);
auto const column_order_dv = detail::make_device_uvector_async(column_order_flattened, stream);
auto const null_precedence_dv =
detail::make_device_uvector_async(null_precedence_flattened, stream);

auto const count_it = thrust::make_counting_iterator<size_type>(0);

if (has_nulls(t) or has_nulls(values)) {
auto const comp = row_lexicographic_comparator<true>(
lhs, rhs, column_order_dv.data(), null_precedence_dv.data());
Expand Down

0 comments on commit 4ba910a

Please sign in to comment.