Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MD5 implementation. #9212

Merged
merged 54 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
9d158e8
Improve comments and naming.
bdice Sep 10, 2021
5630eed
Remove unused seed from MD5Hash constructor.
bdice Sep 10, 2021
f23735d
Expand detail namespace.
bdice Sep 10, 2021
ed56ef0
Remove seed and default constructors.
bdice Sep 10, 2021
03bfe6d
Move MD5 implementations from hash_functions.cuh to md5_hash.cu becau…
bdice Sep 10, 2021
7945efe
Change to unsigned types for md5_chunk_size.
bdice Sep 20, 2021
81bf129
Remove unused parameter names.
bdice Sep 20, 2021
ab3779b
clang-format.
bdice Sep 28, 2021
d2c4a0c
Replace magic number with named constant.
bdice Sep 28, 2021
bd766bf
Use memcpy instead of std::memcpy, add comments.
bdice Oct 1, 2021
e961f7c
Intermediate stage with optional iterator experiments on device.
bdice Oct 1, 2021
65446ed
Revert changes to column_device_view.cuh.
bdice Oct 1, 2021
b6d166b
Clean up duplication in typed element processing.
bdice Oct 1, 2021
60da684
More cleanup.
bdice Oct 1, 2021
3940daa
Simplify message length.
bdice Oct 1, 2021
6af5451
Make normalization pass-through for non-floating fixed-width types.
bdice Oct 1, 2021
b33ea93
Prefer char over uint8_t.
bdice Oct 1, 2021
af42f7e
Improve helper functions.
bdice Oct 1, 2021
f3038e1
Refactor finalize.
bdice Oct 1, 2021
e8c6e3e
Rename message length variable.
bdice Oct 1, 2021
c6867e6
Additional simplifications to finalize.
bdice Oct 1, 2021
4e8817e
Simplify padding.
bdice Oct 1, 2021
f83cfe0
Consolidate MD5 hash constants into .cu file, remove includes.
bdice Oct 11, 2021
37cb283
Implement hash_circular_buffer.
bdice Oct 12, 2021
a3be658
Remove non-const accessor.
bdice Oct 13, 2021
088fa89
Replace SFINAE with constexpr if.
bdice Oct 13, 2021
331107b
Refactor buffer and hash_step callbacks.
bdice Oct 13, 2021
af98bab
Move processing functions around.
bdice Oct 13, 2021
e604f14
MD5Hasher now owns its hash state.
bdice Oct 13, 2021
7ff0e72
Use HasherDispatcher to avoid unexpected memory access error with typ…
bdice Oct 14, 2021
167dab3
Move methods from md5_hash_state to MD5Hasher class.
bdice Oct 14, 2021
a8ac635
Improve naming.
bdice Oct 14, 2021
e8b61dd
Use destructor instead of finalize method.
bdice Oct 15, 2021
5150c5e
Avoid double type dispatch for list columns.
bdice Oct 15, 2021
3715e25
Remove hash_state class.
bdice Oct 15, 2021
677c2fe
Refactoring hash step callback structure.
bdice Oct 15, 2021
720b601
Remove MD5Hasher.put.
bdice Oct 15, 2021
55e4bce
Remove MD5Hasher.pad.
bdice Oct 15, 2021
27a8bf0
Move hash_values to be part of the hash step's state.
bdice Oct 18, 2021
e56e1d6
Construct hash step explicitly, use message_chunk_size for buffer sizes.
bdice Oct 18, 2021
5544d87
Make MD5Hash own the hash_values.
bdice Oct 18, 2021
ff038cd
Move variable.
bdice Oct 18, 2021
cc2bf4b
Declare HasherDispatchers inline.
bdice Oct 20, 2021
9fd47fa
Make HasherDispatcher column_device_view const&.
bdice Oct 20, 2021
f12eaf8
Parenthesize expression for readability.
bdice Oct 20, 2021
2f5c2b6
Add const.
bdice Oct 20, 2021
25c0e90
Rename Key to Element, get_data to get_element_pointer_and_size.
bdice Oct 20, 2021
b6c4c75
Add constructor to HasherDispatcher.
bdice Oct 20, 2021
d7ff965
Simplify hash_circular_buffer.
bdice Oct 21, 2021
8cdf423
Reformat shift constants.
bdice Oct 21, 2021
e7834dd
Add comment.
bdice Oct 21, 2021
92b329b
Remove unnecessary qualifiers.
bdice Oct 21, 2021
410b30d
Add comment about constant sources.
bdice Oct 21, 2021
c7d132f
Clarify flag.
bdice Oct 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 9 additions & 303 deletions cpp/include/cudf/detail/utilities/hash_functions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,110 +21,27 @@
#include <cudf/fixed_point/fixed_point.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/types.hpp>
#include <hash/hash_constants.hpp>

