Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Struct binary search (lower_bound/upper_bound) #7865

Merged
merged 28 commits into from
Apr 19, 2021
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f60b127
Extract DictionarySearchTest from SearchTest
ttnghia Apr 2, 2021
611252b
Add file for StructSearchTest
ttnghia Apr 2, 2021
aef67cc
Rename files
ttnghia Apr 2, 2021
9e73bbb
Remove test
ttnghia Apr 2, 2021
55bf495
Add one more test for StructSearchTest
ttnghia Apr 2, 2021
a8d0a67
And TrivialInputTests for StructSearchTest
ttnghia Apr 2, 2021
1b063d0
Flatten columns for binary search
ttnghia Apr 2, 2021
a5984d0
Simplify binary search
ttnghia Apr 5, 2021
4fb33f9
Add more tests
ttnghia Apr 5, 2021
aafea01
Rewrite tests
ttnghia Apr 5, 2021
30348cc
Fix SimpleInputWithNullsTests
ttnghia Apr 5, 2021
a842198
Finish ComplexStructTest
ttnghia Apr 5, 2021
02988f1
Rename variables
ttnghia Apr 6, 2021
c7c10aa
Fix typo
ttnghia Apr 6, 2021
dbbe480
Rewrite `search_ordered`, replacing `device_vector` by `device_uvecto…
ttnghia Apr 6, 2021
304cfe8
Reorder variables' declaration
ttnghia Apr 6, 2021
36f478c
Fix copyright year in header
ttnghia Apr 6, 2021
63655d9
Simplify StructSearchTests
ttnghia Apr 7, 2021
d17f09a
Change copied variables into references.
ttnghia Apr 7, 2021
ede552a
Fix test for ComplexStructTest
ttnghia Apr 8, 2021
065ee10
Merge remote-tracking branch 'origin/branch-0.20' into struct_binary_…
ttnghia Apr 8, 2021
07a7ca0
Use structure binding to simplify code
ttnghia Apr 8, 2021
43139b8
Remove redundant comment
ttnghia Apr 8, 2021
eead52f
Ignore variable in structure binding
ttnghia Apr 8, 2021
fc3b252
Add SlicedColumnInputTests
ttnghia Apr 8, 2021
661a5cd
Disable debug printing
ttnghia Apr 9, 2021
281fb61
Merge branch 'branch-0.20' into struct_binary_search
ttnghia Apr 19, 2021
4ba910a
Use `make_device_uvector_async`, reverse structured binding, and re-o…
ttnghia Apr 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 42 additions & 42 deletions cpp/src/search/search.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
#include <cudf/table/row_operators.cuh>
#include <cudf/table/table_device_view.cuh>
#include <cudf/table/table_view.hpp>
#include <structs/utilities.hpp>

#include <hash/unordered_multiset.cuh>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/binary_search.h>
Expand Down Expand Up @@ -76,15 +78,13 @@ std::unique_ptr<column> search_ordered(table_view const& t,
rmm::mr::device_memory_resource* mr)
{
// Allocate result column
std::unique_ptr<column> result = make_numeric_column(
auto result = make_numeric_column(
data_type{type_to_id<size_type>()}, values.num_rows(), mask_state::UNALLOCATED, stream, mr);

mutable_column_view result_view = result.get()->mutable_view();
auto const result_out = result->mutable_view().data<size_type>();

// Handle empty inputs
if (t.num_rows() == 0) {
CUDA_TRY(cudaMemsetAsync(
result_view.data<size_type>(), 0, values.num_rows() * sizeof(size_type), stream.value()));
CUDA_TRY(cudaMemsetAsync(result_out, 0, values.num_rows() * sizeof(size_type), stream.value()));
return result;
}

Expand All @@ -100,46 +100,46 @@ std::unique_ptr<column> search_ordered(table_view const& t,

// This utility will ensure all corresponding dictionary columns have matching keys.
// It will return any new dictionary columns created as well as updated table_views.
auto matched = dictionary::detail::match_dictionaries({t, values}, stream);
auto d_t = table_device_view::create(matched.second.front(), stream);
auto d_values = table_device_view::create(matched.second.back(), stream);
auto count_it = thrust::make_counting_iterator<size_type>(0);

rmm::device_vector<order> d_column_order(column_order.begin(), column_order.end());
rmm::device_vector<null_order> d_null_precedence(null_precedence.begin(), null_precedence.end());
auto const matched = dictionary::detail::match_dictionaries({t, values}, stream);

// 0-table_view, 1-column_order, 2-null_precedence, 3-validity_columns
auto const t_flattened =
structs::detail::flatten_nested_columns(matched.second.front(), column_order, null_precedence);
auto const values_flattened =
structs::detail::flatten_nested_columns(matched.second.back(), {}, {});

auto const t_d = table_device_view::create(std::get<0>(t_flattened), stream);
auto const values_d = table_device_view::create(std::get<0>(values_flattened), stream);
auto const& lhs = find_first ? *t_d : *values_d;
auto const& rhs = find_first ? *values_d : *t_d;

auto const& column_order_flattened = std::get<1>(t_flattened);
auto const& null_precedence_flattened = std::get<2>(t_flattened);
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
rmm::device_uvector<order> column_order_dv(column_order_flattened.size(), stream);
rmm::device_uvector<null_order> null_precedence_dv(null_precedence_flattened.size(), stream);
CUDA_TRY(cudaMemcpyAsync(column_order_dv.data(),
column_order_flattened.data(),
sizeof(order) * column_order_flattened.size(),
cudaMemcpyDefault,
stream.value()));
CUDA_TRY(cudaMemcpyAsync(null_precedence_dv.data(),
null_precedence_flattened.data(),
sizeof(null_order) * null_precedence_flattened.size(),
cudaMemcpyDefault,
stream.value()));
ttnghia marked this conversation as resolved.
Show resolved Hide resolved

auto const count_it = thrust::make_counting_iterator<size_type>(0);

if (has_nulls(t) or has_nulls(values)) {
auto ineq_op =
(find_first)
? row_lexicographic_comparator<true>(
*d_t, *d_values, d_column_order.data().get(), d_null_precedence.data().get())
: row_lexicographic_comparator<true>(
*d_values, *d_t, d_column_order.data().get(), d_null_precedence.data().get());

launch_search(count_it,
count_it,
t.num_rows(),
values.num_rows(),
result_view.data<size_type>(),
ineq_op,
find_first,
stream);
auto const comp = row_lexicographic_comparator<true>(
lhs, rhs, column_order_dv.data(), null_precedence_dv.data());
launch_search(
count_it, count_it, t.num_rows(), values.num_rows(), result_out, comp, find_first, stream);
} else {
auto ineq_op =
(find_first)
? row_lexicographic_comparator<false>(
*d_t, *d_values, d_column_order.data().get(), d_null_precedence.data().get())
: row_lexicographic_comparator<false>(
*d_values, *d_t, d_column_order.data().get(), d_null_precedence.data().get());

launch_search(count_it,
count_it,
t.num_rows(),
values.num_rows(),
result_view.data<size_type>(),
ineq_op,
find_first,
stream);
auto const comp = row_lexicographic_comparator<false>(
lhs, rhs, column_order_dv.data(), null_precedence_dv.data());
launch_search(
count_it, count_it, t.num_rows(), values.num_rows(), result_out, comp, find_first, stream);
}
ttnghia marked this conversation as resolved.
Show resolved Hide resolved

return result;
Expand Down
5 changes: 4 additions & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,10 @@ ConfigureTest(FILLING_TEST

###################################################################################################
# - search test -----------------------------------------------------------------------------------
ConfigureTest(SEARCH_TEST search/search_test.cpp)
ConfigureTest(SEARCH_TEST
search/search_dictionary_test.cpp
search/search_struct_test.cpp
search/search_test.cpp)

###################################################################################################
# - reshape test ----------------------------------------------------------------------------------
Expand Down
107 changes: 107 additions & 0 deletions cpp/tests/search/search_dictionary_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
hyperbolic2346 marked this conversation as resolved.
Show resolved Hide resolved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf_test/base_fixture.hpp>
#include <cudf_test/column_utilities.hpp>
#include <cudf_test/column_wrapper.hpp>
#include <cudf_test/type_lists.hpp>

#include <cudf/search.hpp>

struct DictionarySearchTest : public cudf::test::BaseFixture {
};

using cudf::numeric_scalar;
using cudf::size_type;
using cudf::string_scalar;
using cudf::test::fixed_width_column_wrapper;

TEST_F(DictionarySearchTest, search_dictionary)
{
cudf::test::dictionary_column_wrapper<std::string> input(
{"", "", "10", "10", "20", "20", "30", "40"}, {0, 0, 1, 1, 1, 1, 1, 1});
cudf::test::dictionary_column_wrapper<std::string> values(
{"", "08", "10", "11", "30", "32", "90"}, {0, 1, 1, 1, 1, 1, 1});

auto result = cudf::upper_bound({cudf::table_view{{input}}},
{cudf::table_view{{values}}},
{cudf::order::ASCENDING},
{cudf::null_order::BEFORE});
fixed_width_column_wrapper<size_type> expect_upper{2, 2, 4, 4, 7, 7, 8};
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*result, expect_upper);

result = cudf::lower_bound({cudf::table_view{{input}}},
{cudf::table_view{{values}}},
{cudf::order::ASCENDING},
{cudf::null_order::BEFORE});
fixed_width_column_wrapper<size_type> expect_lower{0, 2, 2, 4, 6, 7, 8};
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*result, expect_lower);
}

TEST_F(DictionarySearchTest, search_table_dictionary)
{
fixed_width_column_wrapper<int32_t> column_0{{10, 10, 20, 20, 20, 20, 20, 20, 20, 50, 30},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0}};
fixed_width_column_wrapper<float> column_1{{5.0, 6.0, .5, .5, .5, .5, .7, .7, .7, .7, .5},
{1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}};
cudf::test::dictionary_column_wrapper<int16_t> column_2{
{90, 95, 77, 78, 79, 76, 61, 62, 63, 41, 50}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1}};
cudf::table_view input({column_0, column_1, column_2});

fixed_width_column_wrapper<int32_t> values_0{{10, 40, 20}, {1, 0, 1}};
fixed_width_column_wrapper<float> values_1{{6., .5, .5}, {0, 1, 1}};
hyperbolic2346 marked this conversation as resolved.
Show resolved Hide resolved
cudf::test::dictionary_column_wrapper<int16_t> values_2{{95, 50, 77}, {1, 1, 0}};
cudf::table_view values({values_0, values_1, values_2});

std::vector<cudf::order> order_flags{
{cudf::order::ASCENDING, cudf::order::ASCENDING, cudf::order::DESCENDING}};
std::vector<cudf::null_order> null_order_flags{
{cudf::null_order::AFTER, cudf::null_order::AFTER, cudf::null_order::AFTER}};
hyperbolic2346 marked this conversation as resolved.
Show resolved Hide resolved

auto result = cudf::lower_bound(input, values, order_flags, null_order_flags);
fixed_width_column_wrapper<size_type> expect_lower{1, 10, 2};
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*result, expect_lower);

result = cudf::upper_bound(input, values, order_flags, null_order_flags);
fixed_width_column_wrapper<size_type> expect_upper{2, 11, 6};
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*result, expect_upper);
}

TEST_F(DictionarySearchTest, contains_dictionary)
{
cudf::test::dictionary_column_wrapper<std::string> column(
{"00", "00", "17", "17", "23", "23", "29"});
EXPECT_TRUE(cudf::contains(column, string_scalar{"23"}));
EXPECT_FALSE(cudf::contains(column, string_scalar{"28"}));

cudf::test::dictionary_column_wrapper<std::string> needles({"00", "17", "23", "27"});
fixed_width_column_wrapper<bool> expect{1, 1, 1, 1, 1, 1, 0};
auto result = cudf::contains(column, needles);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*result, expect);
}

TEST_F(DictionarySearchTest, contains_nullable_dictionary)
{
cudf::test::dictionary_column_wrapper<int64_t> column({0, 0, 17, 17, 23, 23, 29},
{1, 0, 1, 1, 1, 1, 1});
EXPECT_TRUE(cudf::contains(column, numeric_scalar<int64_t>{23}));
EXPECT_FALSE(cudf::contains(column, numeric_scalar<int64_t>{28}));

cudf::test::dictionary_column_wrapper<int64_t> needles({0, 17, 23, 27});
fixed_width_column_wrapper<bool> expect({1, 0, 1, 1, 1, 1, 0}, {1, 0, 1, 1, 1, 1, 1});
auto result = cudf::contains(column, needles);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*result, expect);
}
Loading