diff --git a/cpp/src/io/comp/nvcomp_adapter.hpp b/cpp/src/io/comp/nvcomp_adapter.hpp index 1393b70f058..69a278757ce 100644 --- a/cpp/src/io/comp/nvcomp_adapter.hpp +++ b/cpp/src/io/comp/nvcomp_adapter.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -99,8 +99,8 @@ inline bool operator==(feature_status_parameters const& lhs, feature_status_para * @param[in] inputs List of input buffers * @param[out] outputs List of output buffers * @param[out] results List of output status structures - * @param[in] max_uncomp_chunk_size maximum size of uncompressed chunk - * @param[in] max_total_uncomp_size maximum total size of uncompressed data + * @param[in] max_uncomp_chunk_size Maximum size of any single uncompressed chunk + * @param[in] max_total_uncomp_size Maximum total size of uncompressed data * @param[in] stream CUDA stream to use */ void batched_decompress(compression_type compression, @@ -111,6 +111,24 @@ void batched_decompress(compression_type compression, size_t max_total_uncomp_size, rmm::cuda_stream_view stream); +/** + * @brief Return the amount of temporary space required in bytes for a given decompression + * operation. + * + * The size returned reflects the size of the scratch buffer to be passed to + * `batched_decompress_async` + * + * @param[in] compression Compression type + * @param[in] num_chunks The number of decompression chunks to be processed + * @param[in] max_uncomp_chunk_size Maximum size of any single uncompressed chunk + * @param[in] max_total_uncomp_size Maximum total size of uncompressed data + * @returns The total required size in bytes + */ +size_t batched_decompress_temp_size(compression_type compression, + size_t num_chunks, + size_t max_uncomp_chunk_size, + size_t max_total_uncomp_size); + /** * @brief Gets the maximum size any chunk could compress to in the batch. * diff --git a/cpp/src/io/parquet/page_decode.cuh b/cpp/src/io/parquet/page_decode.cuh index 8f256cd1f97..409b1464cd1 100644 --- a/cpp/src/io/parquet/page_decode.cuh +++ b/cpp/src/io/parquet/page_decode.cuh @@ -1301,16 +1301,15 @@ inline __device__ bool setupLocalPageInfo(page_state_s* const s, if (((s->col.data_type & 7) == BYTE_ARRAY) && (s->col.str_dict_index)) { // String dictionary: use index s->dict_base = reinterpret_cast(s->col.str_dict_index); - s->dict_size = s->col.page_info[0].num_input_values * sizeof(string_index_pair); + s->dict_size = s->col.dict_page->num_input_values * sizeof(string_index_pair); } else { - s->dict_base = - s->col.page_info[0].page_data; // dictionary is always stored in the first page - s->dict_size = s->col.page_info[0].uncompressed_page_size; + s->dict_base = s->col.dict_page->page_data; + s->dict_size = s->col.dict_page->uncompressed_page_size; } s->dict_run = 0; s->dict_val = 0; s->dict_bits = (cur < end) ? *cur++ : 0; - if (s->dict_bits > 32 || !s->dict_base) { + if (s->dict_bits > 32 || (!s->dict_base && s->col.dict_page->num_input_values > 0)) { s->set_error_code(decode_error::INVALID_DICT_WIDTH); } break; diff --git a/cpp/src/io/parquet/page_hdr.cu b/cpp/src/io/parquet/page_hdr.cu index 4be4f45497d..888d9452612 100644 --- a/cpp/src/io/parquet/page_hdr.cu +++ b/cpp/src/io/parquet/page_hdr.cu @@ -348,9 +348,11 @@ struct gpuParsePageHeader { * @param[in] num_chunks Number of column chunks */ // blockDim {128,1,1} -CUDF_KERNEL void __launch_bounds__(128) gpuDecodePageHeaders(ColumnChunkDesc* chunks, - int32_t num_chunks, - kernel_error::pointer error_code) +CUDF_KERNEL +void __launch_bounds__(128) gpuDecodePageHeaders(ColumnChunkDesc* chunks, + chunk_page_info* chunk_pages, + int32_t num_chunks, + kernel_error::pointer error_code) { using cudf::detail::warp_size; gpuParsePageHeader parse_page_header; @@ -392,11 +394,10 @@ CUDF_KERNEL void __launch_bounds__(128) gpuDecodePageHeaders(ColumnChunkDesc* ch bs->page.temp_string_buf = nullptr; bs->page.kernel_mask = decode_kernel_mask::NONE; } - num_values = bs->ck.num_values; - page_info = bs->ck.page_info; - num_dict_pages = bs->ck.num_dict_pages; - max_num_pages = (page_info) ? bs->ck.max_num_pages : 0; - values_found = 0; + num_values = bs->ck.num_values; + page_info = chunk_pages ? chunk_pages[chunk].pages : nullptr; + max_num_pages = page_info ? bs->ck.max_num_pages : 0; + values_found = 0; __syncwarp(); while (values_found < num_values && bs->cur < bs->end) { int index_out = -1; @@ -495,9 +496,9 @@ CUDF_KERNEL void __launch_bounds__(128) if (!lane_id && ck->num_dict_pages > 0 && ck->str_dict_index) { // Data type to describe a string string_index_pair* dict_index = ck->str_dict_index; - uint8_t const* dict = ck->page_info[0].page_data; - int dict_size = ck->page_info[0].uncompressed_page_size; - int num_entries = ck->page_info[0].num_input_values; + uint8_t const* dict = ck->dict_page->page_data; + int dict_size = ck->dict_page->uncompressed_page_size; + int num_entries = ck->dict_page->num_input_values; int pos = 0, cur = 0; for (int i = 0; i < num_entries; i++) { int len = 0; @@ -518,13 +519,15 @@ CUDF_KERNEL void __launch_bounds__(128) } void __host__ DecodePageHeaders(ColumnChunkDesc* chunks, + chunk_page_info* chunk_pages, int32_t num_chunks, kernel_error::pointer error_code, rmm::cuda_stream_view stream) { dim3 dim_block(128, 1); dim3 dim_grid((num_chunks + 3) >> 2, 1); // 1 chunk per warp, 4 warps per block - gpuDecodePageHeaders<<>>(chunks, num_chunks, error_code); + gpuDecodePageHeaders<<>>( + chunks, chunk_pages, num_chunks, error_code); } void __host__ BuildStringDictionaryIndex(ColumnChunkDesc* chunks, diff --git a/cpp/src/io/parquet/page_string_decode.cu b/cpp/src/io/parquet/page_string_decode.cu index 37a8cabc182..d652a43d097 100644 --- a/cpp/src/io/parquet/page_string_decode.cu +++ b/cpp/src/io/parquet/page_string_decode.cu @@ -868,14 +868,16 @@ CUDF_KERNEL void __launch_bounds__(preprocess_block_size) gpuComputePageStringSi if (col.str_dict_index) { // String dictionary: use index dict_base = reinterpret_cast(col.str_dict_index); - dict_size = col.page_info[0].num_input_values * sizeof(string_index_pair); + dict_size = col.dict_page->num_input_values * sizeof(string_index_pair); } else { - dict_base = col.page_info[0].page_data; // dictionary is always stored in the first page - dict_size = col.page_info[0].uncompressed_page_size; + dict_base = col.dict_page->page_data; + dict_size = col.dict_page->uncompressed_page_size; } // FIXME: need to return an error condition...this won't actually do anything - if (s->dict_bits > 32 || !dict_base) { CUDF_UNREACHABLE("invalid dictionary bit size"); } + if (s->dict_bits > 32 || (!dict_base && col.dict_page->num_input_values > 0)) { + CUDF_UNREACHABLE("invalid dictionary bit size"); + } str_bytes = totalDictEntriesSize( data, dict_base, s->dict_bits, dict_size, (end - data), start_value, end_value); diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 18d282be855..d58c7f95389 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -339,6 +339,21 @@ struct PageInfo { decode_kernel_mask kernel_mask; }; +/** + * @brief Return the column schema id as the key for a PageInfo struct. + */ +struct get_page_key { + __device__ int32_t operator()(PageInfo const& page) const { return page.src_col_schema; } +}; + +/** + * @brief Return an iterator that returns they keys for a vector of pages. + */ +inline auto make_page_key_iterator(device_span pages) +{ + return thrust::make_transform_iterator(pages.begin(), get_page_key{}); +} + /** * @brief Struct describing a particular chunk of column data */ @@ -362,7 +377,8 @@ struct ColumnChunkDesc { int8_t decimal_precision_, int32_t ts_clock_rate_, int32_t src_col_index_, - int32_t src_col_schema_) + int32_t src_col_schema_, + float list_bytes_per_row_est_) : compressed_data(compressed_data_), compressed_size(compressed_size_), num_values(num_values_), @@ -375,7 +391,7 @@ struct ColumnChunkDesc { num_data_pages(0), num_dict_pages(0), max_num_pages(0), - page_info(nullptr), + dict_page(nullptr), str_dict_index(nullptr), valid_map_base{nullptr}, column_data_base{nullptr}, @@ -386,26 +402,25 @@ struct ColumnChunkDesc { decimal_precision(decimal_precision_), ts_clock_rate(ts_clock_rate_), src_col_index(src_col_index_), - src_col_schema(src_col_schema_) + src_col_schema(src_col_schema_), + list_bytes_per_row_est(list_bytes_per_row_est_) { } - uint8_t const* compressed_data{}; // pointer to compressed column chunk data - size_t compressed_size{}; // total compressed data size for this chunk - size_t num_values{}; // total number of values in this column - size_t start_row{}; // starting row of this chunk - uint32_t num_rows{}; // number of rows in this chunk + uint8_t const* compressed_data{}; // pointer to compressed column chunk data + size_t compressed_size{}; // total compressed data size for this chunk + size_t num_values{}; // total number of values in this column + size_t start_row{}; // file-wide, absolute starting row of this chunk + uint32_t num_rows{}; // number of rows in this chunk int16_t max_level[level_type::NUM_LEVEL_TYPES]{}; // max definition/repetition level int16_t max_nesting_depth{}; // max nesting depth of the output - uint16_t data_type{}; // basic column data type, ((type_length << 3) | - // parquet::Type) + uint16_t data_type{}; // basic column data type, ((type_length << 3) | // parquet::Type) uint8_t - level_bits[level_type::NUM_LEVEL_TYPES]{}; // bits to encode max definition/repetition levels - int32_t num_data_pages{}; // number of data pages - int32_t num_dict_pages{}; // number of dictionary pages - int32_t max_num_pages{}; // size of page_info array - PageInfo* page_info{}; // output page info for up to num_dict_pages + - // num_data_pages (dictionary pages first) + level_bits[level_type::NUM_LEVEL_TYPES]{}; // bits to encode max definition/repetition levels + int32_t num_data_pages{}; // number of data pages + int32_t num_dict_pages{}; // number of dictionary pages + int32_t max_num_pages{}; // size of page_info array + PageInfo const* dict_page{}; string_index_pair* str_dict_index{}; // index for string dictionary bitmask_type** valid_map_base{}; // base pointers of valid bit map for this column void** column_data_base{}; // base pointers of column data @@ -418,6 +433,15 @@ struct ColumnChunkDesc { int32_t src_col_index{}; // my input column index int32_t src_col_schema{}; // my schema index in the file + + float list_bytes_per_row_est{}; // for LIST columns, an estimate on number of bytes per row +}; + +/** + * @brief A utility structure for use in decoding page headers. + */ +struct chunk_page_info { + PageInfo* pages; }; /** @@ -578,11 +602,13 @@ constexpr bool is_string_col(ColumnChunkDesc const& chunk) * @brief Launches kernel for parsing the page headers in the column chunks * * @param[in] chunks List of column chunks + * @param[in] chunk_pages List of pages associated with the chunks, in chunk-sorted order * @param[in] num_chunks Number of column chunks * @param[out] error_code Error code for kernel failures * @param[in] stream CUDA stream to use */ void DecodePageHeaders(ColumnChunkDesc* chunks, + chunk_page_info* chunk_pages, int32_t num_chunks, kernel_error::pointer error_code, rmm::cuda_stream_view stream); diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index c1082c0305a..24d46d91dbb 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,26 +29,28 @@ namespace cudf::io::parquet::detail { void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) { - auto& chunks = _pass_itm_data->chunks; - auto& pages = _pass_itm_data->pages_info; - auto& page_nesting = _pass_itm_data->page_nesting_info; - auto& page_nesting_decode = _pass_itm_data->page_nesting_decode_info; - auto const level_type_size = _pass_itm_data->level_type_size; + auto& pass = *_pass_itm_data; + auto& subpass = *pass.subpass; + + auto& page_nesting = subpass.page_nesting_info; + auto& page_nesting_decode = subpass.page_nesting_decode_info; + + auto const level_type_size = pass.level_type_size; // temporary space for DELTA_BYTE_ARRAY decoding. this only needs to live until // gpu::DecodeDeltaByteArray returns. rmm::device_uvector delta_temp_buf(0, _stream); // Should not reach here if there is no page data. - CUDF_EXPECTS(pages.size() > 0, "There is no page to decode"); + CUDF_EXPECTS(subpass.pages.size() > 0, "There are no pages to decode"); size_t const sum_max_depths = std::accumulate( - chunks.begin(), chunks.end(), 0, [&](size_t cursum, ColumnChunkDesc const& chunk) { + pass.chunks.begin(), pass.chunks.end(), 0, [&](size_t cursum, ColumnChunkDesc const& chunk) { return cursum + _metadata->get_output_nesting_depth(chunk.src_col_schema); }); // figure out which kernels to run - auto const kernel_mask = GetAggregatedDecodeKernelMask(pages, _stream); + auto const kernel_mask = GetAggregatedDecodeKernelMask(subpass.pages, _stream); // Check to see if there are any string columns present. If so, then we need to get size info // for each string page. This size info will be used to pre-allocate memory for the column, @@ -59,8 +61,14 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) auto const has_strings = (kernel_mask & STRINGS_MASK) != 0; std::vector col_sizes(_input_columns.size(), 0L); if (has_strings) { - ComputePageStringSizes( - pages, chunks, delta_temp_buf, skip_rows, num_rows, level_type_size, kernel_mask, _stream); + ComputePageStringSizes(subpass.pages, + pass.chunks, + delta_temp_buf, + skip_rows, + num_rows, + level_type_size, + kernel_mask, + _stream); col_sizes = calculate_page_string_offsets(); @@ -83,26 +91,26 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) cudf::detail::hostdevice_vector(has_strings ? sum_max_depths : 0, _stream); // Update chunks with pointers to column data. - for (size_t c = 0, page_count = 0, chunk_off = 0; c < chunks.size(); c++) { - input_column_info const& input_col = _input_columns[chunks[c].src_col_index]; - CUDF_EXPECTS(input_col.schema_idx == chunks[c].src_col_schema, + for (size_t c = 0, chunk_off = 0; c < pass.chunks.size(); c++) { + input_column_info const& input_col = _input_columns[pass.chunks[c].src_col_index]; + CUDF_EXPECTS(input_col.schema_idx == pass.chunks[c].src_col_schema, "Column/page schema index mismatch"); - size_t max_depth = _metadata->get_output_nesting_depth(chunks[c].src_col_schema); + size_t max_depth = _metadata->get_output_nesting_depth(pass.chunks[c].src_col_schema); chunk_offsets.push_back(chunk_off); // get a slice of size `nesting depth` from `chunk_nested_valids` to store an array of pointers // to validity data - auto valids = chunk_nested_valids.host_ptr(chunk_off); - chunks[c].valid_map_base = chunk_nested_valids.device_ptr(chunk_off); + auto valids = chunk_nested_valids.host_ptr(chunk_off); + pass.chunks[c].valid_map_base = chunk_nested_valids.device_ptr(chunk_off); // get a slice of size `nesting depth` from `chunk_nested_data` to store an array of pointers to // out data - auto data = chunk_nested_data.host_ptr(chunk_off); - chunks[c].column_data_base = chunk_nested_data.device_ptr(chunk_off); + auto data = chunk_nested_data.host_ptr(chunk_off); + pass.chunks[c].column_data_base = chunk_nested_data.device_ptr(chunk_off); auto str_data = has_strings ? chunk_nested_str_data.host_ptr(chunk_off) : nullptr; - chunks[c].column_string_base = + pass.chunks[c].column_string_base = has_strings ? chunk_nested_str_data.device_ptr(chunk_off) : nullptr; chunk_off += max_depth; @@ -148,8 +156,8 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) valids[idx] = out_buf.null_mask(); data[idx] = out_buf.data(); // only do string buffer for leaf - if (out_buf.string_size() == 0 && col_sizes[chunks[c].src_col_index] > 0) { - out_buf.create_string_data(col_sizes[chunks[c].src_col_index], _stream); + if (out_buf.string_size() == 0 && col_sizes[pass.chunks[c].src_col_index] > 0) { + out_buf.create_string_data(col_sizes[pass.chunks[c].src_col_index], _stream); } if (has_strings) { str_data[idx] = out_buf.string_data(); } out_buf.user_data |= @@ -159,12 +167,9 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) data[idx] = nullptr; } } - - // column_data_base will always point to leaf data, even for nested types. - page_count += chunks[c].max_num_pages; } - chunks.host_to_device_async(_stream); + pass.chunks.host_to_device_async(_stream); chunk_nested_valids.host_to_device_async(_stream); chunk_nested_data.host_to_device_async(_stream); if (has_strings) { chunk_nested_str_data.host_to_device_async(_stream); } @@ -179,44 +184,71 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) // launch string decoder int s_idx = 0; if (BitAnd(kernel_mask, decode_kernel_mask::STRING) != 0) { - DecodeStringPageData( - pages, chunks, num_rows, skip_rows, level_type_size, error_code.data(), streams[s_idx++]); + DecodeStringPageData(subpass.pages, + pass.chunks, + num_rows, + skip_rows, + level_type_size, + error_code.data(), + streams[s_idx++]); } // launch delta byte array decoder if (BitAnd(kernel_mask, decode_kernel_mask::DELTA_BYTE_ARRAY) != 0) { - DecodeDeltaByteArray( - pages, chunks, num_rows, skip_rows, level_type_size, error_code.data(), streams[s_idx++]); + DecodeDeltaByteArray(subpass.pages, + pass.chunks, + num_rows, + skip_rows, + level_type_size, + error_code.data(), + streams[s_idx++]); } // launch delta length byte array decoder if (BitAnd(kernel_mask, decode_kernel_mask::DELTA_LENGTH_BA) != 0) { - DecodeDeltaLengthByteArray( - pages, chunks, num_rows, skip_rows, level_type_size, error_code.data(), streams[s_idx++]); + DecodeDeltaLengthByteArray(subpass.pages, + pass.chunks, + num_rows, + skip_rows, + level_type_size, + error_code.data(), + streams[s_idx++]); } // launch delta binary decoder if (BitAnd(kernel_mask, decode_kernel_mask::DELTA_BINARY) != 0) { - DecodeDeltaBinary( - pages, chunks, num_rows, skip_rows, level_type_size, error_code.data(), streams[s_idx++]); + DecodeDeltaBinary(subpass.pages, + pass.chunks, + num_rows, + skip_rows, + level_type_size, + error_code.data(), + streams[s_idx++]); } // launch the catch-all page decoder if (BitAnd(kernel_mask, decode_kernel_mask::GENERAL) != 0) { - DecodePageData( - pages, chunks, num_rows, skip_rows, level_type_size, error_code.data(), streams[s_idx++]); + DecodePageData(subpass.pages, + pass.chunks, + num_rows, + skip_rows, + level_type_size, + error_code.data(), + streams[s_idx++]); } // synchronize the streams cudf::detail::join_streams(streams, _stream); - pages.device_to_host_async(_stream); + subpass.pages.device_to_host_async(_stream); page_nesting.device_to_host_async(_stream); page_nesting_decode.device_to_host_async(_stream); if (error_code.value() != 0) { CUDF_FAIL("Parquet data decode failed with code(s) " + error_code.str()); } + // error_code.value() is synchronous; explicitly sync here for better visibility + _stream.synchronize(); // for list columns, add the final offset to every offset buffer. // TODO : make this happen in more efficiently. Maybe use thrust::for_each @@ -259,10 +291,10 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) } // update null counts in the final column buffers - for (size_t idx = 0; idx < pages.size(); idx++) { - PageInfo* pi = &pages[idx]; + for (size_t idx = 0; idx < subpass.pages.size(); idx++) { + PageInfo* pi = &subpass.pages[idx]; if (pi->flags & PAGEINFO_FLAGS_DICTIONARY) { continue; } - ColumnChunkDesc* col = &chunks[pi->chunk_idx]; + ColumnChunkDesc* col = &pass.chunks[pi->chunk_idx]; input_column_info const& input_col = _input_columns[col->src_col_index]; int index = pi->nesting_decode - page_nesting_decode.device_ptr(); @@ -344,60 +376,16 @@ void reader::impl::prepare_data(int64_t skip_rows, { // if we have not preprocessed at the whole-file level, do that now if (!_file_preprocessed) { - // if filter is not empty, then create output types as vector and pass for filtering. - std::vector output_types; - if (filter.has_value()) { - std::transform(_output_buffers.cbegin(), - _output_buffers.cend(), - std::back_inserter(output_types), - [](auto const& col) { return col.type; }); - } - std::tie( - _file_itm_data.global_skip_rows, _file_itm_data.global_num_rows, _file_itm_data.row_groups) = - _metadata->select_row_groups( - row_group_indices, skip_rows, num_rows, output_types, filter, _stream); - - if (_file_itm_data.global_num_rows > 0 && not _file_itm_data.row_groups.empty() && - not _input_columns.empty()) { - // fills in chunk information without physically loading or decompressing - // the associated data - create_global_chunk_info(); - - // compute schedule of input reads. Each rowgroup contains 1 chunk per column. For now - // we will read an entire row group at a time. However, it is possible to do - // sub-rowgroup reads if we made some estimates on individual chunk sizes (tricky) and - // changed the high level structure such that we weren't always reading an entire table's - // worth of columns at once. - compute_input_passes(); - } - - _file_preprocessed = true; + // setup file level information + // - read row group information + // - setup information on (parquet) chunks + // - compute schedule of input passes + preprocess_file(skip_rows, num_rows, row_group_indices, filter); } - // if we have to start a new pass, do that now - if (!_pass_preprocessed) { - auto const num_passes = _file_itm_data.input_pass_row_group_offsets.size() - 1; - - // always create the pass struct, even if we end up with no passes. - // this will also cause the previous pass information to be deleted - _pass_itm_data = std::make_unique(); - - if (_file_itm_data.global_num_rows > 0 && not _file_itm_data.row_groups.empty() && - not _input_columns.empty() && _current_input_pass < num_passes) { - // setup the pass_intermediate_info for this pass. - setup_next_pass(); - - load_and_decompress_data(); - preprocess_pages(uses_custom_row_bounds, _output_chunk_read_limit); - - if (_output_chunk_read_limit == 0) { // read the whole file at once - CUDF_EXPECTS(_pass_itm_data->output_chunk_read_info.size() == 1, - "Reading the whole file should yield only one chunk."); - } - } - - _pass_preprocessed = true; - } + // handle any chunking work (ratcheting through the subpasses and chunks within + // our current pass) + if (_file_itm_data.num_passes() > 0) { handle_chunking(uses_custom_row_bounds); } } void reader::impl::populate_metadata(table_metadata& out_metadata) @@ -427,12 +415,12 @@ table_with_metadata reader::impl::read_chunk_internal( auto out_columns = std::vector>{}; out_columns.reserve(_output_buffers.size()); - if (!has_next() || _pass_itm_data->output_chunk_read_info.empty()) { - return finalize_output(out_metadata, out_columns, filter); - } + // no work to do (this can happen on the first pass if we have no rows to read) + if (!has_more_work()) { return finalize_output(out_metadata, out_columns, filter); } - auto const& read_info = - _pass_itm_data->output_chunk_read_info[_pass_itm_data->current_output_chunk]; + auto& pass = *_pass_itm_data; + auto& subpass = *pass.subpass; + auto const& read_info = subpass.output_chunk_read_info[subpass.current_output_chunk]; // Allocate memory buffers for the output columns. allocate_columns(read_info.skip_rows, read_info.num_rows, uses_custom_row_bounds); @@ -485,15 +473,12 @@ table_with_metadata reader::impl::finalize_output( _output_metadata = std::make_unique(out_metadata); } - // advance chunks/passes as necessary - _pass_itm_data->current_output_chunk++; - _chunk_count++; - if (_pass_itm_data->current_output_chunk >= _pass_itm_data->output_chunk_read_info.size()) { - _pass_itm_data->current_output_chunk = 0; - _pass_itm_data->output_chunk_read_info.clear(); - - _current_input_pass++; - _pass_preprocessed = false; + // advance output chunk/subpass/pass info + if (_file_itm_data.num_passes() > 0) { + auto& pass = *_pass_itm_data; + auto& subpass = *pass.subpass; + subpass.current_output_chunk++; + _file_itm_data._output_chunk_count++; } if (filter.has_value()) { @@ -530,7 +515,7 @@ table_with_metadata reader::impl::read_chunk() { // Reset the output buffers to their original states (right after reader construction). // Don't need to do it if we read the file all at once. - if (_chunk_count > 0) { + if (_file_itm_data._output_chunk_count > 0) { _output_buffers.resize(0); for (auto const& buff : _output_buffers_template) { _output_buffers.emplace_back(cudf::io::detail::inline_column_buffer::empty_like(buff)); @@ -553,10 +538,9 @@ bool reader::impl::has_next() {} /*row_group_indices, empty means read all row groups*/, std::nullopt /*filter*/); - size_t const num_input_passes = std::max( - int64_t{0}, static_cast(_file_itm_data.input_pass_row_group_offsets.size()) - 1); - return (_pass_itm_data->current_output_chunk < _pass_itm_data->output_chunk_read_info.size()) || - (_current_input_pass < num_input_passes); + // current_input_pass will only be incremented to be == num_passes after + // the last chunk in the last subpass in the last pass has been returned + return has_more_work(); } namespace { diff --git a/cpp/src/io/parquet/reader_impl.hpp b/cpp/src/io/parquet/reader_impl.hpp index cea4ba35606..67c56c9c2d7 100644 --- a/cpp/src/io/parquet/reader_impl.hpp +++ b/cpp/src/io/parquet/reader_impl.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -120,6 +120,8 @@ class reader::impl { */ table_with_metadata read_chunk(); + // top level functions involved with ratcheting through the passes, subpasses + // and output chunks of the read process private: /** * @brief Perform the necessary data preprocessing for parsing file later on. @@ -138,20 +140,101 @@ class reader::impl { std::optional> filter); /** - * @brief Create chunk information and start file reads + * @brief Preprocess step for the entire file. + * + * Only ever called once. This function reads in rowgroup and associated chunk + * information and computes the schedule of top level passes (see `pass_intermediate_data`). + * + * @param skip_rows The number of rows to skip in the requested set of rowgroups to be read + * @param num_rows The total number of rows to read out of the selected rowgroups + * @param row_group_indices Lists of row groups to read, one per source + * @param filter Optional AST expression to filter output rows + */ + void preprocess_file(int64_t skip_rows, + std::optional const& num_rows, + host_span const> row_group_indices, + std::optional> filter); + + /** + * @brief Ratchet the pass/subpass/chunk process forward. + * + * @param uses_custom_row_bounds Whether or not num_rows and skip_rows represents user-specified + * bounds + */ + void handle_chunking(bool uses_custom_row_bounds); + + /** + * @brief Setup step for the next input read pass. + * + * A 'pass' is defined as a subset of row groups read out of the globally + * requested set of all row groups. + * + * @param uses_custom_row_bounds Whether or not num_rows and skip_rows represents user-specific + * bounds + */ + void setup_next_pass(bool uses_custom_row_bounds); + + /** + * @brief Setup step for the next decompression subpass. + * + * @param uses_custom_row_bounds Whether or not num_rows and skip_rows represents user-specific + * bounds + * + * A 'subpass' is defined as a subset of pages within a pass that are + * decompressed and decoded as a batch. Subpasses may be further subdivided + * into output chunks. + */ + void setup_next_subpass(bool uses_custom_row_bounds); + + /** + * @brief Read a chunk of data and return an output table. + * + * This function is called internally and expects all preprocessing steps have already been done. + * + * @param uses_custom_row_bounds Whether or not num_rows and skip_rows represents user-specific + * bounds + * @param filter Optional AST expression to filter output rows + * @return The output table along with columns' metadata + */ + table_with_metadata read_chunk_internal( + bool uses_custom_row_bounds, + std::optional> filter); + + // utility functions + private: + /** + * @brief Read the set of column chunks to be processed for this pass. + * + * Does not decompress the chunk data. * * @return pair of boolean indicating if compressed chunks were found and a vector of futures for * read completion */ - std::pair>> read_and_decompress_column_chunks(); + std::pair>> read_column_chunks(); /** - * @brief Load and decompress the input file(s) into memory. + * @brief Read compressed data and page information for the current pass. */ - void load_and_decompress_data(); + void read_compressed_data(); /** - * @brief Perform some preprocessing for page data and also compute the split locations + * @brief Build string dictionary indices for a pass. + * + */ + void build_string_dict_indices(); + + /** + * @brief For list columns, generate estimated row counts for pages in the current pass. + * + * The row counts in the pages that come out of the file only reflect the number of values in + * all of the rows in the page, not the number of rows themselves. In order to do subpass reading + * more accurately, we would like to have a more accurate guess of the real number of rows per + * page. + */ + void generate_list_column_row_count_estimates(); + + /** + * @brief Perform some preprocessing for subpass page data and also compute the split locations * {skip_rows, num_rows} for chunked reading. * * There are several pieces of information we can't compute directly from row counts in @@ -166,7 +249,7 @@ class reader::impl { * @param chunk_read_limit Limit on total number of bytes to be returned per read, * or `0` if there is no limit */ - void preprocess_pages(bool uses_custom_row_bounds, size_t chunk_read_limit); + void preprocess_subpass_pages(bool uses_custom_row_bounds, size_t chunk_read_limit); /** * @brief Allocate nesting information storage for all pages and set pointers to it. @@ -194,20 +277,6 @@ class reader::impl { */ void populate_metadata(table_metadata& out_metadata); - /** - * @brief Read a chunk of data and return an output table. - * - * This function is called internally and expects all preprocessing steps have already been done. - * - * @param uses_custom_row_bounds Whether or not num_rows and skip_rows represents user-specific - * bounds - * @param filter Optional AST expression to filter output rows - * @return The output table along with columns' metadata - */ - table_with_metadata read_chunk_internal( - bool uses_custom_row_bounds, - std::optional> filter); - /** * @brief Finalize the output table by adding empty columns for the non-selected columns in * schema. @@ -260,17 +329,18 @@ class reader::impl { */ void compute_input_passes(); - /** - * @brief Close out the existing pass (if any) and prepare for the next pass. - */ - void setup_next_pass(); - /** * @brief Given a set of pages that have had their sizes computed by nesting level and * a limit on total read size, generate a set of {skip_rows, num_rows} pairs representing * a set of reads that will generate output columns of total size <= `chunk_read_limit` bytes. */ - void compute_splits_for_pass(); + void compute_output_chunks_for_subpass(); + + [[nodiscard]] bool has_more_work() const + { + return _file_itm_data.num_passes() > 0 && + _file_itm_data._current_input_pass < _file_itm_data.num_passes(); + } private: rmm::cuda_stream_view _stream; @@ -311,13 +381,9 @@ class reader::impl { bool _file_preprocessed{false}; std::unique_ptr _pass_itm_data; - bool _pass_preprocessed{false}; std::size_t _output_chunk_read_limit{0}; // output chunk size limit in bytes std::size_t _input_pass_read_limit{0}; // input pass memory usage limit in bytes - - std::size_t _current_input_pass{0}; // current input pass index - std::size_t _chunk_count{0}; // how many output chunks we have produced }; } // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/reader_impl_chunking.cu b/cpp/src/io/parquet/reader_impl_chunking.cu index 213fc380a34..1bfe5745b9e 100644 --- a/cpp/src/io/parquet/reader_impl_chunking.cu +++ b/cpp/src/io/parquet/reader_impl_chunking.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,11 @@ #include #include +#include +#include + +#include #include #include @@ -27,37 +31,61 @@ #include #include #include +#include #include +#include + +#include + +#include namespace cudf::io::parquet::detail { namespace { -struct cumulative_row_info { - size_t row_count; // cumulative row count +struct split_info { + row_range rows; + int64_t split_pos; +}; + +struct cumulative_page_info { + size_t row_index; // row index size_t size_bytes; // cumulative size in bytes int key; // schema index }; +// the minimum amount of memory we can safely expect to be enough to +// do a subpass decode. if the difference between the user specified limit and +// the actual memory used for compressed/temp data is > than this value, we will still use +// at least this many additional bytes. +// Example: +// - user has specified 1 GB limit +// - we have read in 900 MB of compressed data +// - that leaves us 100 MB of space for decompression batches +// - to keep the gpu busy, we really don't want to do less than 200 MB at a time so we're just going +// to use 200 MB of space +// even if that goes past the user-specified limit. +constexpr size_t minimum_subpass_expected_size = 200 * 1024 * 1024; + +// percentage of the total available input read limit that should be reserved for compressed +// data vs uncompressed data. +constexpr float input_limit_compression_reserve = 0.3f; + #if defined(CHUNKING_DEBUG) -void print_cumulative_page_info(cudf::detail::hostdevice_vector& pages, - rmm::device_uvector const& page_index, - rmm::device_uvector const& c_info, +void print_cumulative_page_info(device_span d_pages, + device_span d_chunks, + device_span d_c_info, rmm::cuda_stream_view stream) { - pages.device_to_host_sync(stream); + std::vector pages = cudf::detail::make_std_vector_sync(d_pages, stream); + std::vector chunks = cudf::detail::make_std_vector_sync(d_chunks, stream); + std::vector c_info = cudf::detail::make_std_vector_sync(d_c_info, stream); printf("------------\nCumulative sizes by page\n"); std::vector schemas(pages.size()); - std::vector h_page_index(pages.size()); - CUDF_CUDA_TRY(cudaMemcpy( - h_page_index.data(), page_index.data(), sizeof(int) * pages.size(), cudaMemcpyDefault)); - std::vector h_cinfo(pages.size()); - CUDF_CUDA_TRY(cudaMemcpy( - h_cinfo.data(), c_info.data(), sizeof(cumulative_row_info) * pages.size(), cudaMemcpyDefault)); auto schema_iter = cudf::detail::make_counting_transform_iterator( - 0, [&](size_type i) { return pages[h_page_index[i]].src_col_schema; }); + 0, [&](size_type i) { return pages[i].src_col_schema; }); thrust::copy(thrust::seq, schema_iter, schema_iter + pages.size(), schemas.begin()); auto last = thrust::unique(thrust::seq, schemas.begin(), schemas.end()); schemas.resize(last - schemas.begin()); @@ -66,38 +94,44 @@ void print_cumulative_page_info(cudf::detail::hostdevice_vector& pages for (size_t idx = 0; idx < schemas.size(); idx++) { printf("Schema %d\n", schemas[idx]); for (size_t pidx = 0; pidx < pages.size(); pidx++) { - auto const& page = pages[h_page_index[pidx]]; + auto const& page = pages[pidx]; if (page.flags & PAGEINFO_FLAGS_DICTIONARY || page.src_col_schema != schemas[idx]) { continue; } - printf("\tP: {%lu, %lu}\n", h_cinfo[pidx].row_count, h_cinfo[pidx].size_bytes); + bool const is_list = chunks[page.chunk_idx].max_level[level_type::REPETITION] > 0; + printf("\tP %s: {%lu, %lu, %lu}\n", + is_list ? "(L)" : "", + pidx, + c_info[pidx].row_index, + c_info[pidx].size_bytes); } } } -void print_cumulative_row_info(host_span sizes, +void print_cumulative_row_info(host_span sizes, std::string const& label, - std::optional> splits = std::nullopt) + std::optional> splits = std::nullopt) { if (splits.has_value()) { - printf("------------\nSplits\n"); + printf("------------\nSplits (skip_rows, num_rows)\n"); for (size_t idx = 0; idx < splits->size(); idx++) { printf("{%lu, %lu}\n", splits.value()[idx].skip_rows, splits.value()[idx].num_rows); } } - printf("------------\nCumulative sizes %s\n", label.c_str()); + printf("------------\nCumulative sizes %s (index, row_index, size_bytes, page_key)\n", + label.c_str()); for (size_t idx = 0; idx < sizes.size(); idx++) { - printf("{%lu, %lu, %d}", sizes[idx].row_count, sizes[idx].size_bytes, sizes[idx].key); + printf("{%lu, %lu, %lu, %d}", idx, sizes[idx].row_index, sizes[idx].size_bytes, sizes[idx].key); if (splits.has_value()) { // if we have a split at this row count and this is the last instance of this row count - auto start = thrust::make_transform_iterator( - splits->begin(), [](chunk_read_info const& i) { return i.skip_rows; }); + auto start = thrust::make_transform_iterator(splits->begin(), + [](row_range const& i) { return i.skip_rows; }); auto end = start + splits->size(); - auto split = std::find(start, end, sizes[idx].row_count); + auto split = std::find(start, end, sizes[idx].row_index); auto const split_index = [&]() -> int { if (split != end && - ((idx == sizes.size() - 1) || (sizes[idx + 1].row_count > sizes[idx].row_count))) { + ((idx == sizes.size() - 1) || (sizes[idx + 1].row_index > sizes[idx].row_index))) { return static_cast(std::distance(start, split)); } return idx == 0 ? 0 : -1; @@ -114,13 +148,13 @@ void print_cumulative_row_info(host_span sizes, #endif // CHUNKING_DEBUG /** - * @brief Functor which reduces two cumulative_row_info structs of the same key. + * @brief Functor which reduces two cumulative_page_info structs of the same key. */ -struct cumulative_row_sum { - cumulative_row_info operator() - __device__(cumulative_row_info const& a, cumulative_row_info const& b) const +struct cumulative_page_sum { + cumulative_page_info operator() + __device__(cumulative_page_info const& a, cumulative_page_info const& b) const { - return cumulative_row_info{a.row_count + b.row_count, a.size_bytes + b.size_bytes, a.key}; + return cumulative_page_info{0, a.size_bytes + b.size_bytes, a.key}; } }; @@ -178,32 +212,57 @@ __device__ size_t row_size_functor::operator()(size_t num_rows, boo * * Sums across all nesting levels. */ -struct get_cumulative_row_info { - PageInfo const* const pages; - - __device__ cumulative_row_info operator()(size_type index) +struct get_page_output_size { + __device__ cumulative_page_info operator()(PageInfo const& page) const { - auto const& page = pages[index]; if (page.flags & PAGEINFO_FLAGS_DICTIONARY) { - return cumulative_row_info{0, 0, page.src_col_schema}; + return cumulative_page_info{0, 0, page.src_col_schema}; } // total nested size, not counting string data - auto iter = - cudf::detail::make_counting_transform_iterator(0, [page, index] __device__(size_type i) { + auto iter = cudf::detail::make_counting_transform_iterator( + 0, cuda::proclaim_return_type([page] __device__(size_type i) { auto const& pni = page.nesting[i]; return cudf::type_dispatcher( data_type{pni.type}, row_size_functor{}, pni.size, pni.nullable); - }); - - size_t const row_count = static_cast(page.nesting[0].size); + })); return { - row_count, + 0, thrust::reduce(thrust::seq, iter, iter + page.num_output_nesting_levels) + page.str_bytes, page.src_col_schema}; } }; +/** + * @brief Functor which sets the (uncompressed) size of a page. + */ +struct get_page_input_size { + __device__ cumulative_page_info operator()(PageInfo const& page) const + { + // we treat dictionary page sizes as 0 for subpasses because we have already paid the price for + // them at the pass level. + if (page.flags & PAGEINFO_FLAGS_DICTIONARY) { return {0, 0, page.src_col_schema}; } + return {0, static_cast(page.uncompressed_page_size), page.src_col_schema}; + } +}; + +/** + * @brief Functor which sets the absolute row index of a page in a cumulative_page_info struct + */ +struct set_row_index { + device_span chunks; + device_span pages; + device_span c_info; + + __device__ void operator()(size_t i) + { + auto const& page = pages[i]; + auto const& chunk = chunks[page.chunk_idx]; + size_t const page_start_row = chunk.start_row + page.chunk_row + page.num_rows; + c_info[i].row_index = page_start_row; + } +}; + /** * @brief Functor which computes the effective size of all input columns by page. * @@ -219,12 +278,12 @@ struct get_cumulative_row_info { * at that point. So we have to proceed as if we are taking the bytes from all 200 rows of that * page. Essentially, a conservative over-estimate of the real size. */ -struct row_total_size { - cumulative_row_info const* c_info; +struct page_total_size { + cumulative_page_info const* c_info; size_type const* key_offsets; size_t num_keys; - __device__ cumulative_row_info operator()(cumulative_row_info const& i) + __device__ cumulative_page_info operator()(cumulative_page_info const& i) const { // sum sizes for each input column at this row size_t sum = 0; @@ -232,71 +291,81 @@ struct row_total_size { auto const start = key_offsets[idx]; auto const end = key_offsets[idx + 1]; auto iter = cudf::detail::make_counting_transform_iterator( - 0, [&] __device__(size_type i) { return c_info[i].row_count; }); + 0, cuda::proclaim_return_type([&] __device__(size_type i) { + return c_info[i].row_index; + })); auto const page_index = - thrust::lower_bound(thrust::seq, iter + start, iter + end, i.row_count) - iter; + thrust::lower_bound(thrust::seq, iter + start, iter + end, i.row_index) - iter; sum += c_info[page_index].size_bytes; } - return {i.row_count, sum, i.key}; + return {i.row_index, sum, i.key}; } }; /** - * @brief Given a vector of cumulative {row_count, byte_size} pairs and a chunk read - * limit, determine the set of splits. + * @brief Functor which returns the compressed data size for a chunk + */ +struct get_chunk_compressed_size { + __device__ size_t operator()(ColumnChunkDesc const& chunk) const { return chunk.compressed_size; } +}; + +/** + * @brief Find the first entry in the aggreggated_info that corresponds to the specified row * - * @param sizes Vector of cumulative {row_count, byte_size} pairs - * @param num_rows Total number of rows to read - * @param chunk_read_limit Limit on total number of bytes to be returned per read, for all columns */ -std::vector find_splits(std::vector const& sizes, - size_t num_rows, - size_t chunk_read_limit) +size_t find_start_index(cudf::host_span aggregated_info, + size_t start_row) { - // now we have an array of {row_count, real output bytes}. just walk through it and generate - // splits. - // TODO: come up with a clever way to do this entirely in parallel. For now, as long as batch - // sizes are reasonably large, this shouldn't iterate too many times - std::vector splits; - { - size_t cur_pos = 0; - size_t cur_cumulative_size = 0; - size_t cur_row_count = 0; - auto start = thrust::make_transform_iterator(sizes.begin(), [&](cumulative_row_info const& i) { - return i.size_bytes - cur_cumulative_size; - }); - auto end = start + sizes.size(); - while (cur_row_count < num_rows) { - int64_t split_pos = - thrust::lower_bound(thrust::seq, start + cur_pos, end, chunk_read_limit) - start; - - // if we're past the end, or if the returned bucket is > than the chunk_read_limit, move back - // one. - if (static_cast(split_pos) >= sizes.size() || - (sizes[split_pos].size_bytes - cur_cumulative_size > chunk_read_limit)) { - split_pos--; - } + auto start = thrust::make_transform_iterator( + aggregated_info.begin(), [&](cumulative_page_info const& i) { return i.row_index; }); + auto start_index = + thrust::lower_bound(thrust::host, start, start + aggregated_info.size(), start_row) - start; + + // cumulative_page_info.row_index is the -end- of the rows of a given page. so move forward until + // we find the next group of pages + while (start_index < (static_cast(aggregated_info.size()) - 1) && + (start_index < 0 || aggregated_info[start_index].row_index == start_row)) { + start_index++; + } - // best-try. if we can't find something that'll fit, we have to go bigger. we're doing this in - // a loop because all of the cumulative sizes for all the pages are sorted into one big list. - // so if we had two columns, both of which had an entry {1000, 10000}, that entry would be in - // the list twice. so we have to iterate until we skip past all of them. The idea is that we - // either do this, or we have to call unique() on the input first. - while (split_pos < (static_cast(sizes.size()) - 1) && - (split_pos < 0 || sizes[split_pos].row_count == cur_row_count)) { - split_pos++; - } + return start_index; +} - auto const start_row = cur_row_count; - cur_row_count = sizes[split_pos].row_count; - splits.push_back(chunk_read_info{start_row, cur_row_count - start_row}); - cur_pos = split_pos; - cur_cumulative_size = sizes[split_pos].size_bytes; - } +/** + * @brief Given a current position and row index, find the next split based on the + * specified size limit + * + * @returns The inclusive index within `sizes` where the next split should happen + * + */ +int64_t find_next_split(int64_t cur_pos, + size_t cur_row_index, + size_t cur_cumulative_size, + cudf::host_span sizes, + size_t size_limit) +{ + auto const start = thrust::make_transform_iterator( + sizes.begin(), + [&](cumulative_page_info const& i) { return i.size_bytes - cur_cumulative_size; }); + auto const end = start + sizes.size(); + + int64_t split_pos = thrust::lower_bound(thrust::seq, start + cur_pos, end, size_limit) - start; + + // if we're past the end, or if the returned bucket is > than the chunk_read_limit, move back + // one. + if (static_cast(split_pos) >= sizes.size() || + (sizes[split_pos].size_bytes - cur_cumulative_size > size_limit)) { + split_pos--; } - // print_cumulative_row_info(sizes, "adjusted", splits); - return splits; + // cumulative_page_info.row_index is the -end- of the rows of a given page. so move forward until + // we find the next group of pages + while (split_pos < (static_cast(sizes.size()) - 1) && + (split_pos < 0 || sizes[split_pos].row_index == cur_row_index)) { + split_pos++; + } + + return split_pos; } /** @@ -340,15 +409,969 @@ template return static_cast(CompactProtocolReader::NumRequiredBits(max_level)); } -struct row_count_compare { - __device__ bool operator()(cumulative_row_info const& a, cumulative_row_info const& b) +struct row_count_less { + __device__ bool operator()(cumulative_page_info const& a, cumulative_page_info const& b) const + { + return a.row_index < b.row_index; + } +}; + +/** + * @brief return compressed and total size of the data in a row group + * + */ +std::pair get_row_group_size(RowGroup const& rg) +{ + auto compressed_size_iter = thrust::make_transform_iterator( + rg.columns.begin(), [](ColumnChunk const& c) { return c.meta_data.total_compressed_size; }); + + // the trick is that total temp space needed is tricky to know + auto const compressed_size = + std::reduce(compressed_size_iter, compressed_size_iter + rg.columns.size()); + auto const total_size = compressed_size + rg.total_byte_size; + return {compressed_size, total_size}; +} + +/** + * @brief For a set of cumulative_page_info data, adjust the size_bytes field + * such that it reflects the worst case for all pages that span the same rows. + * + * By doing this, we can now look at row X and know the total + * byte cost for all pages that span row X, not just the cost up to row X itself. + * + * This function is asynchronous. Call stream.synchronize() before using the + * results. + */ +std::pair, rmm::device_uvector> +adjust_cumulative_sizes(device_span c_info, + device_span pages, + rmm::cuda_stream_view stream) +{ + // sort by row count + rmm::device_uvector c_info_sorted = + make_device_uvector_async(c_info, stream, rmm::mr::get_current_device_resource()); + thrust::sort( + rmm::exec_policy_nosync(stream), c_info_sorted.begin(), c_info_sorted.end(), row_count_less{}); + + // page keys grouped by split. + rmm::device_uvector page_keys_by_split{c_info.size(), stream}; + thrust::transform(rmm::exec_policy_nosync(stream), + c_info_sorted.begin(), + c_info_sorted.end(), + page_keys_by_split.begin(), + cuda::proclaim_return_type( + [] __device__(cumulative_page_info const& c) { return c.key; })); + + // generate key offsets (offsets to the start of each partition of keys). worst case is 1 page per + // key + rmm::device_uvector key_offsets(pages.size() + 1, stream); + auto page_keys = make_page_key_iterator(pages); + auto const key_offsets_end = thrust::reduce_by_key(rmm::exec_policy(stream), + page_keys, + page_keys + pages.size(), + thrust::make_constant_iterator(1), + thrust::make_discard_iterator(), + key_offsets.begin()) + .second; + size_t const num_unique_keys = key_offsets_end - key_offsets.begin(); + thrust::exclusive_scan( + rmm::exec_policy_nosync(stream), key_offsets.begin(), key_offsets.end(), key_offsets.begin()); + + // adjust the cumulative info such that for each row count, the size includes any pages that span + // that row count. this is so that if we have this case: + // page row counts + // Column A: 0 <----> 100 <----> 200 + // Column B: 0 <---------------> 200 <--------> 400 + // | + // if we decide to split at row 100, we don't really know the actual amount of bytes in column B + // at that point. So we have to proceed as if we are taking the bytes from all 200 rows of that + // page. + // + rmm::device_uvector aggregated_info(c_info.size(), stream); + thrust::transform(rmm::exec_policy_nosync(stream), + c_info_sorted.begin(), + c_info_sorted.end(), + aggregated_info.begin(), + page_total_size{c_info.data(), key_offsets.data(), num_unique_keys}); + return {std::move(aggregated_info), std::move(page_keys_by_split)}; +} + +struct page_span { + size_t start, end; +}; + +struct get_page_row_index { + device_span c_info; + + __device__ size_t operator()(size_t i) const { return c_info[i].row_index; } +}; + +/** + * @brief Return the span of page indices for a given column index that spans start_row and end_row + * + */ +template +struct get_page_span { + device_span page_offsets; + RowIndexIter page_row_index; + size_t const start_row; + size_t const end_row; + + get_page_span(device_span _page_offsets, + RowIndexIter _page_row_index, + size_t _start_row, + size_t _end_row) + : page_offsets(_page_offsets), + page_row_index(_page_row_index), + start_row(_start_row), + end_row(_end_row) + { + } + + __device__ page_span operator()(size_t column_index) const + { + auto const first_page_index = page_offsets[column_index]; + auto const column_page_start = page_row_index + first_page_index; + auto const column_page_end = page_row_index + page_offsets[column_index + 1]; + auto const num_pages = column_page_end - column_page_start; + + auto start_page = + (thrust::lower_bound(thrust::seq, column_page_start, column_page_end, start_row) - + column_page_start) + + first_page_index; + if (page_row_index[start_page] == start_row) { start_page++; } + + auto end_page = (thrust::lower_bound(thrust::seq, column_page_start, column_page_end, end_row) - + column_page_start) + + first_page_index; + if (end_page < (first_page_index + num_pages)) { end_page++; } + + return {static_cast(start_page), static_cast(end_page)}; + } +}; + +struct get_span_size { + __device__ size_t operator()(page_span const& s) const { return s.end - s.start; } +}; + +/** + * @brief Computes the next subpass within the current pass. + * + * A subpass is a subset of the pages within the parent pass that is decompressed + * as a batch and decoded. Subpasses are the level at which we control memory intermediate + * memory usage. A pass consists of >= 1 subpass. We cannot compute all subpasses in one + * shot because we do not know how many rows we actually have in the pages of list columns. + * So we have to make an educated guess that fits within the memory limits, and then adjust + * for subsequent subpasses when we see how many rows we actually receive. + * + * @param c_info The cumulative page size information (row count and byte size) per column + * @param pages All of the pages in the pass + * @param page_offsets Offsets into the pages array representing the first page for each column + * @param start_row The row to start the subpass at + * @param size_limit The size limit in bytes of the subpass + * @param num_columns The number of columns + * @param stream The stream to execute cuda operations on + * @returns A tuple containing a vector of page_span structs indicating the page indices to include + * for each column to be processed, the total number of pages over all columns, and the total + * expected memory usage (including scratch space) + * + */ +std::tuple, size_t, size_t> compute_next_subpass( + device_span c_info, + device_span pages, + device_span page_offsets, + size_t start_row, + size_t size_limit, + size_t num_columns, + rmm::cuda_stream_view stream) +{ + auto [aggregated_info, page_keys_by_split] = adjust_cumulative_sizes(c_info, pages, stream); + + // bring back to the cpu + auto const h_aggregated_info = cudf::detail::make_std_vector_sync(aggregated_info, stream); + // print_cumulative_row_info(h_aggregated_info, "adjusted"); + + // TODO: if the user has explicitly specified skip_rows/num_rows we could be more intelligent + // about skipping subpasses/pages that do not fall within the range of values, but only if the + // data does not contain lists (because our row counts are only estimates in that case) + + // find the next split + auto const start_index = find_start_index(h_aggregated_info, start_row); + auto const cumulative_size = + start_row == 0 || start_index == 0 ? 0 : h_aggregated_info[start_index - 1].size_bytes; + auto const end_index = + find_next_split(start_index, start_row, cumulative_size, h_aggregated_info, size_limit); + auto const end_row = h_aggregated_info[end_index].row_index; + + // for each column, collect the set of pages that spans start_row / end_row + rmm::device_uvector page_bounds(num_columns, stream); + auto iter = thrust::make_counting_iterator(size_t{0}); + auto page_row_index = + cudf::detail::make_counting_transform_iterator(0, get_page_row_index{c_info}); + thrust::transform(rmm::exec_policy_nosync(stream), + iter, + iter + num_columns, + page_bounds.begin(), + get_page_span{page_offsets, page_row_index, start_row, end_row}); + + // total page count over all columns + auto page_count_iter = thrust::make_transform_iterator(page_bounds.begin(), get_span_size{}); + size_t const total_pages = + thrust::reduce(rmm::exec_policy(stream), page_count_iter, page_count_iter + num_columns); + + return {cudf::detail::make_std_vector_sync(page_bounds, stream), + total_pages, + h_aggregated_info[end_index].size_bytes - cumulative_size}; +} + +std::vector compute_page_splits_by_row(device_span c_info, + device_span pages, + size_t skip_rows, + size_t num_rows, + size_t size_limit, + rmm::cuda_stream_view stream) +{ + auto [aggregated_info, page_keys_by_split] = adjust_cumulative_sizes(c_info, pages, stream); + + // bring back to the cpu + std::vector h_aggregated_info = + cudf::detail::make_std_vector_sync(aggregated_info, stream); + // print_cumulative_row_info(h_aggregated_info, "adjusted"); + + std::vector splits; + // note: we are working with absolute row indices so skip_rows represents the absolute min row + // index we care about + size_t cur_pos = find_start_index(h_aggregated_info, skip_rows); + size_t cur_row_index = skip_rows; + size_t cur_cumulative_size = 0; + auto const max_row = min(skip_rows + num_rows, h_aggregated_info.back().row_index); + while (cur_row_index < max_row) { + auto const split_pos = + find_next_split(cur_pos, cur_row_index, cur_cumulative_size, h_aggregated_info, size_limit); + + auto const start_row = cur_row_index; + cur_row_index = min(max_row, h_aggregated_info[split_pos].row_index); + splits.push_back({start_row, cur_row_index - start_row}); + cur_pos = split_pos; + cur_cumulative_size = h_aggregated_info[split_pos].size_bytes; + } + // print_cumulative_row_info(h_aggregated_info, "adjusted w/splits", splits); + + return splits; +} + +/** + * @brief Decompresses a set of pages contained in the set of chunks. + * + * This function handles the case where `pages` is only a subset of all available + * pages in `chunks`. + * + * @param chunks List of column chunk descriptors + * @param pages List of page information + * @param dict_pages If true, decompress dictionary pages only. Otherwise decompress non-dictionary + * pages only. + * @param stream CUDA stream used for device memory operations and kernel launches + * + * @return Device buffer to decompressed page data + */ +[[nodiscard]] rmm::device_buffer decompress_page_data( + cudf::detail::hostdevice_vector const& chunks, + cudf::detail::hostdevice_vector& pages, + bool dict_pages, + rmm::cuda_stream_view stream) +{ + auto for_each_codec_page = [&](Compression codec, std::function const& f) { + for (size_t p = 0; p < pages.size(); p++) { + if (chunks[pages[p].chunk_idx].codec == codec && + ((dict_pages && (pages[p].flags & PAGEINFO_FLAGS_DICTIONARY)) || + (!dict_pages && !(pages[p].flags & PAGEINFO_FLAGS_DICTIONARY)))) { + f(p); + } + } + }; + + // Brotli scratch memory for decompressing + rmm::device_buffer debrotli_scratch; + + // Count the exact number of compressed pages + size_t num_comp_pages = 0; + size_t total_decomp_size = 0; + + struct codec_stats { + Compression compression_type = UNCOMPRESSED; + size_t num_pages = 0; + int32_t max_decompressed_size = 0; + size_t total_decomp_size = 0; + }; + + std::array codecs{codec_stats{GZIP}, codec_stats{SNAPPY}, codec_stats{BROTLI}, codec_stats{ZSTD}}; + + auto is_codec_supported = [&codecs](int8_t codec) { + if (codec == UNCOMPRESSED) return true; + return std::find_if(codecs.begin(), codecs.end(), [codec](auto& cstats) { + return codec == cstats.compression_type; + }) != codecs.end(); + }; + CUDF_EXPECTS(std::all_of(chunks.begin(), + chunks.end(), + [&is_codec_supported](auto const& chunk) { + return is_codec_supported(chunk.codec); + }), + "Unsupported compression type"); + + for (auto& codec : codecs) { + for_each_codec_page(codec.compression_type, [&](size_t page) { + auto page_uncomp_size = pages[page].uncompressed_page_size; + total_decomp_size += page_uncomp_size; + codec.total_decomp_size += page_uncomp_size; + codec.max_decompressed_size = std::max(codec.max_decompressed_size, page_uncomp_size); + codec.num_pages++; + num_comp_pages++; + }); + if (codec.compression_type == BROTLI && codec.num_pages > 0) { + debrotli_scratch.resize(get_gpu_debrotli_scratch_size(codec.num_pages), stream); + } + } + + // Dispatch batches of pages to decompress for each codec. + // Buffer needs to be padded, required by `gpuDecodePageData`. + rmm::device_buffer decomp_pages( + cudf::util::round_up_safe(total_decomp_size, BUFFER_PADDING_MULTIPLE), stream); + + std::vector> comp_in; + comp_in.reserve(num_comp_pages); + std::vector> comp_out; + comp_out.reserve(num_comp_pages); + + // vectors to save v2 def and rep level data, if any + std::vector> copy_in; + copy_in.reserve(num_comp_pages); + std::vector> copy_out; + copy_out.reserve(num_comp_pages); + + rmm::device_uvector comp_res(num_comp_pages, stream); + thrust::fill(rmm::exec_policy_nosync(stream), + comp_res.begin(), + comp_res.end(), + compression_result{0, compression_status::FAILURE}); + + size_t decomp_offset = 0; + int32_t start_pos = 0; + for (auto const& codec : codecs) { + if (codec.num_pages == 0) { continue; } + + for_each_codec_page(codec.compression_type, [&](size_t page_idx) { + auto const dst_base = static_cast(decomp_pages.data()) + decomp_offset; + auto& page = pages[page_idx]; + // offset will only be non-zero for V2 pages + auto const offset = + page.lvl_bytes[level_type::DEFINITION] + page.lvl_bytes[level_type::REPETITION]; + // for V2 need to copy def and rep level info into place, and then offset the + // input and output buffers. otherwise we'd have to keep both the compressed + // and decompressed data. + if (offset != 0) { + copy_in.emplace_back(page.page_data, offset); + copy_out.emplace_back(dst_base, offset); + } + comp_in.emplace_back(page.page_data + offset, + static_cast(page.compressed_page_size - offset)); + comp_out.emplace_back(dst_base + offset, + static_cast(page.uncompressed_page_size - offset)); + page.page_data = dst_base; + decomp_offset += page.uncompressed_page_size; + }); + + host_span const> comp_in_view{comp_in.data() + start_pos, + codec.num_pages}; + auto const d_comp_in = cudf::detail::make_device_uvector_async( + comp_in_view, stream, rmm::mr::get_current_device_resource()); + host_span const> comp_out_view(comp_out.data() + start_pos, + codec.num_pages); + auto const d_comp_out = cudf::detail::make_device_uvector_async( + comp_out_view, stream, rmm::mr::get_current_device_resource()); + device_span d_comp_res_view(comp_res.data() + start_pos, codec.num_pages); + + switch (codec.compression_type) { + case GZIP: + gpuinflate(d_comp_in, d_comp_out, d_comp_res_view, gzip_header_included::YES, stream); + break; + case SNAPPY: + if (cudf::io::detail::nvcomp_integration::is_stable_enabled()) { + nvcomp::batched_decompress(nvcomp::compression_type::SNAPPY, + d_comp_in, + d_comp_out, + d_comp_res_view, + codec.max_decompressed_size, + codec.total_decomp_size, + stream); + } else { + gpu_unsnap(d_comp_in, d_comp_out, d_comp_res_view, stream); + } + break; + case ZSTD: + nvcomp::batched_decompress(nvcomp::compression_type::ZSTD, + d_comp_in, + d_comp_out, + d_comp_res_view, + codec.max_decompressed_size, + codec.total_decomp_size, + stream); + break; + case BROTLI: + gpu_debrotli(d_comp_in, + d_comp_out, + d_comp_res_view, + debrotli_scratch.data(), + debrotli_scratch.size(), + stream); + break; + default: CUDF_FAIL("Unexpected decompression dispatch"); break; + } + start_pos += codec.num_pages; + } + + CUDF_EXPECTS(thrust::all_of(rmm::exec_policy(stream), + comp_res.begin(), + comp_res.end(), + cuda::proclaim_return_type([] __device__(auto const& res) { + return res.status == compression_status::SUCCESS; + })), + "Error during decompression"); + + // now copy the uncompressed V2 def and rep level data + if (not copy_in.empty()) { + auto const d_copy_in = cudf::detail::make_device_uvector_async( + copy_in, stream, rmm::mr::get_current_device_resource()); + auto const d_copy_out = cudf::detail::make_device_uvector_async( + copy_out, stream, rmm::mr::get_current_device_resource()); + + gpu_copy_uncompressed_blocks(d_copy_in, d_copy_out, stream); + stream.synchronize(); + } + + pages.host_to_device_async(stream); + + stream.synchronize(); + return decomp_pages; +} + +struct flat_column_num_rows { + ColumnChunkDesc const* chunks; + + __device__ size_type operator()(PageInfo const& page) const + { + // ignore dictionary pages and pages belonging to any column containing repetition (lists) + if ((page.flags & PAGEINFO_FLAGS_DICTIONARY) || + (chunks[page.chunk_idx].max_level[level_type::REPETITION] > 0)) { + return 0; + } + return page.num_rows; + } +}; + +struct row_counts_nonzero { + __device__ bool operator()(size_type count) const { return count > 0; } +}; + +struct row_counts_different { + size_type const expected; + __device__ bool operator()(size_type count) const { return (count != 0) && (count != expected); } +}; + +/** + * @brief Detect malformed parquet input data. + * + * We have seen cases where parquet files can be oddly malformed. This function specifically + * detects one case in particular: + * + * - When you have a file containing N rows + * - For some reason, the sum total of the number of rows over all pages for a given column + * is != N + * + * @param pages All pages to be decoded + * @param chunks Chunk data + * @param expected_row_count Expected row count, if applicable + * @param stream CUDA stream used for device memory operations and kernel launches + */ +void detect_malformed_pages(device_span pages, + device_span chunks, + std::optional expected_row_count, + rmm::cuda_stream_view stream) +{ + // sum row counts for all non-dictionary, non-list columns. other columns will be indicated as 0 + rmm::device_uvector row_counts(pages.size(), + stream); // worst case: num keys == num pages + auto const size_iter = + thrust::make_transform_iterator(pages.begin(), flat_column_num_rows{chunks.data()}); + auto const row_counts_begin = row_counts.begin(); + auto page_keys = make_page_key_iterator(pages); + auto const row_counts_end = thrust::reduce_by_key(rmm::exec_policy(stream), + page_keys, + page_keys + pages.size(), + size_iter, + thrust::make_discard_iterator(), + row_counts_begin) + .second; + + // make sure all non-zero row counts are the same + rmm::device_uvector compacted_row_counts(pages.size(), stream); + auto const compacted_row_counts_begin = compacted_row_counts.begin(); + auto const compacted_row_counts_end = thrust::copy_if(rmm::exec_policy(stream), + row_counts_begin, + row_counts_end, + compacted_row_counts_begin, + row_counts_nonzero{}); + if (compacted_row_counts_end != compacted_row_counts_begin) { + size_t const found_row_count = static_cast(compacted_row_counts.element(0, stream)); + + // if we somehow don't match the expected row count from the row groups themselves + if (expected_row_count.has_value()) { + CUDF_EXPECTS(expected_row_count.value() == found_row_count, + "Encountered malformed parquet page data (unexpected row count in page data)"); + } + + // all non-zero row counts must be the same + auto const chk = + thrust::count_if(rmm::exec_policy(stream), + compacted_row_counts_begin, + compacted_row_counts_end, + row_counts_different{static_cast(found_row_count)}); + CUDF_EXPECTS(chk == 0, + "Encountered malformed parquet page data (row count mismatch in page data)"); + } +} + +struct decompression_info { + Compression codec; + size_t num_pages; + size_t max_page_decompressed_size; + size_t total_decompressed_size; +}; + +/** + * @brief Functor which retrieves per-page decompression information. + * + */ +struct get_decomp_info { + device_span chunks; + + __device__ decompression_info operator()(PageInfo const& p) const + { + return {static_cast(chunks[p.chunk_idx].codec), + 1, + static_cast(p.uncompressed_page_size), + static_cast(p.uncompressed_page_size)}; + } +}; + +/** + * @brief Functor which accumulates per-page decompression information. + * + */ +struct decomp_sum { + __device__ decompression_info operator()(decompression_info const& a, + decompression_info const& b) const + { + return {a.codec, + a.num_pages + b.num_pages, + std::max(a.max_page_decompressed_size, b.max_page_decompressed_size), + a.total_decompressed_size + b.total_decompressed_size}; + } +}; + +/** + * @brief Functor which returns total scratch space required based on computed decompression_info + * data. + * + */ +struct get_decomp_scratch { + size_t operator()(decompression_info const& di) const { - return a.row_count < b.row_count; + switch (di.codec) { + case UNCOMPRESSED: + case GZIP: return 0; + + case BROTLI: return get_gpu_debrotli_scratch_size(di.num_pages); + + case SNAPPY: + if (cudf::io::detail::nvcomp_integration::is_stable_enabled()) { + return cudf::io::nvcomp::batched_decompress_temp_size( + cudf::io::nvcomp::compression_type::SNAPPY, + di.num_pages, + di.max_page_decompressed_size, + di.total_decompressed_size); + } else { + return 0; + } + break; + + case ZSTD: + return cudf::io::nvcomp::batched_decompress_temp_size( + cudf::io::nvcomp::compression_type::ZSTD, + di.num_pages, + di.max_page_decompressed_size, + di.total_decompressed_size); + + default: CUDF_FAIL("Invalid compression codec for parquet decompression"); + } } }; +/** + * @brief Add the cost of decompression codec scratch space to the per-page cumulative + * size information. + * + */ +void include_decompression_scratch_size(device_span chunks, + device_span pages, + device_span c_info, + rmm::cuda_stream_view stream) +{ + CUDF_EXPECTS(pages.size() == c_info.size(), + "Encountered page/cumulative_page_info size mismatch"); + + auto page_keys = make_page_key_iterator(pages); + + // per-codec page counts and decompression sizes + rmm::device_uvector decomp_info(pages.size(), stream); + auto decomp_iter = thrust::make_transform_iterator(pages.begin(), get_decomp_info{chunks}); + thrust::inclusive_scan_by_key(rmm::exec_policy_nosync(stream), + page_keys, + page_keys + pages.size(), + decomp_iter, + decomp_info.begin(), + thrust::equal_to{}, + decomp_sum{}); + + // retrieve to host so we can call nvcomp to get compression scratch sizes + std::vector h_decomp_info = + cudf::detail::make_std_vector_sync(decomp_info, stream); + std::vector temp_cost(pages.size()); + thrust::transform(thrust::host, + h_decomp_info.begin(), + h_decomp_info.end(), + temp_cost.begin(), + get_decomp_scratch{}); + + // add to the cumulative_page_info data + rmm::device_uvector d_temp_cost = cudf::detail::make_device_uvector_async( + temp_cost, stream, rmm::mr::get_current_device_resource()); + auto iter = thrust::make_counting_iterator(size_t{0}); + thrust::for_each(rmm::exec_policy_nosync(stream), + iter, + iter + pages.size(), + [temp_cost = d_temp_cost.begin(), c_info = c_info.begin()] __device__(size_t i) { + c_info[i].size_bytes += temp_cost[i]; + }); + stream.synchronize(); +} + } // anonymous namespace +void reader::impl::handle_chunking(bool uses_custom_row_bounds) +{ + // if this is our first time in here, setup the first pass. + if (!_pass_itm_data) { + // setup the next pass + setup_next_pass(uses_custom_row_bounds); + } + + auto& pass = *_pass_itm_data; + + // if we already have a subpass in flight. + if (pass.subpass != nullptr) { + // if it still has more chunks in flight, there's nothing more to do + if (pass.subpass->current_output_chunk < pass.subpass->output_chunk_read_info.size()) { + return; + } + + // increment rows processed + pass.processed_rows += pass.subpass->num_rows; + + // release the old subpass (will free memory) + pass.subpass.reset(); + + // otherwise we are done with the pass entirely + if (pass.processed_rows == pass.num_rows) { + // release the old pass + _pass_itm_data.reset(); + + _file_itm_data._current_input_pass++; + // no more passes. we are absolutely done with this file. + if (_file_itm_data._current_input_pass == _file_itm_data.num_passes()) { return; } + + // setup the next pass + setup_next_pass(uses_custom_row_bounds); + } + } + + // setup the next sub pass + setup_next_subpass(uses_custom_row_bounds); +} + +void reader::impl::setup_next_pass(bool uses_custom_row_bounds) +{ + auto const num_passes = _file_itm_data.num_passes(); + + // always create the pass struct, even if we end up with no work. + // this will also cause the previous pass information to be deleted + _pass_itm_data = std::make_unique(); + + if (_file_itm_data.global_num_rows > 0 && not _file_itm_data.row_groups.empty() && + not _input_columns.empty() && _file_itm_data._current_input_pass < num_passes) { + auto& pass = *_pass_itm_data; + + // setup row groups to be loaded for this pass + auto const row_group_start = + _file_itm_data.input_pass_row_group_offsets[_file_itm_data._current_input_pass]; + auto const row_group_end = + _file_itm_data.input_pass_row_group_offsets[_file_itm_data._current_input_pass + 1]; + auto const num_row_groups = row_group_end - row_group_start; + pass.row_groups.resize(num_row_groups); + std::copy(_file_itm_data.row_groups.begin() + row_group_start, + _file_itm_data.row_groups.begin() + row_group_end, + pass.row_groups.begin()); + + CUDF_EXPECTS(_file_itm_data._current_input_pass < num_passes, + "Encountered an invalid read pass index"); + + auto const chunks_per_rowgroup = _input_columns.size(); + auto const num_chunks = chunks_per_rowgroup * num_row_groups; + + auto chunk_start = _file_itm_data.chunks.begin() + (row_group_start * chunks_per_rowgroup); + auto chunk_end = _file_itm_data.chunks.begin() + (row_group_end * chunks_per_rowgroup); + + pass.chunks = cudf::detail::hostdevice_vector(num_chunks, _stream); + std::copy(chunk_start, chunk_end, pass.chunks.begin()); + + // compute skip_rows / num_rows for this pass. + if (num_passes == 1) { + pass.skip_rows = _file_itm_data.global_skip_rows; + pass.num_rows = _file_itm_data.global_num_rows; + } else { + auto const global_start_row = _file_itm_data.global_skip_rows; + auto const global_end_row = global_start_row + _file_itm_data.global_num_rows; + auto const start_row = + std::max(_file_itm_data.input_pass_start_row_count[_file_itm_data._current_input_pass], + global_start_row); + auto const end_row = + std::min(_file_itm_data.input_pass_start_row_count[_file_itm_data._current_input_pass + 1], + global_end_row); + + // skip_rows is always global in the sense that it is relative to the first row of + // everything we will be reading, regardless of what pass we are on. + // num_rows is how many rows we are reading this pass. + pass.skip_rows = + global_start_row + + _file_itm_data.input_pass_start_row_count[_file_itm_data._current_input_pass]; + pass.num_rows = end_row - start_row; + } + + // load page information for the chunk. this retrieves the compressed bytes for all the + // pages, and their headers (which we can access without decompressing) + read_compressed_data(); + + // detect malformed columns. + // - we have seen some cases in the wild where we have a row group containing N + // rows, but the total number of rows in the pages for column X is != N. while it + // is possible to load this by just capping the number of rows read, we cannot tell + // which rows are invalid so we may be returning bad data. in addition, this mismatch + // confuses the chunked reader + detect_malformed_pages( + pass.pages, + pass.chunks, + uses_custom_row_bounds ? std::nullopt : std::make_optional(pass.num_rows), + _stream); + + // decompress dictionary data if applicable. + if (pass.has_compressed_data) { + pass.decomp_dict_data = decompress_page_data(pass.chunks, pass.pages, true, _stream); + } + + // store off how much memory we've used so far. This includes the compressed page data and the + // decompressed dictionary data. we will subtract this from the available total memory for the + // subpasses + auto chunk_iter = + thrust::make_transform_iterator(pass.chunks.d_begin(), get_chunk_compressed_size{}); + pass.base_mem_size = + pass.decomp_dict_data.size() + + thrust::reduce(rmm::exec_policy(_stream), chunk_iter, chunk_iter + pass.chunks.size()); + + // since there is only ever 1 dictionary per chunk (the first page), do it at the + // pass level. + build_string_dict_indices(); + + // if we are doing subpass reading, generate more accurate num_row estimates for list columns. + // this helps us to generate more accurate subpass splits. + if (_input_pass_read_limit != 0) { generate_list_column_row_count_estimates(); } + +#if defined(PARQUET_CHUNK_LOGGING) + printf("Pass: row_groups(%'lu), chunks(%'lu), pages(%'lu)\n", + pass.row_groups.size(), + pass.chunks.size(), + pass.pages.size()); + printf("\tskip_rows: %'lu\n", pass.skip_rows); + printf("\tnum_rows: %'lu\n", pass.num_rows); + printf("\tbase mem usage: %'lu\n", pass.base_mem_size); + auto const num_columns = _input_columns.size(); + for (size_t c_idx = 0; c_idx < num_columns; c_idx++) { + printf("\t\tColumn %'lu: num_pages(%'d)\n", + c_idx, + pass.page_offsets[c_idx + 1] - pass.page_offsets[c_idx]); + } +#endif + + _stream.synchronize(); + } +} + +void reader::impl::setup_next_subpass(bool uses_custom_row_bounds) +{ + auto& pass = *_pass_itm_data; + pass.subpass = std::make_unique(); + auto& subpass = *pass.subpass; + + auto const num_columns = _input_columns.size(); + + // if the user has passed a very small value (under the hardcoded minimum_subpass_expected_size), + // respect it. + auto const min_subpass_size = std::min(_input_pass_read_limit, minimum_subpass_expected_size); + + // what do we do if the base memory size (the compressed data) itself is approaching or larger + // than the overall read limit? we are still going to be decompressing in subpasses, but we have + // to assume some reasonable minimum size needed to safely decompress a single subpass. so always + // reserve at least that much space. this can result in using up to 2x the specified user limit + // but should only ever happen with unrealistically low numbers. + size_t const remaining_read_limit = + _input_pass_read_limit == 0 ? 0 + : pass.base_mem_size + min_subpass_size >= _input_pass_read_limit + ? min_subpass_size + : _input_pass_read_limit - pass.base_mem_size; + + auto [page_indices, total_pages, total_expected_size] = + [&]() -> std::tuple, size_t, size_t> { + // special case: if we contain no compressed data, or if we have no input limit, we can always + // just do 1 subpass since what we already have loaded is all the temporary memory we will ever + // use. + if (!pass.has_compressed_data || _input_pass_read_limit == 0) { + std::vector page_indices; + page_indices.reserve(num_columns); + auto iter = thrust::make_counting_iterator(0); + std::transform( + iter, iter + num_columns, std::back_inserter(page_indices), [&](size_t i) -> page_span { + return {static_cast(pass.page_offsets[i]), + static_cast(pass.page_offsets[i + 1])}; + }); + return {page_indices, pass.pages.size(), 0}; + } + // otherwise we have to look forward and choose a batch of pages + + // as subpasses get decoded, the initial estimates we have for list row counts + // get updated with accurate data, so regenerate cumulative size info and row + // indices + rmm::device_uvector c_info(pass.pages.size(), _stream); + auto page_keys = make_page_key_iterator(pass.pages); + auto page_size = thrust::make_transform_iterator(pass.pages.d_begin(), get_page_input_size{}); + thrust::inclusive_scan_by_key(rmm::exec_policy_nosync(_stream), + page_keys, + page_keys + pass.pages.size(), + page_size, + c_info.begin(), + thrust::equal_to{}, + cumulative_page_sum{}); + + // include scratch space needed for decompression. for certain codecs (eg ZSTD) this + // can be considerable. + include_decompression_scratch_size(pass.chunks, pass.pages, c_info, _stream); + + auto iter = thrust::make_counting_iterator(0); + thrust::for_each(rmm::exec_policy_nosync(_stream), + iter, + iter + pass.pages.size(), + set_row_index{pass.chunks, pass.pages, c_info}); + // print_cumulative_page_info(pass.pages, pass.chunks, c_info, _stream); + + // get the next batch of pages + return compute_next_subpass(c_info, + pass.pages, + pass.page_offsets, + pass.processed_rows + pass.skip_rows, + remaining_read_limit, + num_columns, + _stream); + }(); + + // fill out the subpass struct + subpass.pages = cudf::detail::hostdevice_vector(0, total_pages, _stream); + subpass.page_src_index = + cudf::detail::hostdevice_vector(total_pages, total_pages, _stream); + // copy the appropriate subset of pages from each column + size_t page_count = 0; + for (size_t c_idx = 0; c_idx < num_columns; c_idx++) { + auto const num_column_pages = page_indices[c_idx].end - page_indices[c_idx].start; + subpass.column_page_count.push_back(num_column_pages); + std::copy(pass.pages.begin() + page_indices[c_idx].start, + pass.pages.begin() + page_indices[c_idx].end, + std::back_inserter(subpass.pages)); + + // mapping back to original pages in the pass + thrust::sequence(thrust::host, + subpass.page_src_index.begin() + page_count, + subpass.page_src_index.begin() + page_count + num_column_pages, + page_indices[c_idx].start); + page_count += num_column_pages; + } + // print_hostdevice_vector(subpass.page_src_index); + + // decompress the data for the pages in this subpass. + if (pass.has_compressed_data) { + subpass.decomp_page_data = decompress_page_data(pass.chunks, subpass.pages, false, _stream); + } + + subpass.pages.host_to_device_async(_stream); + subpass.page_src_index.host_to_device_async(_stream); + _stream.synchronize(); + + // buffers needed by the decode kernels + { + // nesting information (sizes, etc) stored -per page- + // note : even for flat schemas, we allocate 1 level of "nesting" info + allocate_nesting_info(); + + // level decode space + allocate_level_decode_space(); + } + subpass.pages.host_to_device_async(_stream); + + // preprocess pages (computes row counts for lists, computes output chunks and computes + // the actual row counts we will be able load out of this subpass) + preprocess_subpass_pages(uses_custom_row_bounds, _output_chunk_read_limit); + +#if defined(PARQUET_CHUNK_LOGGING) + printf("\tSubpass: skip_rows(%'lu), num_rows(%'lu), remaining read limit(%'lu)\n", + subpass.skip_rows, + subpass.num_rows, + remaining_read_limit); + printf("\t\tDecompressed size: %'lu\n", subpass.decomp_page_data.size()); + printf("\t\tTotal expected usage: %'lu\n", + total_expected_size == 0 ? subpass.decomp_page_data.size() + pass.base_mem_size + : total_expected_size + pass.base_mem_size); + for (size_t c_idx = 0; c_idx < num_columns; c_idx++) { + printf("\t\tColumn %'lu: pages(%'lu - %'lu)\n", + c_idx, + page_indices[c_idx].start, + page_indices[c_idx].end); + } + printf("\t\tOutput chunks:\n"); + for (size_t idx = 0; idx < subpass.output_chunk_read_info.size(); idx++) { + printf("\t\t\t%'lu: skip_rows(%'lu) num_rows(%'lu)\n", + idx, + subpass.output_chunk_read_info[idx].skip_rows, + subpass.output_chunk_read_info[idx].num_rows); + } +#endif +} + void reader::impl::create_global_chunk_info() { auto const num_rows = _file_itm_data.global_num_rows; @@ -380,6 +1403,14 @@ void reader::impl::create_global_chunk_info() schema.converted_type, schema.type_length); + // for lists, estimate the number of bytes per row. this is used by the subpass reader to + // determine where to split the decompression boundaries + float const list_bytes_per_row_est = + schema.max_repetition_level > 0 && row_group.num_rows > 0 + ? static_cast(col_meta.total_uncompressed_size) / + static_cast(row_group.num_rows) + : 0.0f; + chunks.push_back(ColumnChunkDesc(col_meta.total_compressed_size, nullptr, col_meta.num_values, @@ -398,7 +1429,8 @@ void reader::impl::create_global_chunk_info() schema.decimal_precision, clock_rate, i, - col.schema_idx)); + col.schema_idx, + list_bytes_per_row_est)); } remaining_rows -= row_group_rows; @@ -415,185 +1447,101 @@ void reader::impl::compute_input_passes() if (_input_pass_read_limit == 0) { _file_itm_data.input_pass_row_group_offsets.push_back(0); _file_itm_data.input_pass_row_group_offsets.push_back(row_groups_info.size()); + _file_itm_data.input_pass_start_row_count.push_back(0); + auto rg_row_count = cudf::detail::make_counting_transform_iterator(0, [&](size_t i) { + auto const& rgi = row_groups_info[i]; + auto const& row_group = _metadata->get_row_group(rgi.index, rgi.source_index); + return row_group.num_rows; + }); + _file_itm_data.input_pass_start_row_count.push_back( + std::reduce(rg_row_count, rg_row_count + row_groups_info.size())); return; } // generate passes. make sure to account for the case where a single row group doesn't fit within // - std::size_t const read_limit = - _input_pass_read_limit > 0 ? _input_pass_read_limit : std::numeric_limits::max(); + std::size_t const comp_read_limit = + _input_pass_read_limit > 0 + ? static_cast(_input_pass_read_limit * input_limit_compression_reserve) + : std::numeric_limits::max(); std::size_t cur_pass_byte_size = 0; std::size_t cur_rg_start = 0; std::size_t cur_row_count = 0; _file_itm_data.input_pass_row_group_offsets.push_back(0); - _file_itm_data.input_pass_row_count.push_back(0); + _file_itm_data.input_pass_start_row_count.push_back(0); for (size_t cur_rg_index = 0; cur_rg_index < row_groups_info.size(); cur_rg_index++) { auto const& rgi = row_groups_info[cur_rg_index]; auto const& row_group = _metadata->get_row_group(rgi.index, rgi.source_index); + // total compressed size and total size (compressed + uncompressed) for + auto const [compressed_rg_size, _ /*compressed + uncompressed*/] = + get_row_group_size(row_group); + // can we add this row group - if (cur_pass_byte_size + row_group.total_byte_size >= read_limit) { + if (cur_pass_byte_size + compressed_rg_size >= comp_read_limit) { // A single row group (the current one) is larger than the read limit: // We always need to include at least one row group, so end the pass at the end of the current // row group if (cur_rg_start == cur_rg_index) { _file_itm_data.input_pass_row_group_offsets.push_back(cur_rg_index + 1); - _file_itm_data.input_pass_row_count.push_back(cur_row_count + row_group.num_rows); + _file_itm_data.input_pass_start_row_count.push_back(cur_row_count + row_group.num_rows); cur_rg_start = cur_rg_index + 1; cur_pass_byte_size = 0; } // End the pass at the end of the previous row group else { _file_itm_data.input_pass_row_group_offsets.push_back(cur_rg_index); - _file_itm_data.input_pass_row_count.push_back(cur_row_count); + _file_itm_data.input_pass_start_row_count.push_back(cur_row_count); cur_rg_start = cur_rg_index; - cur_pass_byte_size = row_group.total_byte_size; + cur_pass_byte_size = compressed_rg_size; } } else { - cur_pass_byte_size += row_group.total_byte_size; + cur_pass_byte_size += compressed_rg_size; } cur_row_count += row_group.num_rows; } + // add the last pass if necessary if (_file_itm_data.input_pass_row_group_offsets.back() != row_groups_info.size()) { _file_itm_data.input_pass_row_group_offsets.push_back(row_groups_info.size()); - _file_itm_data.input_pass_row_count.push_back(cur_row_count); - } -} - -void reader::impl::setup_next_pass() -{ - // this will also cause the previous pass information to be deleted - _pass_itm_data = std::make_unique(); - - // setup row groups to be loaded for this pass - auto const row_group_start = _file_itm_data.input_pass_row_group_offsets[_current_input_pass]; - auto const row_group_end = _file_itm_data.input_pass_row_group_offsets[_current_input_pass + 1]; - auto const num_row_groups = row_group_end - row_group_start; - _pass_itm_data->row_groups.resize(num_row_groups); - std::copy(_file_itm_data.row_groups.begin() + row_group_start, - _file_itm_data.row_groups.begin() + row_group_end, - _pass_itm_data->row_groups.begin()); - - auto const num_passes = _file_itm_data.input_pass_row_group_offsets.size() - 1; - CUDF_EXPECTS(_current_input_pass < num_passes, "Encountered an invalid read pass index"); - - auto const chunks_per_rowgroup = _input_columns.size(); - auto const num_chunks = chunks_per_rowgroup * num_row_groups; - - auto chunk_start = _file_itm_data.chunks.begin() + (row_group_start * chunks_per_rowgroup); - auto chunk_end = _file_itm_data.chunks.begin() + (row_group_end * chunks_per_rowgroup); - - _pass_itm_data->chunks = cudf::detail::hostdevice_vector(num_chunks, _stream); - std::copy(chunk_start, chunk_end, _pass_itm_data->chunks.begin()); - - // adjust skip_rows and num_rows by what's available in the row groups we are processing - if (num_passes == 1) { - _pass_itm_data->skip_rows = _file_itm_data.global_skip_rows; - _pass_itm_data->num_rows = _file_itm_data.global_num_rows; - } else { - auto const global_start_row = _file_itm_data.global_skip_rows; - auto const global_end_row = global_start_row + _file_itm_data.global_num_rows; - auto const start_row = - std::max(_file_itm_data.input_pass_row_count[_current_input_pass], global_start_row); - auto const end_row = - std::min(_file_itm_data.input_pass_row_count[_current_input_pass + 1], global_end_row); - - // skip_rows is always global in the sense that it is relative to the first row of - // everything we will be reading, regardless of what pass we are on. - // num_rows is how many rows we are reading this pass. - _pass_itm_data->skip_rows = - global_start_row + _file_itm_data.input_pass_row_count[_current_input_pass]; - _pass_itm_data->num_rows = end_row - start_row; + _file_itm_data.input_pass_start_row_count.push_back(cur_row_count); } } -void reader::impl::compute_splits_for_pass() +void reader::impl::compute_output_chunks_for_subpass() { - auto const skip_rows = _pass_itm_data->skip_rows; - auto const num_rows = _pass_itm_data->num_rows; + auto& pass = *_pass_itm_data; + auto& subpass = *pass.subpass; // simple case : no chunk size, no splits if (_output_chunk_read_limit <= 0) { - _pass_itm_data->output_chunk_read_info = std::vector{{skip_rows, num_rows}}; + subpass.output_chunk_read_info.push_back({subpass.skip_rows, subpass.num_rows}); return; } - auto& pages = _pass_itm_data->pages_info; - - auto const& page_keys = _pass_itm_data->page_keys; - auto const& page_index = _pass_itm_data->page_index; - - // generate cumulative row counts and sizes - rmm::device_uvector c_info(page_keys.size(), _stream); - // convert PageInfo to cumulative_row_info - auto page_input = thrust::make_transform_iterator(page_index.begin(), - get_cumulative_row_info{pages.device_ptr()}); - thrust::inclusive_scan_by_key(rmm::exec_policy(_stream), - page_keys.begin(), - page_keys.end(), + // generate row_indices and cumulative output sizes for all pages + rmm::device_uvector c_info(subpass.pages.size(), _stream); + auto page_input = + thrust::make_transform_iterator(subpass.pages.d_begin(), get_page_output_size{}); + auto page_keys = make_page_key_iterator(subpass.pages); + thrust::inclusive_scan_by_key(rmm::exec_policy_nosync(_stream), + page_keys, + page_keys + subpass.pages.size(), page_input, c_info.begin(), thrust::equal_to{}, - cumulative_row_sum{}); - // print_cumulative_page_info(pages, page_index, c_info, stream); - - // sort by row count - rmm::device_uvector c_info_sorted{c_info, _stream}; - thrust::sort( - rmm::exec_policy(_stream), c_info_sorted.begin(), c_info_sorted.end(), row_count_compare{}); - - // std::vector h_c_info_sorted(c_info_sorted.size()); - // CUDF_CUDA_TRY(cudaMemcpy(h_c_info_sorted.data(), - // c_info_sorted.data(), - // sizeof(cumulative_row_info) * c_info_sorted.size(), - // cudaMemcpyDefault)); - // print_cumulative_row_info(h_c_info_sorted, "raw"); - - // generate key offsets (offsets to the start of each partition of keys). worst case is 1 page per - // key - rmm::device_uvector key_offsets(page_keys.size() + 1, _stream); - auto const key_offsets_end = thrust::reduce_by_key(rmm::exec_policy(_stream), - page_keys.begin(), - page_keys.end(), - thrust::make_constant_iterator(1), - thrust::make_discard_iterator(), - key_offsets.begin()) - .second; - size_t const num_unique_keys = key_offsets_end - key_offsets.begin(); - thrust::exclusive_scan( - rmm::exec_policy(_stream), key_offsets.begin(), key_offsets.end(), key_offsets.begin()); - - // adjust the cumulative info such that for each row count, the size includes any pages that span - // that row count. this is so that if we have this case: - // page row counts - // Column A: 0 <----> 100 <----> 200 - // Column B: 0 <---------------> 200 <--------> 400 - // | - // if we decide to split at row 100, we don't really know the actual amount of bytes in column B - // at that point. So we have to proceed as if we are taking the bytes from all 200 rows of that - // page. - // - rmm::device_uvector aggregated_info(c_info.size(), _stream); - thrust::transform(rmm::exec_policy(_stream), - c_info_sorted.begin(), - c_info_sorted.end(), - aggregated_info.begin(), - row_total_size{c_info.data(), key_offsets.data(), num_unique_keys}); - - // bring back to the cpu - std::vector h_aggregated_info(aggregated_info.size()); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_aggregated_info.data(), - aggregated_info.data(), - sizeof(cumulative_row_info) * c_info.size(), - cudaMemcpyDefault, - _stream.value())); - _stream.synchronize(); - - // generate the actual splits - _pass_itm_data->output_chunk_read_info = - find_splits(h_aggregated_info, num_rows, _output_chunk_read_limit); + cumulative_page_sum{}); + auto iter = thrust::make_counting_iterator(0); + thrust::for_each(rmm::exec_policy_nosync(_stream), + iter, + iter + subpass.pages.size(), + set_row_index{pass.chunks, subpass.pages, c_info}); + // print_cumulative_page_info(subpass.pages, c_info, _stream); + + // compute the splits + subpass.output_chunk_read_info = compute_page_splits_by_row( + c_info, subpass.pages, subpass.skip_rows, subpass.num_rows, _output_chunk_read_limit, _stream); } } // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/reader_impl_chunking.hpp b/cpp/src/io/parquet/reader_impl_chunking.hpp index dfc239d8451..a9cf0e94ec8 100644 --- a/cpp/src/io/parquet/reader_impl_chunking.hpp +++ b/cpp/src/io/parquet/reader_impl_chunking.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,58 +30,105 @@ struct file_intermediate_data { // all row groups to read std::vector row_groups{}; - // all chunks from the selected row groups. We may end up reading these chunks progressively - // instead of all at once + // all chunks from the selected row groups. std::vector chunks{}; // an array of offsets into _file_itm_data::global_chunks. Each pair of offsets represents // the start/end of the chunks to be loaded for a given pass. std::vector input_pass_row_group_offsets{}; - // row counts per input-pass - std::vector input_pass_row_count{}; - // skip_rows/num_rows values for the entire file. these need to be adjusted per-pass because we - // may not be visiting every row group that contains these bounds + // start row counts per input-pass. this includes all rows in the row groups of the pass and + // is not capped by global_skip_rows and global_num_rows. + std::vector input_pass_start_row_count{}; + + size_t _current_input_pass{0}; // current input pass index + size_t _output_chunk_count{0}; // how many output chunks we have produced + + // skip_rows/num_rows values for the entire file. size_t global_skip_rows; size_t global_num_rows; + + [[nodiscard]] size_t num_passes() const + { + return input_pass_row_group_offsets.size() == 0 ? 0 : input_pass_row_group_offsets.size() - 1; + } }; /** - * @brief Struct to identify the range for each chunk of rows during a chunked reading pass. + * @brief Struct to identify a range of rows. */ -struct chunk_read_info { +struct row_range { + size_t skip_rows; + size_t num_rows; +}; + +/** + * @brief Passes are broken down into subpasses based on temporary memory constraints. + */ +struct subpass_intermediate_data { + rmm::device_buffer decomp_page_data; + + rmm::device_buffer level_decode_data{}; + cudf::detail::hostdevice_vector pages{}; + // for each page in the subpass, the index of our source page in the pass + cudf::detail::hostdevice_vector page_src_index{}; + // for each column in the file (indexed by _input_columns.size()) + // the number of associated pages for this subpass + std::vector column_page_count; + cudf::detail::hostdevice_vector page_nesting_info{}; + cudf::detail::hostdevice_vector page_nesting_decode_info{}; + + std::vector output_chunk_read_info; + std::size_t current_output_chunk{0}; + + // skip_rows and num_rows values for this particular subpass. in absolute row indices. size_t skip_rows; size_t num_rows; }; /** * @brief Struct to store pass-level data that remains constant for a single pass. + * + * A pass is defined as a set of rowgroups read but not yet decompressed. This set of + * rowgroups may represent less than all of the rowgroups to be read for the file. */ struct pass_intermediate_data { std::vector> raw_page_data; - rmm::device_buffer decomp_page_data; // rowgroup, chunk and page information for the current pass. + bool has_compressed_data{false}; std::vector row_groups{}; cudf::detail::hostdevice_vector chunks{}; - cudf::detail::hostdevice_vector pages_info{}; - cudf::detail::hostdevice_vector page_nesting_info{}; - cudf::detail::hostdevice_vector page_nesting_decode_info{}; + cudf::detail::hostdevice_vector pages{}; - rmm::device_uvector page_keys{0, rmm::cuda_stream_default}; - rmm::device_uvector page_index{0, rmm::cuda_stream_default}; - rmm::device_uvector str_dict_index{0, rmm::cuda_stream_default}; + // base memory used for the pass itself (compressed data in the loaded chunks and any + // decompressed dictionary pages) + size_t base_mem_size{0}; - std::vector output_chunk_read_info; - std::size_t current_output_chunk{0}; + // offsets to each group of input pages (by column/schema, indexed by _input_columns.size()) + // so if we had 2 columns/schemas, with page keys + // + // 1 1 1 1 1 2 2 2 + // + // page_offsets would be 0, 5, 8 + cudf::detail::hostdevice_vector page_offsets{}; + + rmm::device_buffer decomp_dict_data{0, rmm::cuda_stream_default}; + rmm::device_uvector str_dict_index{0, rmm::cuda_stream_default}; - rmm::device_buffer level_decode_data{}; int level_type_size{0}; - // skip_rows and num_rows values for this particular pass. these may be adjusted values from the - // global values stored in file_intermediate_data. + // skip_rows / num_rows for this pass. + // NOTE: skip_rows is the absolute row index in the file. size_t skip_rows; size_t num_rows; + // number of rows we have processed so far (out of num_rows). note that this + // only includes the number of rows we have processed before starting the current + // subpass. it does not get updated as a subpass iterates through output chunks. + size_t processed_rows{0}; + + // currently active subpass + std::unique_ptr subpass{}; }; } // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/reader_impl_preprocess.cu b/cpp/src/io/parquet/reader_impl_preprocess.cu index e10f2c00f40..ee3b1c466e0 100644 --- a/cpp/src/io/parquet/reader_impl_preprocess.cu +++ b/cpp/src/io/parquet/reader_impl_preprocess.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,9 +17,6 @@ #include "error.hpp" #include "reader_impl.hpp" -#include -#include - #include #include #include @@ -49,6 +46,28 @@ namespace cudf::io::parquet::detail { namespace { +#if defined(PREPROCESS_DEBUG) +void print_pages(cudf::detail::hostdevice_vector& pages, rmm::cuda_stream_view _stream) +{ + pages.device_to_host_sync(_stream); + for (size_t idx = 0; idx < pages.size(); idx++) { + auto const& p = pages[idx]; + // skip dictionary pages + if (p.flags & PAGEINFO_FLAGS_DICTIONARY) { continue; } + printf( + "P(%lu, s:%d): chunk_row(%d), num_rows(%d), skipped_values(%d), skipped_leaf_values(%d), " + "str_bytes(%d)\n", + idx, + p.src_col_schema, + p.chunk_row, + p.num_rows, + p.skipped_values, + p.skipped_leaf_values, + p.str_bytes); + } +} +#endif // PREPROCESS_DEBUG + /** * @brief Generate depth remappings for repetition and definition levels. * @@ -269,7 +288,7 @@ void generate_depth_remappings(std::map, std::ve kernel_error error_code(stream); chunks.host_to_device_async(stream); - DecodePageHeaders(chunks.device_ptr(), chunks.size(), error_code.data(), stream); + DecodePageHeaders(chunks.device_ptr(), nullptr, chunks.size(), error_code.data(), stream); chunks.device_to_host_sync(stream); // It's required to ignore unsupported encodings in this function @@ -351,33 +370,37 @@ std::string encoding_to_string(Encoding encoding) } /** - * @brief Decode the page information from the given column chunks. + * @brief Decode the page information for a given pass. * - * @param chunks List of column chunk descriptors - * @param pages List of page information - * @param stream CUDA stream used for device memory operations and kernel launches - * @returns The size in bytes of level type data required + * @param pass_intermediate_data The struct containing pass information */ -int decode_page_headers(cudf::detail::hostdevice_vector& chunks, - cudf::detail::hostdevice_vector& pages, - rmm::cuda_stream_view stream) +void decode_page_headers(pass_intermediate_data& pass, + device_span unsorted_pages, + rmm::cuda_stream_view stream) { + cudf::detail::hostdevice_vector chunk_page_info(pass.chunks.size(), stream); + // IMPORTANT : if you change how pages are stored within a chunk (dist pages, then data pages), // please update preprocess_nested_columns to reflect this. - for (size_t c = 0, page_count = 0; c < chunks.size(); c++) { - chunks[c].max_num_pages = chunks[c].num_data_pages + chunks[c].num_dict_pages; - chunks[c].page_info = pages.device_ptr(page_count); - page_count += chunks[c].max_num_pages; + for (size_t c = 0, page_count = 0; c < pass.chunks.size(); c++) { + pass.chunks[c].max_num_pages = pass.chunks[c].num_data_pages + pass.chunks[c].num_dict_pages; + chunk_page_info[c].pages = &unsorted_pages[page_count]; + page_count += pass.chunks[c].max_num_pages; } kernel_error error_code(stream); - chunks.host_to_device_async(stream); - DecodePageHeaders(chunks.device_ptr(), chunks.size(), error_code.data(), stream); + pass.chunks.host_to_device_async(stream); + chunk_page_info.host_to_device_async(stream); + DecodePageHeaders(pass.chunks.device_ptr(), + chunk_page_info.device_ptr(), + pass.chunks.size(), + error_code.data(), + stream); if (error_code.value() != 0) { if (BitAnd(error_code.value(), decode_error::UNSUPPORTED_ENCODING) != 0) { auto const unsupported_str = - ". With unsupported encodings found: " + list_unsupported_encodings(pages, stream); + ". With unsupported encodings found: " + list_unsupported_encodings(pass.pages, stream); CUDF_FAIL("Parquet header parsing failed with code(s) " + error_code.str() + unsupported_str); } else { CUDF_FAIL("Parquet header parsing failed with code(s) " + error_code.str()); @@ -386,7 +409,7 @@ int decode_page_headers(cudf::detail::hostdevice_vector& chunks // compute max bytes needed for level data auto level_bit_size = cudf::detail::make_counting_transform_iterator( - 0, cuda::proclaim_return_type([chunks = chunks.d_begin()] __device__(int i) { + 0, cuda::proclaim_return_type([chunks = pass.chunks.d_begin()] __device__(int i) { auto c = chunks[i]; return static_cast( max(c.level_bits[level_type::REPETITION], c.level_bits[level_type::DEFINITION])); @@ -394,223 +417,243 @@ int decode_page_headers(cudf::detail::hostdevice_vector& chunks // max level data bit size. int const max_level_bits = thrust::reduce(rmm::exec_policy(stream), level_bit_size, - level_bit_size + chunks.size(), + level_bit_size + pass.chunks.size(), 0, thrust::maximum()); + pass.level_type_size = std::max(1, cudf::util::div_rounding_up_safe(max_level_bits, 8)); - return std::max(1, cudf::util::div_rounding_up_safe(max_level_bits, 8)); -} + // sort the pages in chunk/schema order. we use chunk.src_col_index instead of + // chunk.src_col_schema because the user may have reordered them (reading columns, "a" and "b" but + // returning them as "b" and "a") + // + // ordering of pages is by input column schema, repeated across row groups. so + // if we had 3 columns, each with 2 pages, and 1 row group, our schema values might look like + // + // 1, 1, 2, 2, 3, 3 + // + // However, if we had more than one row group, the pattern would be + // + // 1, 1, 2, 2, 3, 3, 1, 1, 2, 2, 3, 3 + // ^ row group 0 | + // ^ row group 1 + // + // To process pages by key (exclusive_scan_by_key, reduce_by_key, etc), the ordering we actually + // want is + // + // 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3 + // + // We also need to preserve key-relative page ordering, so we need to use a stable sort. + { + rmm::device_uvector page_keys{unsorted_pages.size(), stream}; + thrust::transform(rmm::exec_policy_nosync(stream), + unsorted_pages.begin(), + unsorted_pages.end(), + page_keys.begin(), + [chunks = pass.chunks.d_begin()] __device__(PageInfo const& page) { + return chunks[page.chunk_idx].src_col_index; + }); + // we are doing this by sorting indices first and then transforming the output because nvcc + // started generating kernels using too much shared memory when trying to sort the pages + // directly. + rmm::device_uvector sort_indices(unsorted_pages.size(), stream); + thrust::sequence(rmm::exec_policy_nosync(stream), sort_indices.begin(), sort_indices.end(), 0); + thrust::stable_sort_by_key(rmm::exec_policy_nosync(stream), + page_keys.begin(), + page_keys.end(), + sort_indices.begin(), + thrust::less()); + pass.pages = cudf::detail::hostdevice_vector( + unsorted_pages.size(), unsorted_pages.size(), stream); + thrust::transform(rmm::exec_policy_nosync(stream), + sort_indices.begin(), + sort_indices.end(), + pass.pages.d_begin(), + [unsorted_pages = unsorted_pages.begin()] __device__(int32_t i) { + return unsorted_pages[i]; + }); + } -/** - * @brief Decompresses the page data, at page granularity. - * - * @param chunks List of column chunk descriptors - * @param pages List of page information - * @param stream CUDA stream used for device memory operations and kernel launches - * - * @return Device buffer to decompressed page data - */ -[[nodiscard]] rmm::device_buffer decompress_page_data( - cudf::detail::hostdevice_vector& chunks, - cudf::detail::hostdevice_vector& pages, - rmm::cuda_stream_view stream) -{ - auto for_each_codec_page = [&](Compression codec, std::function const& f) { - for (size_t c = 0, page_count = 0; c < chunks.size(); c++) { - const auto page_stride = chunks[c].max_num_pages; - if (chunks[c].codec == codec) { - for (int k = 0; k < page_stride; k++) { - f(page_count + k); - } - } - page_count += page_stride; - } - }; + // compute offsets to each group of input pages. + // page_keys: 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3 + // + // result: 0, 4, 8 + rmm::device_uvector page_counts(pass.pages.size() + 1, stream); + auto page_keys = make_page_key_iterator(pass.pages); + auto const page_counts_end = thrust::reduce_by_key(rmm::exec_policy(stream), + page_keys, + page_keys + pass.pages.size(), + thrust::make_constant_iterator(1), + thrust::make_discard_iterator(), + page_counts.begin()) + .second; + auto const num_page_counts = page_counts_end - page_counts.begin(); + pass.page_offsets = cudf::detail::hostdevice_vector(num_page_counts + 1, stream); + thrust::exclusive_scan(rmm::exec_policy_nosync(stream), + page_counts.begin(), + page_counts.begin() + num_page_counts + 1, + pass.page_offsets.d_begin()); + + // setup dict_page for each chunk if necessary + thrust::for_each(rmm::exec_policy_nosync(stream), + pass.pages.d_begin(), + pass.pages.d_end(), + [chunks = pass.chunks.d_begin()] __device__(PageInfo const& p) { + if (p.flags & PAGEINFO_FLAGS_DICTIONARY) { + chunks[p.chunk_idx].dict_page = &p; + } + }); - // Brotli scratch memory for decompressing - rmm::device_buffer debrotli_scratch; + pass.page_offsets.device_to_host_async(stream); + pass.pages.device_to_host_async(stream); + pass.chunks.device_to_host_async(stream); + stream.synchronize(); +} - // Count the exact number of compressed pages - size_t num_comp_pages = 0; - size_t total_decomp_size = 0; +struct set_str_dict_index_count { + device_span str_dict_index_count; + device_span chunks; - struct codec_stats { - Compression compression_type = UNCOMPRESSED; - size_t num_pages = 0; - int32_t max_decompressed_size = 0; - size_t total_decomp_size = 0; - }; + __device__ void operator()(PageInfo const& page) + { + auto const& chunk = chunks[page.chunk_idx]; + if ((page.flags & PAGEINFO_FLAGS_DICTIONARY) && (chunk.data_type & 0x7) == BYTE_ARRAY && + (chunk.num_dict_pages > 0)) { + // there is only ever one dictionary page per chunk, so this is safe to do in parallel. + str_dict_index_count[page.chunk_idx] = page.num_input_values; + } + } +}; - std::array codecs{codec_stats{GZIP}, codec_stats{SNAPPY}, codec_stats{BROTLI}, codec_stats{ZSTD}}; +struct set_str_dict_index_ptr { + string_index_pair* const base; + device_span str_dict_index_offsets; + device_span chunks; - auto is_codec_supported = [&codecs](int8_t codec) { - if (codec == UNCOMPRESSED) return true; - return std::find_if(codecs.begin(), codecs.end(), [codec](auto& cstats) { - return codec == cstats.compression_type; - }) != codecs.end(); - }; - CUDF_EXPECTS(std::all_of(chunks.begin(), - chunks.end(), - [&is_codec_supported](auto const& chunk) { - return is_codec_supported(chunk.codec); - }), - "Unsupported compression type"); - - for (auto& codec : codecs) { - for_each_codec_page(codec.compression_type, [&](size_t page) { - auto page_uncomp_size = pages[page].uncompressed_page_size; - total_decomp_size += page_uncomp_size; - codec.total_decomp_size += page_uncomp_size; - codec.max_decompressed_size = std::max(codec.max_decompressed_size, page_uncomp_size); - codec.num_pages++; - num_comp_pages++; - }); - if (codec.compression_type == BROTLI && codec.num_pages > 0) { - debrotli_scratch.resize(get_gpu_debrotli_scratch_size(codec.num_pages), stream); + __device__ void operator()(size_t i) + { + auto& chunk = chunks[i]; + if ((chunk.data_type & 0x7) == BYTE_ARRAY && (chunk.num_dict_pages > 0)) { + chunk.str_dict_index = base + str_dict_index_offsets[i]; } } +}; - // Dispatch batches of pages to decompress for each codec. - // Buffer needs to be padded, required by `gpuDecodePageData`. - rmm::device_buffer decomp_pages( - cudf::util::round_up_safe(total_decomp_size, BUFFER_PADDING_MULTIPLE), stream); - - std::vector> comp_in; - comp_in.reserve(num_comp_pages); - std::vector> comp_out; - comp_out.reserve(num_comp_pages); - - // vectors to save v2 def and rep level data, if any - std::vector> copy_in; - copy_in.reserve(num_comp_pages); - std::vector> copy_out; - copy_out.reserve(num_comp_pages); - - rmm::device_uvector comp_res(num_comp_pages, stream); - thrust::fill(rmm::exec_policy(stream), - comp_res.begin(), - comp_res.end(), - compression_result{0, compression_status::FAILURE}); - - size_t decomp_offset = 0; - int32_t start_pos = 0; - for (auto const& codec : codecs) { - if (codec.num_pages == 0) { continue; } - - for_each_codec_page(codec.compression_type, [&](size_t page_idx) { - auto const dst_base = static_cast(decomp_pages.data()) + decomp_offset; - auto& page = pages[page_idx]; - // offset will only be non-zero for V2 pages - auto const offset = - page.lvl_bytes[level_type::DEFINITION] + page.lvl_bytes[level_type::REPETITION]; - // for V2 need to copy def and rep level info into place, and then offset the - // input and output buffers. otherwise we'd have to keep both the compressed - // and decompressed data. - if (offset != 0) { - copy_in.emplace_back(page.page_data, offset); - copy_out.emplace_back(dst_base, offset); - } - comp_in.emplace_back(page.page_data + offset, - static_cast(page.compressed_page_size - offset)); - comp_out.emplace_back(dst_base + offset, - static_cast(page.uncompressed_page_size - offset)); - page.page_data = dst_base; - decomp_offset += page.uncompressed_page_size; - }); +/** + * @brief Functor which computes an estimated row count for list pages. + * + */ +struct set_list_row_count_estimate { + device_span chunks; - host_span const> comp_in_view{comp_in.data() + start_pos, - codec.num_pages}; - auto const d_comp_in = cudf::detail::make_device_uvector_async( - comp_in_view, stream, rmm::mr::get_current_device_resource()); - host_span const> comp_out_view(comp_out.data() + start_pos, - codec.num_pages); - auto const d_comp_out = cudf::detail::make_device_uvector_async( - comp_out_view, stream, rmm::mr::get_current_device_resource()); - device_span d_comp_res_view(comp_res.data() + start_pos, codec.num_pages); - - switch (codec.compression_type) { - case GZIP: - gpuinflate(d_comp_in, d_comp_out, d_comp_res_view, gzip_header_included::YES, stream); - break; - case SNAPPY: - if (cudf::io::detail::nvcomp_integration::is_stable_enabled()) { - nvcomp::batched_decompress(nvcomp::compression_type::SNAPPY, - d_comp_in, - d_comp_out, - d_comp_res_view, - codec.max_decompressed_size, - codec.total_decomp_size, - stream); - } else { - gpu_unsnap(d_comp_in, d_comp_out, d_comp_res_view, stream); - } - break; - case ZSTD: - nvcomp::batched_decompress(nvcomp::compression_type::ZSTD, - d_comp_in, - d_comp_out, - d_comp_res_view, - codec.max_decompressed_size, - codec.total_decomp_size, - stream); - break; - case BROTLI: - gpu_debrotli(d_comp_in, - d_comp_out, - d_comp_res_view, - debrotli_scratch.data(), - debrotli_scratch.size(), - stream); - break; - default: CUDF_FAIL("Unexpected decompression dispatch"); break; - } - start_pos += codec.num_pages; + __device__ void operator()(PageInfo& page) + { + if (page.flags & PAGEINFO_FLAGS_DICTIONARY) { return; } + auto const& chunk = chunks[page.chunk_idx]; + auto const is_list = chunk.max_level[level_type::REPETITION] > 0; + if (!is_list) { return; } + + // For LIST pages that we have not yet decoded, page.num_rows is not an accurate number. + // so we instead estimate the number of rows as follows: + // - each chunk stores an estimated number of bytes per row E + // - estimate number of rows in a page = page.uncompressed_page_size / E + // + // it is not required that this number is accurate. we just want it to be somewhat close so that + // we get reasonable results as we choose subpass splits. + // + // all other columns can use page.num_rows directly as it will be accurate. + page.num_rows = static_cast(static_cast(page.uncompressed_page_size) / + chunk.list_bytes_per_row_est); } +}; + +/** + * @brief Set the expected row count on the final page for all columns. + * + */ +struct set_final_row_count { + device_span pages; + device_span chunks; + device_span page_offsets; + size_t const max_row; - CUDF_EXPECTS(thrust::all_of(rmm::exec_policy(stream), - comp_res.begin(), - comp_res.end(), - [] __device__(auto const& res) { - return res.status == compression_status::SUCCESS; - }), - "Error during decompression"); - - // now copy the uncompressed V2 def and rep level data - if (not copy_in.empty()) { - auto const d_copy_in = cudf::detail::make_device_uvector_async( - copy_in, stream, rmm::mr::get_current_device_resource()); - auto const d_copy_out = cudf::detail::make_device_uvector_async( - copy_out, stream, rmm::mr::get_current_device_resource()); - - gpu_copy_uncompressed_blocks(d_copy_in, d_copy_out, stream); - stream.synchronize(); + __device__ void operator()(size_t i) + { + auto const last_page_index = page_offsets[i + 1] - 1; + auto const& page = pages[last_page_index]; + auto const& chunk = chunks[page.chunk_idx]; + size_t const page_start_row = chunk.start_row + page.chunk_row; + pages[last_page_index].num_rows = max_row - page_start_row; } +}; - // Update the page information in device memory with the updated value of - // page_data; it now points to the uncompressed data buffer - pages.host_to_device_async(stream); +} // anonymous namespace - return decomp_pages; +void reader::impl::build_string_dict_indices() +{ + auto& pass = *_pass_itm_data; + + // compute number of indices per chunk and a summed total + rmm::device_uvector str_dict_index_count(pass.chunks.size() + 1, _stream); + thrust::fill( + rmm::exec_policy_nosync(_stream), str_dict_index_count.begin(), str_dict_index_count.end(), 0); + thrust::for_each(rmm::exec_policy_nosync(_stream), + pass.pages.begin(), + pass.pages.end(), + set_str_dict_index_count{str_dict_index_count, pass.chunks}); + + size_t const total_str_dict_indexes = thrust::reduce( + rmm::exec_policy(_stream), str_dict_index_count.begin(), str_dict_index_count.end()); + if (total_str_dict_indexes == 0) { return; } + + // convert to offsets + rmm::device_uvector& str_dict_index_offsets = str_dict_index_count; + thrust::exclusive_scan(rmm::exec_policy_nosync(_stream), + str_dict_index_offsets.begin(), + str_dict_index_offsets.end(), + str_dict_index_offsets.begin(), + 0); + + // allocate and distribute pointers + pass.str_dict_index = cudf::detail::make_zeroed_device_uvector_async( + total_str_dict_indexes, _stream, rmm::mr::get_current_device_resource()); + + auto iter = thrust::make_counting_iterator(0); + thrust::for_each( + rmm::exec_policy_nosync(_stream), + iter, + iter + pass.chunks.size(), + set_str_dict_index_ptr{pass.str_dict_index.data(), str_dict_index_offsets, pass.chunks}); + + // compute the indices + BuildStringDictionaryIndex(pass.chunks.device_ptr(), pass.chunks.size(), _stream); + pass.chunks.device_to_host_sync(_stream); } -} // namespace - void reader::impl::allocate_nesting_info() { - auto const& chunks = _pass_itm_data->chunks; - auto& pages = _pass_itm_data->pages_info; - auto& page_nesting_info = _pass_itm_data->page_nesting_info; - auto& page_nesting_decode_info = _pass_itm_data->page_nesting_decode_info; + auto& pass = *_pass_itm_data; + auto& subpass = *pass.subpass; + + auto const num_columns = _input_columns.size(); + auto& pages = subpass.pages; + auto& page_nesting_info = subpass.page_nesting_info; + auto& page_nesting_decode_info = subpass.page_nesting_decode_info; + + // generate the number of nesting info structs needed per-page, by column + std::vector per_page_nesting_info_size(num_columns); + auto iter = thrust::make_counting_iterator(size_type{0}); + std::transform(iter, iter + num_columns, per_page_nesting_info_size.begin(), [&](size_type i) { + auto const schema_idx = _input_columns[i].schema_idx; + auto const& schema = _metadata->get_schema(schema_idx); + return max(schema.max_definition_level + 1, _metadata->get_output_nesting_depth(schema_idx)); + }); // compute total # of page_nesting infos needed and allocate space. doing this in one // buffer to keep it to a single gpu allocation - size_t const total_page_nesting_infos = std::accumulate( - chunks.host_ptr(), chunks.host_ptr() + chunks.size(), 0, [&](int total, auto& chunk) { - // the schema of the input column - auto const& schema = _metadata->get_schema(chunk.src_col_schema); - auto const per_page_nesting_info_size = max( - schema.max_definition_level + 1, _metadata->get_output_nesting_depth(chunk.src_col_schema)); - return total + (per_page_nesting_info_size * chunk.num_data_pages); + auto counting_iter = thrust::make_counting_iterator(size_t{0}); + size_t const total_page_nesting_infos = + std::accumulate(counting_iter, counting_iter + num_columns, 0, [&](int total, size_t index) { + return total + (per_page_nesting_info_size[index] * subpass.column_page_count[index]); }); page_nesting_info = @@ -621,41 +664,33 @@ void reader::impl::allocate_nesting_info() // update pointers in the PageInfos int target_page_index = 0; int src_info_index = 0; - for (size_t idx = 0; idx < chunks.size(); idx++) { - int src_col_schema = chunks[idx].src_col_schema; - auto& schema = _metadata->get_schema(src_col_schema); - auto const per_page_nesting_info_size = std::max( - schema.max_definition_level + 1, _metadata->get_output_nesting_depth(src_col_schema)); - - // skip my dict pages - target_page_index += chunks[idx].num_dict_pages; - for (int p_idx = 0; p_idx < chunks[idx].num_data_pages; p_idx++) { + for (size_t idx = 0; idx < _input_columns.size(); idx++) { + auto const src_col_schema = _input_columns[idx].schema_idx; + + for (size_t p_idx = 0; p_idx < subpass.column_page_count[idx]; p_idx++) { pages[target_page_index + p_idx].nesting = page_nesting_info.device_ptr() + src_info_index; pages[target_page_index + p_idx].nesting_decode = page_nesting_decode_info.device_ptr() + src_info_index; - pages[target_page_index + p_idx].nesting_info_size = per_page_nesting_info_size; + pages[target_page_index + p_idx].nesting_info_size = per_page_nesting_info_size[idx]; pages[target_page_index + p_idx].num_output_nesting_levels = _metadata->get_output_nesting_depth(src_col_schema); - src_info_index += per_page_nesting_info_size; + src_info_index += per_page_nesting_info_size[idx]; } - target_page_index += chunks[idx].num_data_pages; + target_page_index += subpass.column_page_count[idx]; } // fill in int nesting_info_index = 0; std::map, std::vector>> depth_remapping; - for (size_t idx = 0; idx < chunks.size(); idx++) { - int src_col_schema = chunks[idx].src_col_schema; + for (size_t idx = 0; idx < _input_columns.size(); idx++) { + auto const src_col_schema = _input_columns[idx].schema_idx; // schema of the input column auto& schema = _metadata->get_schema(src_col_schema); // real depth of the output cudf column hierarchy (1 == no nesting, 2 == 1 level, etc) - int max_depth = _metadata->get_output_nesting_depth(src_col_schema); - - // # of nesting infos stored per page for this column - auto const per_page_nesting_info_size = std::max(schema.max_definition_level + 1, max_depth); + int const max_output_depth = _metadata->get_output_nesting_depth(src_col_schema); // if this column has lists, generate depth remapping std::map, std::vector>> depth_remapping; @@ -666,18 +701,19 @@ void reader::impl::allocate_nesting_info() // fill in host-side nesting info int schema_idx = src_col_schema; auto cur_schema = _metadata->get_schema(schema_idx); - int cur_depth = max_depth - 1; + int cur_depth = max_output_depth - 1; while (schema_idx > 0) { - // stub columns (basically the inner field of a list scheme element) are not real columns. + // stub columns (basically the inner field of a list schema element) are not real columns. // we can ignore them for the purposes of output nesting info if (!cur_schema.is_stub()) { // initialize each page within the chunk - for (int p_idx = 0; p_idx < chunks[idx].num_data_pages; p_idx++) { + for (size_t p_idx = 0; p_idx < subpass.column_page_count[idx]; p_idx++) { PageNestingInfo* pni = - &page_nesting_info[nesting_info_index + (p_idx * per_page_nesting_info_size)]; + &page_nesting_info[nesting_info_index + (p_idx * per_page_nesting_info_size[idx])]; PageNestingDecodeInfo* nesting_info = - &page_nesting_decode_info[nesting_info_index + (p_idx * per_page_nesting_info_size)]; + &page_nesting_decode_info[nesting_info_index + + (p_idx * per_page_nesting_info_size[idx])]; // if we have lists, set our start and end depth remappings if (schema.max_repetition_level > 0) { @@ -712,7 +748,7 @@ void reader::impl::allocate_nesting_info() cur_schema = _metadata->get_schema(schema_idx); } - nesting_info_index += (per_page_nesting_info_size * chunks[idx].num_data_pages); + nesting_info_index += (per_page_nesting_info_size[idx] * subpass.column_page_count[idx]); } // copy nesting info to the device @@ -722,32 +758,33 @@ void reader::impl::allocate_nesting_info() void reader::impl::allocate_level_decode_space() { - auto& pages = _pass_itm_data->pages_info; + auto& pass = *_pass_itm_data; + auto& subpass = *pass.subpass; + + auto& pages = subpass.pages; // TODO: this could be made smaller if we ignored dictionary pages and pages with no // repetition data. - size_t const per_page_decode_buf_size = - LEVEL_DECODE_BUF_SIZE * 2 * _pass_itm_data->level_type_size; - auto const decode_buf_size = per_page_decode_buf_size * pages.size(); - _pass_itm_data->level_decode_data = + size_t const per_page_decode_buf_size = LEVEL_DECODE_BUF_SIZE * 2 * pass.level_type_size; + auto const decode_buf_size = per_page_decode_buf_size * pages.size(); + subpass.level_decode_data = rmm::device_buffer(decode_buf_size, _stream, rmm::mr::get_current_device_resource()); // distribute the buffers - uint8_t* buf = static_cast(_pass_itm_data->level_decode_data.data()); + uint8_t* buf = static_cast(subpass.level_decode_data.data()); for (size_t idx = 0; idx < pages.size(); idx++) { auto& p = pages[idx]; p.lvl_decode_buf[level_type::DEFINITION] = buf; - buf += (LEVEL_DECODE_BUF_SIZE * _pass_itm_data->level_type_size); + buf += (LEVEL_DECODE_BUF_SIZE * pass.level_type_size); p.lvl_decode_buf[level_type::REPETITION] = buf; - buf += (LEVEL_DECODE_BUF_SIZE * _pass_itm_data->level_type_size); + buf += (LEVEL_DECODE_BUF_SIZE * pass.level_type_size); } } -std::pair>> reader::impl::read_and_decompress_column_chunks() +std::pair>> reader::impl::read_column_chunks() { auto const& row_groups_info = _pass_itm_data->row_groups; - auto const num_rows = _pass_itm_data->num_rows; auto& raw_page_data = _pass_itm_data->raw_page_data; auto& chunks = _pass_itm_data->chunks; @@ -767,13 +804,14 @@ std::pair>> reader::impl::read_and_decompres // Initialize column chunk information size_t total_decompressed_size = 0; - auto remaining_rows = num_rows; + // TODO: make this respect the pass-wide skip_rows/num_rows instead of the file-wide + // skip_rows/num_rows + // auto remaining_rows = num_rows; std::vector> read_chunk_tasks; size_type chunk_count = 0; for (auto const& rg : row_groups_info) { auto const& row_group = _metadata->get_row_group(rg.index, rg.source_index); auto const row_group_source = rg.source_index; - auto const row_group_rows = std::min(remaining_rows, row_group.num_rows); // generate ColumnChunkDesc objects for everything to be decoded (all input columns) for (size_t i = 0; i < num_input_columns; ++i) { @@ -795,7 +833,6 @@ std::pair>> reader::impl::read_and_decompres chunk_count++; } - remaining_rows -= row_group_rows; } // Read compressed chunk data to device memory @@ -808,22 +845,20 @@ std::pair>> reader::impl::read_and_decompres chunk_source_map, _stream)); - CUDF_EXPECTS(remaining_rows == 0, "All rows data must be read."); - return {total_decompressed_size > 0, std::move(read_chunk_tasks)}; } -void reader::impl::load_and_decompress_data() +void reader::impl::read_compressed_data() { + auto& pass = *_pass_itm_data; + // This function should never be called if `num_rows == 0`. CUDF_EXPECTS(_pass_itm_data->num_rows > 0, "Number of reading rows must not be zero."); - auto& raw_page_data = _pass_itm_data->raw_page_data; - auto& decomp_page_data = _pass_itm_data->decomp_page_data; - auto& chunks = _pass_itm_data->chunks; - auto& pages = _pass_itm_data->pages_info; + auto& chunks = pass.chunks; - auto const [has_compressed_data, read_chunks_tasks] = read_and_decompress_column_chunks(); + auto const [has_compressed_data, read_chunks_tasks] = read_column_chunks(); + pass.has_compressed_data = has_compressed_data; for (auto& task : read_chunks_tasks) { task.wait(); @@ -832,44 +867,12 @@ void reader::impl::load_and_decompress_data() // Process dataset chunk pages into output columns auto const total_pages = count_page_headers(chunks, _stream); if (total_pages <= 0) { return; } - pages = cudf::detail::hostdevice_vector(total_pages, total_pages, _stream); + rmm::device_uvector unsorted_pages(total_pages, _stream); // decoding of column/page information - _pass_itm_data->level_type_size = decode_page_headers(chunks, pages, _stream); - pages.device_to_host_sync(_stream); - if (has_compressed_data) { - decomp_page_data = decompress_page_data(chunks, pages, _stream); - // Free compressed data - for (size_t c = 0; c < chunks.size(); c++) { - if (chunks[c].codec != Compression::UNCOMPRESSED) { raw_page_data[c].reset(); } - } - } - - // build output column info - // walk the schema, building out_buffers that mirror what our final cudf columns will look - // like. important : there is not necessarily a 1:1 mapping between input columns and output - // columns. For example, parquet does not explicitly store a ColumnChunkDesc for struct - // columns. The "structiness" is simply implied by the schema. For example, this schema: - // required group field_id=1 name { - // required binary field_id=2 firstname (String); - // required binary field_id=3 middlename (String); - // required binary field_id=4 lastname (String); - // } - // will only contain 3 columns of data (firstname, middlename, lastname). But of course - // "name" is a struct column that we want to return, so we have to make sure that we - // create it ourselves. - // std::vector output_info = build_output_column_info(); - - // the following two allocate functions modify the page data - { - // nesting information (sizes, etc) stored -per page- - // note : even for flat schemas, we allocate 1 level of "nesting" info - allocate_nesting_info(); - - // level decode space - allocate_level_decode_space(); - } - pages.host_to_device_async(_stream); + decode_page_headers(pass, unsorted_pages, _stream); + CUDF_EXPECTS(pass.page_offsets.size() - 1 == static_cast(_input_columns.size()), + "Encountered page_offsets / num_columns mismatch"); } namespace { @@ -880,28 +883,6 @@ struct cumulative_row_info { int key; // schema index }; -#if defined(PREPROCESS_DEBUG) -void print_pages(cudf::detail::hostdevice_vector& pages, rmm::cuda_stream_view _stream) -{ - pages.device_to_host_sync(_stream); - for (size_t idx = 0; idx < pages.size(); idx++) { - auto const& p = pages[idx]; - // skip dictionary pages - if (p.flags & PAGEINFO_FLAGS_DICTIONARY) { continue; } - printf( - "P(%lu, s:%d): chunk_row(%d), num_rows(%d), skipped_values(%d), skipped_leaf_values(%d), " - "str_bytes(%d)\n", - idx, - p.src_col_schema, - p.chunk_row, - p.num_rows, - p.skipped_values, - p.skipped_leaf_values, - p.str_bytes); - } -} -#endif // PREPROCESS_DEBUG - struct get_page_chunk_idx { __device__ size_type operator()(PageInfo const& page) { return page.chunk_idx; } }; @@ -910,14 +891,6 @@ struct get_page_num_rows { __device__ size_type operator()(PageInfo const& page) { return page.num_rows; } }; -struct get_page_column_index { - ColumnChunkDesc const* chunks; - __device__ size_type operator()(PageInfo const& page) - { - return chunks[page.chunk_idx].src_col_index; - } -}; - struct input_col_info { int const schema_idx; size_type const nesting_depth; @@ -950,13 +923,12 @@ struct get_page_nesting_size { size_type const max_depth; size_t const num_pages; PageInfo const* const pages; - int const* page_indices; __device__ size_type operator()(size_t index) const { auto const indices = reduction_indices{index, max_depth, num_pages}; - auto const& page = pages[page_indices[indices.page_idx]]; + auto const& page = pages[indices.page_idx]; if (page.src_col_schema != input_cols[indices.col_idx].schema_idx || page.flags & PAGEINFO_FLAGS_DICTIONARY || indices.depth_idx >= input_cols[indices.col_idx].nesting_depth) { @@ -995,12 +967,14 @@ struct chunk_row_output_iter { __device__ reference operator*() { return p->chunk_row; } }; +/** + * @brief Writes to the page_start_value field of the PageNestingInfo struct, keyed by schema. + */ /** * @brief Writes to the page_start_value field of the PageNestingInfo struct, keyed by schema. */ struct start_offset_output_iterator { PageInfo const* pages; - int const* page_indices; size_t cur_index; input_col_info const* input_cols; size_type max_depth; @@ -1014,17 +988,16 @@ struct start_offset_output_iterator { constexpr void operator=(start_offset_output_iterator const& other) { - pages = other.pages; - page_indices = other.page_indices; - cur_index = other.cur_index; - input_cols = other.input_cols; - max_depth = other.max_depth; - num_pages = other.num_pages; + pages = other.pages; + cur_index = other.cur_index; + input_cols = other.input_cols; + max_depth = other.max_depth; + num_pages = other.num_pages; } constexpr start_offset_output_iterator operator+(size_t i) { - return {pages, page_indices, cur_index + i, input_cols, max_depth, num_pages}; + return start_offset_output_iterator{pages, cur_index + i, input_cols, max_depth, num_pages}; } constexpr start_offset_output_iterator& operator++() @@ -1041,7 +1014,7 @@ struct start_offset_output_iterator { { auto const indices = reduction_indices{index, max_depth, num_pages}; - PageInfo const& p = pages[page_indices[indices.page_idx]]; + PageInfo const& p = pages[indices.page_idx]; if (p.src_col_schema != input_cols[indices.col_idx].schema_idx || p.flags & PAGEINFO_FLAGS_DICTIONARY || indices.depth_idx >= input_cols[indices.col_idx].nesting_depth) { @@ -1051,114 +1024,20 @@ struct start_offset_output_iterator { } }; -struct flat_column_num_rows { - PageInfo const* pages; - ColumnChunkDesc const* chunks; - - __device__ size_type operator()(size_type pindex) const - { - PageInfo const& page = pages[pindex]; - // ignore dictionary pages and pages belonging to any column containing repetition (lists) - if ((page.flags & PAGEINFO_FLAGS_DICTIONARY) || - (chunks[page.chunk_idx].max_level[level_type::REPETITION] > 0)) { - return 0; - } - return page.num_rows; - } -}; - -struct row_counts_nonzero { - __device__ bool operator()(size_type count) const { return count > 0; } -}; - -struct row_counts_different { - size_type const expected; - __device__ bool operator()(size_type count) const { return (count != 0) && (count != expected); } -}; - -/** - * @brief Detect malformed parquet input data. - * - * We have seen cases where parquet files can be oddly malformed. This function specifically - * detects one case in particular: - * - * - When you have a file containing N rows - * - For some reason, the sum total of the number of rows over all pages for a given column - * is != N - * - * @param pages All pages to be decoded - * @param chunks Chunk data - * @param page_keys Keys (schema id) associated with each page, sorted by column - * @param page_index Page indices for iteration, sorted by column - * @param expected_row_count Expected row count, if applicable - * @param stream CUDA stream used for device memory operations and kernel launches - */ -void detect_malformed_pages(cudf::detail::hostdevice_vector& pages, - cudf::detail::hostdevice_vector const& chunks, - device_span page_keys, - device_span page_index, - std::optional expected_row_count, - rmm::cuda_stream_view stream) -{ - // sum row counts for all non-dictionary, non-list columns. other columns will be indicated as 0 - rmm::device_uvector row_counts(pages.size(), - stream); // worst case: num keys == num pages - auto const size_iter = thrust::make_transform_iterator( - page_index.begin(), flat_column_num_rows{pages.device_ptr(), chunks.device_ptr()}); - auto const row_counts_begin = row_counts.begin(); - auto const row_counts_end = thrust::reduce_by_key(rmm::exec_policy(stream), - page_keys.begin(), - page_keys.end(), - size_iter, - thrust::make_discard_iterator(), - row_counts_begin) - .second; - - // make sure all non-zero row counts are the same - rmm::device_uvector compacted_row_counts(pages.size(), stream); - auto const compacted_row_counts_begin = compacted_row_counts.begin(); - auto const compacted_row_counts_end = thrust::copy_if(rmm::exec_policy(stream), - row_counts_begin, - row_counts_end, - compacted_row_counts_begin, - row_counts_nonzero{}); - if (compacted_row_counts_end != compacted_row_counts_begin) { - size_t const found_row_count = static_cast(compacted_row_counts.element(0, stream)); - - // if we somehow don't match the expected row count from the row groups themselves - if (expected_row_count.has_value()) { - CUDF_EXPECTS(expected_row_count.value() == found_row_count, - "Encountered malformed parquet page data (unexpected row count in page data)"); - } - - // all non-zero row counts must be the same - auto const chk = - thrust::count_if(rmm::exec_policy(stream), - compacted_row_counts_begin, - compacted_row_counts_end, - row_counts_different{static_cast(found_row_count)}); - CUDF_EXPECTS(chk == 0, - "Encountered malformed parquet page data (row count mismatch in page data)"); - } -} - struct page_to_string_size { - PageInfo* pages; ColumnChunkDesc const* chunks; - __device__ size_t operator()(size_type page_idx) const + __device__ size_t operator()(PageInfo const& page) const { - auto const page = pages[page_idx]; auto const chunk = chunks[page.chunk_idx]; if (not is_string_col(chunk) || (page.flags & PAGEINFO_FLAGS_DICTIONARY) != 0) { return 0; } - return pages[page_idx].str_bytes; + return page.str_bytes; } }; struct page_offset_output_iter { PageInfo* p; - size_type const* index; using value_type = size_type; using difference_type = size_type; @@ -1166,78 +1045,148 @@ struct page_offset_output_iter { using reference = size_type&; using iterator_category = thrust::output_device_iterator_tag; - __host__ __device__ page_offset_output_iter operator+(int i) { return {p, index + i}; } + __host__ __device__ page_offset_output_iter operator+(int i) { return {p + i}; } __host__ __device__ page_offset_output_iter& operator++() { - index++; + p++; return *this; } - __device__ reference operator[](int i) { return p[index[i]].str_offset; } - __device__ reference operator*() { return p[*index].str_offset; } + __device__ reference operator[](int i) { return p[i].str_offset; } + __device__ reference operator*() { return p->str_offset; } }; +// update chunk_row field in subpass page from pass page +struct update_subpass_chunk_row { + device_span pass_pages; + device_span subpass_pages; + device_span page_src_index; -} // anonymous namespace + __device__ void operator()(size_t i) + { + subpass_pages[i].chunk_row = pass_pages[page_src_index[i]].chunk_row; + } +}; -void reader::impl::preprocess_pages(bool uses_custom_row_bounds, size_t chunk_read_limit) -{ - auto const skip_rows = _pass_itm_data->skip_rows; - auto const num_rows = _pass_itm_data->num_rows; - auto& chunks = _pass_itm_data->chunks; - auto& pages = _pass_itm_data->pages_info; +// update num_rows field from pass page to subpass page +struct update_pass_num_rows { + device_span pass_pages; + device_span subpass_pages; + device_span page_src_index; - // compute page ordering. - // - // ordering of pages is by input column schema, repeated across row groups. so - // if we had 3 columns, each with 2 pages, and 1 row group, our schema values might look like - // - // 1, 1, 2, 2, 3, 3 - // - // However, if we had more than one row group, the pattern would be - // - // 1, 1, 2, 2, 3, 3, 1, 1, 2, 2, 3, 3 - // ^ row group 0 | - // ^ row group 1 - // - // To process pages by key (exclusive_scan_by_key, reduce_by_key, etc), the ordering we actually - // want is - // - // 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3 - // - // We also need to preserve key-relative page ordering, so we need to use a stable sort. - rmm::device_uvector page_keys(pages.size(), _stream); - rmm::device_uvector page_index(pages.size(), _stream); + __device__ void operator()(size_t i) { - thrust::transform(rmm::exec_policy(_stream), - pages.device_ptr(), - pages.device_ptr() + pages.size(), - page_keys.begin(), - get_page_column_index{chunks.device_ptr()}); + pass_pages[page_src_index[i]].num_rows = subpass_pages[i].num_rows; + } +}; - thrust::sequence(rmm::exec_policy(_stream), page_index.begin(), page_index.end()); - thrust::stable_sort_by_key(rmm::exec_policy(_stream), - page_keys.begin(), - page_keys.end(), - page_index.begin(), - thrust::less()); +} // anonymous namespace + +void reader::impl::preprocess_file( + int64_t skip_rows, + std::optional const& num_rows, + host_span const> row_group_indices, + std::optional> filter) +{ + CUDF_EXPECTS(!_file_preprocessed, "Attempted to preprocess file more than once"); + + // if filter is not empty, then create output types as vector and pass for filtering. + std::vector output_types; + if (filter.has_value()) { + std::transform(_output_buffers.cbegin(), + _output_buffers.cend(), + std::back_inserter(output_types), + [](auto const& col) { return col.type; }); + } + std::tie( + _file_itm_data.global_skip_rows, _file_itm_data.global_num_rows, _file_itm_data.row_groups) = + _metadata->select_row_groups( + row_group_indices, skip_rows, num_rows, output_types, filter, _stream); + + if (_file_itm_data.global_num_rows > 0 && not _file_itm_data.row_groups.empty() && + not _input_columns.empty()) { + // fills in chunk information without physically loading or decompressing + // the associated data + create_global_chunk_info(); + + // compute schedule of input reads. + compute_input_passes(); + } + +#if defined(PARQUET_CHUNK_LOGGING) + printf("==============================================\n"); + setlocale(LC_NUMERIC, ""); + printf("File: skip_rows(%'lu), num_rows(%'lu), input_read_limit(%'lu), output_read_limit(%'lu)\n", + _file_itm_data.global_skip_rows, + _file_itm_data.global_num_rows, + _input_pass_read_limit, + _output_chunk_read_limit); + printf("# Row groups: %'lu\n", _file_itm_data.row_groups.size()); + printf("# Input passes: %'lu\n", _file_itm_data.num_passes()); + printf("# Input columns: %'lu\n", _input_columns.size()); + for (size_t idx = 0; idx < _input_columns.size(); idx++) { + auto const& schema = _metadata->get_schema(_input_columns[idx].schema_idx); + auto const type_id = to_type_id(schema, _strings_to_categorical, _timestamp_type.id()); + printf("\tC(%'lu, %s): %s\n", + idx, + _input_columns[idx].name.c_str(), + cudf::type_to_name(cudf::data_type{type_id}).c_str()); + } + printf("# Output columns: %'lu\n", _output_buffers.size()); + for (size_t idx = 0; idx < _output_buffers.size(); idx++) { + printf("\tC(%'lu): %s\n", idx, cudf::io::detail::type_to_name(_output_buffers[idx]).c_str()); } +#endif + + _file_preprocessed = true; +} + +void reader::impl::generate_list_column_row_count_estimates() +{ + auto& pass = *_pass_itm_data; + thrust::for_each(rmm::exec_policy(_stream), + pass.pages.d_begin(), + pass.pages.d_end(), + set_list_row_count_estimate{pass.chunks}); + + // computes: + // PageInfo::chunk_row (the chunk-relative row index) for all pages in the pass. The start_row + // field in ColumnChunkDesc is the absolute row index for the whole file. chunk_row in PageInfo is + // relative to the beginning of the chunk. so in the kernels, chunk.start_row + page.chunk_row + // gives us the absolute row index + auto key_input = thrust::make_transform_iterator(pass.pages.d_begin(), get_page_chunk_idx{}); + auto page_input = thrust::make_transform_iterator(pass.pages.d_begin(), get_page_num_rows{}); + thrust::exclusive_scan_by_key(rmm::exec_policy_nosync(_stream), + key_input, + key_input + pass.pages.size(), + page_input, + chunk_row_output_iter{pass.pages.device_ptr()}); + + // finally, fudge the last page for each column such that it ends on the real known row count + // for the pass. this is so that as we march through the subpasses, we will find that every column + // cleanly ends up the expected row count at the row group boundary. + auto const& last_chunk = pass.chunks[pass.chunks.size() - 1]; + auto const num_columns = _input_columns.size(); + size_t const max_row = last_chunk.start_row + last_chunk.num_rows; + auto iter = thrust::make_counting_iterator(0); + thrust::for_each(rmm::exec_policy_nosync(_stream), + iter, + iter + num_columns, + set_final_row_count{pass.pages, pass.chunks, pass.page_offsets, max_row}); + + pass.chunks.device_to_host_async(_stream); + pass.pages.device_to_host_async(_stream); + _stream.synchronize(); +} + +void reader::impl::preprocess_subpass_pages(bool uses_custom_row_bounds, size_t chunk_read_limit) +{ + auto& pass = *_pass_itm_data; + auto& subpass = *pass.subpass; - // detect malformed columns. - // - we have seen some cases in the wild where we have a row group containing N - // rows, but the total number of rows in the pages for column X is != N. while it - // is possible to load this by just capping the number of rows read, we cannot tell - // which rows are invalid so we may be returning bad data. in addition, this mismatch - // confuses the chunked reader - detect_malformed_pages(pages, - chunks, - page_keys, - page_index, - uses_custom_row_bounds ? std::nullopt : std::make_optional(num_rows), - _stream); - - // iterate over all input columns and determine if they contain lists so we can further - // preprocess them. + // iterate over all input columns and determine if they contain lists. + // TODO: we could do this once at the file level instead of every time we get in here. the set of + // columns we are processing does not change over multiple passes/subpasses/output chunks. bool has_lists = false; for (size_t idx = 0; idx < _input_columns.size(); idx++) { auto const& input_col = _input_columns[idx]; @@ -1258,49 +1207,9 @@ void reader::impl::preprocess_pages(bool uses_custom_row_bounds, size_t chunk_re if (has_lists) { break; } } - // generate string dict indices if necessary - { - auto is_dict_chunk = [](ColumnChunkDesc const& chunk) { - return (chunk.data_type & 0x7) == BYTE_ARRAY && chunk.num_dict_pages > 0; - }; - - // Count the number of string dictionary entries - // NOTE: Assumes first page in the chunk is always the dictionary page - size_t total_str_dict_indexes = 0; - for (size_t c = 0, page_count = 0; c < chunks.size(); c++) { - if (is_dict_chunk(chunks[c])) { - total_str_dict_indexes += pages[page_count].num_input_values; - } - page_count += chunks[c].max_num_pages; - } - - // Build index for string dictionaries since they can't be indexed - // directly due to variable-sized elements - _pass_itm_data->str_dict_index = - cudf::detail::make_zeroed_device_uvector_async( - total_str_dict_indexes, _stream, rmm::mr::get_current_device_resource()); - - // Update chunks with pointers to string dict indices - for (size_t c = 0, page_count = 0, str_ofs = 0; c < chunks.size(); c++) { - input_column_info const& input_col = _input_columns[chunks[c].src_col_index]; - CUDF_EXPECTS(input_col.schema_idx == chunks[c].src_col_schema, - "Column/page schema index mismatch"); - if (is_dict_chunk(chunks[c])) { - chunks[c].str_dict_index = _pass_itm_data->str_dict_index.data() + str_ofs; - str_ofs += pages[page_count].num_input_values; - } - - // column_data_base will always point to leaf data, even for nested types. - page_count += chunks[c].max_num_pages; - } - - if (total_str_dict_indexes > 0) { - chunks.host_to_device_async(_stream); - BuildStringDictionaryIndex(chunks.device_ptr(), chunks.size(), _stream); - } - } - - // intermediate data we will need for further chunked reads + // in some cases we will need to do further preprocessing of pages. + // - if we have lists, the num_rows field in PageInfo will be incorrect coming out of the file + // - if we are doing a chunked read, we need to compute the size of all string data if (has_lists || chunk_read_limit > 0) { // computes: // PageNestingInfo::num_rows for each page. the true number of rows (taking repetition into @@ -1311,48 +1220,92 @@ void reader::impl::preprocess_pages(bool uses_custom_row_bounds, size_t chunk_re // if: // - user has passed custom row bounds // - we will be doing a chunked read - ComputePageSizes(pages, - chunks, + ComputePageSizes(subpass.pages, + pass.chunks, 0, // 0-max size_t. process all possible rows std::numeric_limits::max(), true, // compute num_rows chunk_read_limit > 0, // compute string sizes _pass_itm_data->level_type_size, _stream); + } - // computes: - // PageInfo::chunk_row (the absolute start row index) for all pages - // Note: this is doing some redundant work for pages in flat hierarchies. chunk_row has already - // been computed during header decoding. the overall amount of work here is very small though. - auto key_input = thrust::make_transform_iterator(pages.device_ptr(), get_page_chunk_idx{}); - auto page_input = thrust::make_transform_iterator(pages.device_ptr(), get_page_num_rows{}); - thrust::exclusive_scan_by_key(rmm::exec_policy(_stream), - key_input, - key_input + pages.size(), - page_input, - chunk_row_output_iter{pages.device_ptr()}); - - // retrieve pages back - pages.device_to_host_sync(_stream); + // copy our now-correct row counts back to the base pages stored in the pass. + auto iter = thrust::make_counting_iterator(0); + thrust::for_each(rmm::exec_policy_nosync(_stream), + iter, + iter + subpass.pages.size(), + update_pass_num_rows{pass.pages, subpass.pages, subpass.page_src_index}); - // print_pages(pages, _stream); - } + // computes: + // PageInfo::chunk_row (the chunk-relative row index) for all pages in the pass. The start_row + // field in ColumnChunkDesc is the absolute row index for the whole file. chunk_row in PageInfo is + // relative to the beginning of the chunk. so in the kernels, chunk.start_row + page.chunk_row + // gives us the absolute row index + auto key_input = thrust::make_transform_iterator(pass.pages.d_begin(), get_page_chunk_idx{}); + auto page_input = thrust::make_transform_iterator(pass.pages.d_begin(), get_page_num_rows{}); + thrust::exclusive_scan_by_key(rmm::exec_policy_nosync(_stream), + key_input, + key_input + pass.pages.size(), + page_input, + chunk_row_output_iter{pass.pages.device_ptr()}); + + // copy chunk row into the subpass pages + thrust::for_each(rmm::exec_policy_nosync(_stream), + iter, + iter + subpass.pages.size(), + update_subpass_chunk_row{pass.pages, subpass.pages, subpass.page_src_index}); + + // retrieve pages back + pass.pages.device_to_host_async(_stream); + subpass.pages.device_to_host_async(_stream); + _stream.synchronize(); - // preserve page ordering data for string decoder - _pass_itm_data->page_keys = std::move(page_keys); - _pass_itm_data->page_index = std::move(page_index); + // at this point we have an accurate row count so we can compute how many rows we will actually be + // able to decode for this pass. we will have selected a set of pages for each column in the + // row group, but not every page will have the same number of rows. so, we can only read as many + // rows as the smallest batch (by column) we have decompressed. + size_t page_index = 0; + size_t max_row = std::numeric_limits::max(); + auto const last_pass_row = + _file_itm_data.input_pass_start_row_count[_file_itm_data._current_input_pass + 1]; + for (size_t idx = 0; idx < subpass.column_page_count.size(); idx++) { + auto const& last_page = subpass.pages[page_index + (subpass.column_page_count[idx] - 1)]; + auto const& chunk = pass.chunks[last_page.chunk_idx]; + + size_t max_col_row = + static_cast(chunk.start_row + last_page.chunk_row + last_page.num_rows); + // special case. list rows can span page boundaries, but we can't tell if that is happening + // here because we have not yet decoded the pages. the very last row starting in the page may + // not terminate in the page. to handle this, only decode up to the second to last row in the + // subpass since we know that will safely completed. + bool const is_list = chunk.max_level[level_type::REPETITION] > 0; + if (is_list && max_col_row < last_pass_row) { + size_t const min_col_row = static_cast(chunk.start_row + last_page.chunk_row); + CUDF_EXPECTS((max_col_row - min_col_row) > 1, "Unexpected short subpass"); + max_col_row--; + } + + max_row = min(max_row, max_col_row); + + page_index += subpass.column_page_count[idx]; + } + subpass.skip_rows = pass.skip_rows + pass.processed_rows; + auto const pass_end = pass.skip_rows + pass.num_rows; + max_row = min(max_row, pass_end); + subpass.num_rows = max_row - subpass.skip_rows; - // compute splits for the pass - compute_splits_for_pass(); + // now split up the output into chunks as necessary + compute_output_chunks_for_subpass(); } void reader::impl::allocate_columns(size_t skip_rows, size_t num_rows, bool uses_custom_row_bounds) { - auto const& chunks = _pass_itm_data->chunks; - auto& pages = _pass_itm_data->pages_info; + auto& pass = *_pass_itm_data; + auto& subpass = *pass.subpass; // Should not reach here if there is no page data. - CUDF_EXPECTS(pages.size() > 0, "There is no page to parse"); + CUDF_EXPECTS(subpass.pages.size() > 0, "There are no pages present in the subpass"); // computes: // PageNestingInfo::batch_size for each level of nesting, for each page, taking row bounds into @@ -1360,13 +1313,13 @@ void reader::impl::allocate_columns(size_t skip_rows, size_t num_rows, bool uses // respect the user bounds. It is only necessary to do this second pass if uses_custom_row_bounds // is set (if the user has specified artificial bounds). if (uses_custom_row_bounds) { - ComputePageSizes(pages, - chunks, + ComputePageSizes(subpass.pages, + pass.chunks, skip_rows, num_rows, false, // num_rows is already computed false, // no need to compute string sizes - _pass_itm_data->level_type_size, + pass.level_type_size, _stream); // print_pages(pages, _stream); @@ -1403,8 +1356,6 @@ void reader::impl::allocate_columns(size_t skip_rows, size_t num_rows, bool uses // compute output column sizes by examining the pages of the -input- columns if (has_lists) { - auto& page_index = _pass_itm_data->page_index; - std::vector h_cols_info; h_cols_info.reserve(_input_columns.size()); std::transform(_input_columns.cbegin(), @@ -1423,7 +1374,7 @@ void reader::impl::allocate_columns(size_t skip_rows, size_t num_rows, bool uses auto const d_cols_info = cudf::detail::make_device_uvector_async( h_cols_info, _stream, rmm::mr::get_current_device_resource()); - auto const num_keys = _input_columns.size() * max_depth * pages.size(); + auto const num_keys = _input_columns.size() * max_depth * subpass.pages.size(); // size iterator. indexes pages by sorted order rmm::device_uvector size_input{num_keys, _stream}; thrust::transform( @@ -1432,9 +1383,9 @@ void reader::impl::allocate_columns(size_t skip_rows, size_t num_rows, bool uses thrust::make_counting_iterator(num_keys), size_input.begin(), get_page_nesting_size{ - d_cols_info.data(), max_depth, pages.size(), pages.device_ptr(), page_index.begin()}); + d_cols_info.data(), max_depth, subpass.pages.size(), subpass.pages.d_begin()}); auto const reduction_keys = - cudf::detail::make_counting_transform_iterator(0, get_reduction_key{pages.size()}); + cudf::detail::make_counting_transform_iterator(0, get_reduction_key{subpass.pages.size()}); cudf::detail::hostdevice_vector sizes{_input_columns.size() * max_depth, _stream}; // find the size of each column @@ -1452,7 +1403,7 @@ void reader::impl::allocate_columns(size_t skip_rows, size_t num_rows, bool uses reduction_keys + num_keys, size_input.cbegin(), start_offset_output_iterator{ - pages.device_ptr(), page_index.begin(), 0, d_cols_info.data(), max_depth, pages.size()}); + subpass.pages.d_begin(), 0, d_cols_info.data(), max_depth, subpass.pages.size()}); sizes.device_to_host_sync(_stream); for (size_type idx = 0; idx < static_cast(_input_columns.size()); idx++) { @@ -1483,30 +1434,30 @@ void reader::impl::allocate_columns(size_t skip_rows, size_t num_rows, bool uses std::vector reader::impl::calculate_page_string_offsets() { - auto& chunks = _pass_itm_data->chunks; - auto& pages = _pass_itm_data->pages_info; - auto const& page_keys = _pass_itm_data->page_keys; - auto const& page_index = _pass_itm_data->page_index; + auto& pass = *_pass_itm_data; + auto& subpass = *pass.subpass; + + auto page_keys = make_page_key_iterator(subpass.pages); std::vector col_sizes(_input_columns.size(), 0L); rmm::device_uvector d_col_sizes(col_sizes.size(), _stream); // use page_index to fetch page string sizes in the proper order - auto val_iter = thrust::make_transform_iterator( - page_index.begin(), page_to_string_size{pages.device_ptr(), chunks.device_ptr()}); + auto val_iter = thrust::make_transform_iterator(subpass.pages.d_begin(), + page_to_string_size{pass.chunks.d_begin()}); // do scan by key to calculate string offsets for each page thrust::exclusive_scan_by_key(rmm::exec_policy_nosync(_stream), - page_keys.begin(), - page_keys.end(), + page_keys, + page_keys + subpass.pages.size(), val_iter, - page_offset_output_iter{pages.device_ptr(), page_index.data()}); + page_offset_output_iter{subpass.pages.device_ptr()}); // now sum up page sizes rmm::device_uvector reduce_keys(col_sizes.size(), _stream); thrust::reduce_by_key(rmm::exec_policy_nosync(_stream), - page_keys.begin(), - page_keys.end(), + page_keys, + page_keys + subpass.pages.size(), val_iter, reduce_keys.begin(), d_col_sizes.begin()); diff --git a/cpp/src/io/utilities/column_buffer.cpp b/cpp/src/io/utilities/column_buffer.cpp index 36303a60aa9..951217dc442 100644 --- a/cpp/src/io/utilities/column_buffer.cpp +++ b/cpp/src/io/utilities/column_buffer.cpp @@ -26,6 +26,9 @@ #include +#include +#include + namespace cudf::io::detail { void gather_column_buffer::allocate_strings_data(rmm::cuda_stream_view stream) @@ -129,6 +132,30 @@ string_policy column_buffer_base::empty_like(string_policy const& return new_buff; } +template +std::string type_to_name(column_buffer_base const& buffer) +{ + if (buffer.type.id() == cudf::type_id::LIST) { + return "List<" + (type_to_name(buffer.children[0])) + ">"; + } + + if (buffer.type.id() == cudf::type_id::STRUCT) { + std::ostringstream out; + + out << "Struct<"; + auto iter = thrust::make_counting_iterator(0); + std::transform( + iter, + iter + buffer.children.size(), + std::ostream_iterator(out, ","), + [&buffer](size_type i) { return type_to_name(buffer.children[i]); }); + out << ">"; + return out.str(); + } + + return cudf::type_to_name(buffer.type); +} + template std::unique_ptr make_column(column_buffer_base& buffer, column_name_info* schema_info, @@ -336,6 +363,10 @@ template std::unique_ptr empty_like(pointer_column_buffer& rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr); +template std::string type_to_name(string_column_buffer const& buffer); +template std::string type_to_name(pointer_column_buffer const& buffer); + template class column_buffer_base; template class column_buffer_base; + } // namespace cudf::io::detail diff --git a/cpp/src/io/utilities/column_buffer.hpp b/cpp/src/io/utilities/column_buffer.hpp index 2ee7c17e480..57ee1043ee9 100644 --- a/cpp/src/io/utilities/column_buffer.hpp +++ b/cpp/src/io/utilities/column_buffer.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -253,6 +253,16 @@ std::unique_ptr empty_like(column_buffer_base& buffer, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr); +/** + * @brief Given a column_buffer, produce a formatted name string describing the type. + * + * @param buffer The column buffer + * + * @return A string describing the type of the buffer suitable for printing + */ +template +std::string type_to_name(column_buffer_base const& buffer); + } // namespace detail } // namespace io } // namespace cudf diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 24085eb5e10..d40b2410ca3 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -293,7 +293,7 @@ ConfigureTest( ConfigureTest( PARQUET_TEST io/parquet_test.cpp - io/parquet_chunked_reader_test.cpp + io/parquet_chunked_reader_test.cu io/parquet_chunked_writer_test.cpp io/parquet_common.cpp io/parquet_misc_test.cpp diff --git a/cpp/tests/io/parquet_chunked_reader_test.cpp b/cpp/tests/io/parquet_chunked_reader_test.cu similarity index 73% rename from cpp/tests/io/parquet_chunked_reader_test.cpp rename to cpp/tests/io/parquet_chunked_reader_test.cu index 05fb9a3ec48..dea44f0e7c3 100644 --- a/cpp/tests/io/parquet_chunked_reader_test.cpp +++ b/cpp/tests/io/parquet_chunked_reader_test.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,8 @@ * limitations under the License. */ +#include "parquet_common.hpp" + #include #include #include @@ -44,14 +46,12 @@ #include #include +#include #include #include namespace { -// Global environment for temporary files -auto const temp_env = static_cast( - ::testing::AddGlobalTestEnvironment(new cudf::test::TempDirTestEnvironment)); using int32s_col = cudf::test::fixed_width_column_wrapper; using int64s_col = cudf::test::fixed_width_column_wrapper; @@ -953,64 +953,296 @@ TEST_F(ParquetChunkedReaderTest, TestChunkedReadNullCount) } while (reader.has_next()); } -TEST_F(ParquetChunkedReaderTest, InputLimitSimple) +constexpr size_t input_limit_expected_file_count = 4; + +std::vector input_limit_get_test_names(std::string const& base_filename) { - auto const filepath = temp_env->get_temp_filepath("input_limit_10_rowgroups.parquet"); - - // This results in 10 grow groups, at 4001150 bytes per row group - constexpr int num_rows = 25'000'000; - auto value_iter = cudf::detail::make_counting_transform_iterator(0, [](int i) { return i; }); - cudf::test::fixed_width_column_wrapper expected(value_iter, value_iter + num_rows); - cudf::io::parquet_writer_options opts = - cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, - cudf::table_view{{expected}}) - // note: it is unnecessary to force compression to NONE here because the size we are using in - // the row group is the uncompressed data size. But forcing the dictionary policy to - // dictionary_policy::NEVER is necessary to prevent changes in the - // decompressed-but-not-yet-decoded data. - .dictionary_policy(cudf::io::dictionary_policy::NEVER); - - cudf::io::write_parquet(opts); - - { - // no chunking - auto const [result, num_chunks] = chunked_read(filepath, 0, 0); - EXPECT_EQ(num_chunks, 1); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->get_column(0)); - } + return {base_filename + "_a.parquet", + base_filename + "_b.parquet", + base_filename + "_c.parquet", + base_filename + "_d.parquet"}; +} - { - // 25 chunks of 100k rows each - auto const [result, num_chunks] = chunked_read(filepath, 0, 1); - EXPECT_EQ(num_chunks, 25); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->get_column(0)); - } +void input_limit_test_write_one(std::string const& filepath, + cudf::table_view const& t, + cudf::io::compression_type compression, + cudf::io::dictionary_policy dict_policy) +{ + cudf::io::parquet_writer_options out_opts = + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, t) + .compression(compression) + .dictionary_policy(dict_policy); + cudf::io::write_parquet(out_opts); +} - { - // 25 chunks of 100k rows each - auto const [result, num_chunks] = chunked_read(filepath, 0, 4000000); - EXPECT_EQ(num_chunks, 25); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->get_column(0)); - } +void input_limit_test_write(std::vector const& test_filenames, + cudf::table_view const& t) +{ + CUDF_EXPECTS(test_filenames.size() == 4, "Unexpected count of test filenames"); + CUDF_EXPECTS(test_filenames.size() == input_limit_expected_file_count, + "Unexpected count of test filenames"); + + // no compression + input_limit_test_write_one( + test_filenames[0], t, cudf::io::compression_type::NONE, cudf::io::dictionary_policy::NEVER); + // compression with a codec that uses a lot of scratch space at decode time (2.5x the total + // decompressed buffer size) + input_limit_test_write_one( + test_filenames[1], t, cudf::io::compression_type::ZSTD, cudf::io::dictionary_policy::NEVER); + // compression with a codec that uses no scratch space at decode time + input_limit_test_write_one( + test_filenames[2], t, cudf::io::compression_type::SNAPPY, cudf::io::dictionary_policy::NEVER); + input_limit_test_write_one( + test_filenames[3], t, cudf::io::compression_type::SNAPPY, cudf::io::dictionary_policy::ALWAYS); +} - { - // 25 chunks of 100k rows each - auto const [result, num_chunks] = chunked_read(filepath, 0, 4100000); - EXPECT_EQ(num_chunks, 25); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->get_column(0)); - } +void input_limit_test_read(std::vector const& test_filenames, + cudf::table_view const& t, + size_t output_limit, + size_t input_limit, + int const expected_chunk_counts[input_limit_expected_file_count]) +{ + CUDF_EXPECTS(test_filenames.size() == input_limit_expected_file_count, + "Unexpected count of test filenames"); - { - // 12 chunks of 200k rows each, plus 1 final chunk of 100k rows. - auto const [result, num_chunks] = chunked_read(filepath, 0, 8002301); - EXPECT_EQ(num_chunks, 13); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->get_column(0)); + for (size_t idx = 0; idx < test_filenames.size(); idx++) { + auto result = chunked_read(test_filenames[idx], output_limit, input_limit); + CUDF_EXPECTS(result.second == expected_chunk_counts[idx], + "Unexpected number of chunks produced in chunk read"); + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*result.first, t); } +} + +struct ParquetChunkedReaderInputLimitConstrainedTest : public cudf::test::BaseFixture {}; + +TEST_F(ParquetChunkedReaderInputLimitConstrainedTest, SingleFixedWidthColumn) +{ + auto base_path = temp_env->get_temp_filepath("single_col_fixed_width"); + auto test_filenames = input_limit_get_test_names(base_path); + + constexpr auto num_rows = 1'000'000; + auto iter1 = thrust::make_constant_iterator(15); + cudf::test::fixed_width_column_wrapper col1(iter1, iter1 + num_rows); + auto tbl = cudf::table_view{{col1}}; + + input_limit_test_write(test_filenames, tbl); + + // semi-reasonable limit + constexpr int expected_a[] = {1, 17, 4, 1}; + input_limit_test_read(test_filenames, tbl, 0, 2 * 1024 * 1024, expected_a); + // an unreasonable limit + constexpr int expected_b[] = {1, 50, 50, 1}; + input_limit_test_read(test_filenames, tbl, 0, 1, expected_b); +} + +TEST_F(ParquetChunkedReaderInputLimitConstrainedTest, MixedColumns) +{ + auto base_path = temp_env->get_temp_filepath("mixed_columns"); + auto test_filenames = input_limit_get_test_names(base_path); + + constexpr auto num_rows = 1'000'000; + + auto iter1 = thrust::make_counting_iterator(0); + cudf::test::fixed_width_column_wrapper col1(iter1, iter1 + num_rows); + + auto iter2 = thrust::make_counting_iterator(0); + cudf::test::fixed_width_column_wrapper col2(iter2, iter2 + num_rows); + + auto const strings = std::vector{"abc", "de", "fghi"}; + auto const str_iter = cudf::detail::make_counting_transform_iterator(0, [&](int32_t i) { + if (i < 250000) { return strings[0]; } + if (i < 750000) { return strings[1]; } + return strings[2]; + }); + auto col3 = strings_col(str_iter, str_iter + num_rows); + + auto tbl = cudf::table_view{{col1, col2, col3}}; + + input_limit_test_write(test_filenames, tbl); + constexpr int expected_a[] = {1, 50, 10, 7}; + input_limit_test_read(test_filenames, tbl, 0, 2 * 1024 * 1024, expected_a); + constexpr int expected_b[] = {1, 50, 50, 50}; + input_limit_test_read(test_filenames, tbl, 0, 1, expected_b); +} + +struct ParquetChunkedReaderInputLimitTest : public cudf::test::BaseFixture {}; + +struct offset_gen { + int const group_size; + __device__ int operator()(int i) { return i * group_size; } +}; + +template +struct value_gen { + __device__ T operator()(int i) { return i % 1024; } +}; +TEST_F(ParquetChunkedReaderInputLimitTest, List) +{ + auto base_path = temp_env->get_temp_filepath("list"); + auto test_filenames = input_limit_get_test_names(base_path); + + constexpr int num_rows = 50'000'000; + constexpr int list_size = 4; + + auto const stream = cudf::get_default_stream(); + + auto offset_iter = cudf::detail::make_counting_transform_iterator(0, offset_gen{list_size}); + auto offset_col = cudf::make_fixed_width_column( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, cudf::mask_state::UNALLOCATED); + thrust::copy(rmm::exec_policy(stream), + offset_iter, + offset_iter + num_rows + 1, + offset_col->mutable_view().begin()); + + // list + constexpr int num_ints = num_rows * list_size; + auto value_iter = cudf::detail::make_counting_transform_iterator(0, value_gen{}); + auto value_col = cudf::make_fixed_width_column( + cudf::data_type{cudf::type_id::INT32}, num_ints, cudf::mask_state::UNALLOCATED); + thrust::copy(rmm::exec_policy(stream), + value_iter, + value_iter + num_ints, + value_col->mutable_view().begin()); + auto col1 = + cudf::make_lists_column(num_rows, + std::move(offset_col), + std::move(value_col), + 0, + cudf::create_null_mask(num_rows, cudf::mask_state::UNALLOCATED), + stream); + + auto tbl = cudf::table_view{{*col1}}; + + input_limit_test_write(test_filenames, tbl); + + // even though we have a very large limit here, there are two cases where we actually produce + // splits. + // - uncompressed data (with no dict). This happens because the code has to make a guess at how + // much + // space to reserve for compressed/uncompressed data prior to reading. It does not know that + // everything it will be reading in this case is uncompressed already, so this guess ends up + // causing it to generate two top level passes. in practice, this shouldn't matter because we + // never really see uncompressed data in the wild. + // + // - ZSTD (with no dict). In this case, ZSTD simple requires a huge amount of temporary + // space: 2.5x the total + // size of the decompressed data. so 2 GB is actually not enough to hold the whole thing at + // once. + // + // Note that in the dictionary cases, both of these revert down to 1 chunk because the + // dictionaries dramatically shrink the size of the uncompressed data. + constexpr int expected_a[] = {2, 2, 1, 1}; + input_limit_test_read(test_filenames, tbl, 0, size_t{2} * 1024 * 1024 * 1024, expected_a); + // smaller limit + constexpr int expected_b[] = {6, 6, 2, 1}; + input_limit_test_read(test_filenames, tbl, 0, 512 * 1024 * 1024, expected_b); + // include output chunking as well + constexpr int expected_c[] = {11, 11, 9, 8}; + input_limit_test_read(test_filenames, tbl, 128 * 1024 * 1024, 512 * 1024 * 1024, expected_c); +} + +struct char_values { + __device__ int8_t operator()(int i) { - // 1 big chunk - auto const [result, num_chunks] = chunked_read(filepath, 0, size_t{1} * 1024 * 1024 * 1024); - EXPECT_EQ(num_chunks, 1); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->get_column(0)); + int const index = (i / 2) % 3; + // generate repeating 3-runs of 2 values each. aabbcc + return index == 0 ? 'a' : (index == 1 ? 'b' : 'c'); } +}; +TEST_F(ParquetChunkedReaderInputLimitTest, Mixed) +{ + auto base_path = temp_env->get_temp_filepath("mixed_types"); + auto test_filenames = input_limit_get_test_names(base_path); + + constexpr int num_rows = 50'000'000; + constexpr int list_size = 4; + constexpr int str_size = 3; + + auto const stream = cudf::get_default_stream(); + + auto offset_iter = cudf::detail::make_counting_transform_iterator(0, offset_gen{list_size}); + auto offset_col = cudf::make_fixed_width_column( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, cudf::mask_state::UNALLOCATED); + thrust::copy(rmm::exec_policy(stream), + offset_iter, + offset_iter + num_rows + 1, + offset_col->mutable_view().begin()); + + // list + constexpr int num_ints = num_rows * list_size; + auto value_iter = cudf::detail::make_counting_transform_iterator(0, value_gen{}); + auto value_col = cudf::make_fixed_width_column( + cudf::data_type{cudf::type_id::INT32}, num_ints, cudf::mask_state::UNALLOCATED); + thrust::copy(rmm::exec_policy(stream), + value_iter, + value_iter + num_ints, + value_col->mutable_view().begin()); + auto col1 = + cudf::make_lists_column(num_rows, + std::move(offset_col), + std::move(value_col), + 0, + cudf::create_null_mask(num_rows, cudf::mask_state::UNALLOCATED), + stream); + + // strings + constexpr int num_chars = num_rows * str_size; + auto str_offset_iter = cudf::detail::make_counting_transform_iterator(0, offset_gen{str_size}); + auto str_offset_col = cudf::make_fixed_width_column( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, cudf::mask_state::UNALLOCATED); + thrust::copy(rmm::exec_policy(stream), + str_offset_iter, + str_offset_iter + num_rows + 1, + str_offset_col->mutable_view().begin()); + auto str_iter = cudf::detail::make_counting_transform_iterator(0, char_values{}); + rmm::device_buffer str_chars(num_chars, stream); + thrust::copy(rmm::exec_policy(stream), + str_iter, + str_iter + num_chars, + static_cast(str_chars.data())); + auto col2 = + cudf::make_strings_column(num_rows, + std::move(str_offset_col), + std::move(str_chars), + 0, + cudf::create_null_mask(num_rows, cudf::mask_state::UNALLOCATED)); + + // doubles + auto double_iter = cudf::detail::make_counting_transform_iterator(0, value_gen{}); + auto col3 = cudf::make_fixed_width_column( + cudf::data_type{cudf::type_id::FLOAT64}, num_rows, cudf::mask_state::UNALLOCATED); + thrust::copy(rmm::exec_policy(stream), + double_iter, + double_iter + num_rows, + col3->mutable_view().begin()); + + auto tbl = cudf::table_view{{*col1, *col2, *col3}}; + + input_limit_test_write(test_filenames, tbl); + + // even though we have a very large limit here, there are two cases where we actually produce + // splits. + // - uncompressed data (with no dict). This happens because the code has to make a guess at how + // much + // space to reserve for compressed/uncompressed data prior to reading. It does not know that + // everything it will be reading in this case is uncompressed already, so this guess ends up + // causing it to generate two top level passes. in practice, this shouldn't matter because we + // never really see uncompressed data in the wild. + // + // - ZSTD (with no dict). In this case, ZSTD simple requires a huge amount of temporary + // space: 2.5x the total + // size of the decompressed data. so 2 GB is actually not enough to hold the whole thing at + // once. + // + // Note that in the dictionary cases, both of these revert down to 1 chunk because the + // dictionaries dramatically shrink the size of the uncompressed data. + constexpr int expected_a[] = {3, 3, 1, 1}; + input_limit_test_read(test_filenames, tbl, 0, size_t{2} * 1024 * 1024 * 1024, expected_a); + // smaller limit + constexpr int expected_b[] = {10, 11, 4, 1}; + input_limit_test_read(test_filenames, tbl, 0, 512 * 1024 * 1024, expected_b); + // include output chunking as well + constexpr int expected_c[] = {20, 21, 15, 14}; + input_limit_test_read(test_filenames, tbl, 128 * 1024 * 1024, 512 * 1024 * 1024, expected_c); }