using hash_value_type = uint32_t;

namespace cudf {
namespace detail {
namespace {
/**
* @brief Core MD5 algorithm implementation. Processes a single 512-bit chunk,
* updating the hash value so far. Does not zero out the buffer contents.
*/
void CUDA_DEVICE_CALLABLE md5_hash_step(md5_intermediate_data* hash_state)
{
uint32_t A = hash_state->hash_value[0];
uint32_t B = hash_state->hash_value[1];
uint32_t C = hash_state->hash_value[2];
uint32_t D = hash_state->hash_value[3];

for (unsigned int j = 0; j < 64; j++) {
uint32_t F;
uint32_t g;
switch (j / 16) {
case 0:
F = (B & C) | ((~B) & D);
g = j;
break;
case 1:
F = (D & B) | ((~D) & C);
g = (5 * j + 1) % 16;
break;
case 2:
F = B ^ C ^ D;
g = (3 * j + 5) % 16;
break;
case 3:
F = C ^ (B | (~D));
g = (7 * j) % 16;
break;
}

uint32_t buffer_element_as_int;
std::memcpy(&buffer_element_as_int, hash_state->buffer + g * 4, 4);
F = F + A + md5_hash_constants[j] + buffer_element_as_int;
A = D;
D = C;
C = B;
B = B + __funnelshift_l(F, F, md5_shift_constants[((j / 16) * 4) + (j % 4)]);
}

hash_state->hash_value[0] += A;
hash_state->hash_value[1] += B;
hash_state->hash_value[2] += C;
hash_state->hash_value[3] += D;

hash_state->buffer_length = 0;
}

/**
* @brief Core MD5 element processing function
* Normalization of floating point NaNs and zeros, passthrough for all other values.
*/
template <typename TKey>
void CUDA_DEVICE_CALLABLE md5_process(TKey const& key, md5_intermediate_data* hash_state)
template <typename T>
T CUDA_DEVICE_CALLABLE normalize_nans_and_zeros(T const& key)
vyasr marked this conversation as resolved.
Show resolved Hide resolved
{
uint32_t const len = sizeof(TKey);
uint8_t const* data = reinterpret_cast<uint8_t const*>(&key);
hash_state->message_length += len;

// 64 bytes for the number of byt es processed in a given step
constexpr int md5_chunk_size = 64;
if (hash_state->buffer_length + len < md5_chunk_size) {
std::memcpy(hash_state->buffer + hash_state->buffer_length, data, len);
hash_state->buffer_length += len;
} else {
uint32_t copylen = md5_chunk_size - hash_state->buffer_length;

std::memcpy(hash_state->buffer + hash_state->buffer_length, data, copylen);
md5_hash_step(hash_state);

while (len > md5_chunk_size + copylen) {
std::memcpy(hash_state->buffer, data + copylen, md5_chunk_size);
md5_hash_step(hash_state);
copylen += md5_chunk_size;
if constexpr (is_floating_point<T>()) {
if (isnan(key)) {
return std::numeric_limits<T>::quiet_NaN();
} else if (key == T{0.0}) {
return T{0.0};
}

std::memcpy(hash_state->buffer, data + copylen, len - copylen);
hash_state->buffer_length = len - copylen;
}
}

/**
* Normalization of floating point NANs and zeros helper
*/
template <typename T, std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
T CUDA_DEVICE_CALLABLE normalize_nans_and_zeros_helper(T key)
{
if (isnan(key)) {
return std::numeric_limits<T>::quiet_NaN();
} else if (key == T{0.0}) {
return T{0.0};
} else {
return key;
}
return key;
}
} // namespace

/**
* Modified GPU implementation of
Expand All @@ -149,217 +66,6 @@ void CUDA_DEVICE_CALLABLE uint32ToLowercaseHexString(uint32_t num, char* destina
std::memcpy(destination, reinterpret_cast<uint8_t*>(&x), 8);
}

struct MD5ListHasher {
template <typename T, std::enable_if_t<is_chrono<T>()>* = nullptr>
void __device__ operator()(column_device_view data_col,
size_type offset_begin,
size_type offset_end,
md5_intermediate_data* hash_state) const
{
cudf_assert(false && "MD5 Unsupported chrono type column");
}

template <typename T, std::enable_if_t<!is_fixed_width<T>()>* = nullptr>
void __device__ operator()(column_device_view data_col,
size_type offset_begin,
size_type offset_end,
md5_intermediate_data* hash_state) const
{
cudf_assert(false && "MD5 Unsupported non-fixed-width type column");
}

template <typename T, std::enable_if_t<is_floating_point<T>()>* = nullptr>
void __device__ operator()(column_device_view data_col,
size_type offset_begin,
size_type offset_end,
md5_intermediate_data* hash_state) const
{
for (int i = offset_begin; i < offset_end; i++) {
if (!data_col.is_null(i)) {
md5_process(normalize_nans_and_zeros_helper<T>(data_col.element<T>(i)), hash_state);
}
}
}

template <
typename T,
std::enable_if_t<is_fixed_width<T>() && !is_floating_point<T>() && !is_chrono<T>()>* = nullptr>
void CUDA_DEVICE_CALLABLE operator()(column_device_view data_col,
size_type offset_begin,
size_type offset_end,
md5_intermediate_data* hash_state) const
{
for (int i = offset_begin; i < offset_end; i++) {
if (!data_col.is_null(i)) md5_process(data_col.element<T>(i), hash_state);
}
}
};

template <>
void CUDA_DEVICE_CALLABLE
MD5ListHasher::operator()<string_view>(column_device_view data_col,
size_type offset_begin,
size_type offset_end,
md5_intermediate_data* hash_state) const
{
for (int i = offset_begin; i < offset_end; i++) {
if (!data_col.is_null(i)) {
string_view key = data_col.element<string_view>(i);
uint32_t const len = static_cast<uint32_t>(key.size_bytes());
uint8_t const* data = reinterpret_cast<uint8_t const*>(key.data());

hash_state->message_length += len;

if (hash_state->buffer_length + len < 64) {
std::memcpy(hash_state->buffer + hash_state->buffer_length, data, len);
hash_state->buffer_length += len;
} else {
uint32_t copylen = 64 - hash_state->buffer_length;
std::memcpy(hash_state->buffer + hash_state->buffer_length, data, copylen);
md5_hash_step(hash_state);

while (len > 64 + copylen) {
std::memcpy(hash_state->buffer, data + copylen, 64);
md5_hash_step(hash_state);
copylen += 64;
}

std::memcpy(hash_state->buffer, data + copylen, len - copylen);
hash_state->buffer_length = len - copylen;
}
}
}
}

struct MD5Hash {
MD5Hash() = default;
constexpr MD5Hash(uint32_t seed) : m_seed(seed) {}

void __device__ finalize(md5_intermediate_data* hash_state, char* result_location) const
{
auto const full_length = (static_cast<uint64_t>(hash_state->message_length)) << 3;
thrust::fill_n(thrust::seq, hash_state->buffer + hash_state->buffer_length, 1, 0x80);

// 64 bytes for the number of bytes processed in a given step
constexpr int md5_chunk_size = 64;
// 8 bytes for the total message length, appended to the end of the last chunk processed
constexpr int message_length_size = 8;
// 1 byte for the end of the message flag
constexpr int end_of_message_size = 1;
if (hash_state->buffer_length + message_length_size + end_of_message_size <= md5_chunk_size) {
thrust::fill_n(
thrust::seq,
hash_state->buffer + hash_state->buffer_length + 1,
(md5_chunk_size - message_length_size - end_of_message_size - hash_state->buffer_length),
0x00);
} else {
thrust::fill_n(thrust::seq,
hash_state->buffer + hash_state->buffer_length + 1,
(md5_chunk_size - hash_state->buffer_length),
0x00);
md5_hash_step(hash_state);

thrust::fill_n(thrust::seq, hash_state->buffer, md5_chunk_size - message_length_size, 0x00);
}

std::memcpy(hash_state->buffer + md5_chunk_size - message_length_size,
reinterpret_cast<uint8_t const*>(&full_length),
message_length_size);
md5_hash_step(hash_state);

#pragma unroll
for (int i = 0; i < 4; ++i)
uint32ToLowercaseHexString(hash_state->hash_value[i], result_location + (8 * i));
}

template <typename T, std::enable_if_t<is_chrono<T>()>* = nullptr>
void __device__ operator()(column_device_view col,
size_type row_index,
md5_intermediate_data* hash_state) const
{
cudf_assert(false && "MD5 Unsupported chrono type column");
}

template <typename T, std::enable_if_t<!is_fixed_width<T>()>* = nullptr>
void __device__ operator()(column_device_view col,
size_type row_index,
md5_intermediate_data* hash_state) const
{
cudf_assert(false && "MD5 Unsupported non-fixed-width type column");
}

template <typename T, std::enable_if_t<is_floating_point<T>()>* = nullptr>
void __device__ operator()(column_device_view col,
size_type row_index,
md5_intermediate_data* hash_state) const
{
md5_process(normalize_nans_and_zeros_helper<T>(col.element<T>(row_index)), hash_state);
}

template <
typename T,
std::enable_if_t<is_fixed_width<T>() && !is_floating_point<T>() && !is_chrono<T>()>* = nullptr>
void CUDA_DEVICE_CALLABLE operator()(column_device_view col,
size_type row_index,
md5_intermediate_data* hash_state) const
{
md5_process(col.element<T>(row_index), hash_state);
}

private:
uint32_t m_seed{cudf::DEFAULT_HASH_SEED};
};

template <>
void CUDA_DEVICE_CALLABLE MD5Hash::operator()<string_view>(column_device_view col,
size_type row_index,
md5_intermediate_data* hash_state) const
{
string_view key = col.element<string_view>(row_index);
uint32_t const len = static_cast<uint32_t>(key.size_bytes());
uint8_t const* data = reinterpret_cast<uint8_t const*>(key.data());

hash_state->message_length += len;

if (hash_state->buffer_length + len < 64) {
std::memcpy(hash_state->buffer + hash_state->buffer_length, data, len);
hash_state->buffer_length += len;
} else {
uint32_t copylen = 64 - hash_state->buffer_length;
std::memcpy(hash_state->buffer + hash_state->buffer_length, data, copylen);
md5_hash_step(hash_state);

while (len > 64 + copylen) {
std::memcpy(hash_state->buffer, data + copylen, 64);
md5_hash_step(hash_state);
copylen += 64;
}

std::memcpy(hash_state->buffer, data + copylen, len - copylen);
hash_state->buffer_length = len - copylen;
}
}

template <>
void CUDA_DEVICE_CALLABLE MD5Hash::operator()<list_view>(column_device_view col,
size_type row_index,
md5_intermediate_data* hash_state) const
{
static constexpr size_type offsets_column_index{0};
static constexpr size_type data_column_index{1};

column_device_view offsets = col.child(offsets_column_index);
column_device_view data = col.child(data_column_index);

if (data.type().id() == type_id::LIST) cudf_assert(false && "Nested list unsupported");

cudf::type_dispatcher(data.type(),
MD5ListHasher{},
data,
offsets.element<size_type>(row_index),
offsets.element<size_type>(row_index + 1),
hash_state);
}
} // namespace detail
} // namespace cudf

Expand Down
Loading