From a69a8a43b5e6d63b5bacd4b5ad03773a0277b2f1 Mon Sep 17 00:00:00 2001 From: Devavret Makkar Date: Fri, 23 Jul 2021 21:05:07 +0530 Subject: [PATCH] Add multi-thread reading to GDS reads (#8752) Adds a thread pool and new `device_read_async` API to datasource. This API has been added to parquet and ORC readers. `device_read` also now uses `device_read_async` with synchronization so using `device_read` with large enough size should also be faster now. As a result, Avro reader should also get the benefits of multi-threaded reading. Authors: - Devavret Makkar (https://github.com/devavret) Approvers: - Vukasin Milovanovic (https://github.com/vuule) - Ram (Ramakrishna Prabhu) ((https://github.com/rgsl888prabhu) URL: https://github.com/rapidsai/cudf/pull/8752 --- cpp/include/cudf/io/datasource.hpp | 29 ++ cpp/src/io/orc/reader_impl.cu | 13 +- cpp/src/io/parquet/reader_impl.cu | 33 +- cpp/src/io/parquet/reader_impl.hpp | 14 +- cpp/src/io/utilities/datasource.cpp | 13 +- cpp/src/io/utilities/file_io_utilities.cpp | 56 ++- cpp/src/io/utilities/file_io_utilities.hpp | 25 ++ cpp/src/io/utilities/thread_pool.hpp | 406 +++++++++++++++++++++ 8 files changed, 560 insertions(+), 29 deletions(-) create mode 100644 cpp/src/io/utilities/thread_pool.hpp diff --git a/cpp/include/cudf/io/datasource.hpp b/cpp/include/cudf/io/datasource.hpp index c1aff818121..93f68d43aff 100644 --- a/cpp/include/cudf/io/datasource.hpp +++ b/cpp/include/cudf/io/datasource.hpp @@ -30,6 +30,7 @@ #include #include +#include #include namespace cudf { @@ -209,6 +210,34 @@ class datasource { CUDF_FAIL("datasource classes that support device_read must override it."); } + /** + * @brief Asynchronously reads a selected range into a preallocated device buffer + * + * Returns a future value that contains the number of bytes read. Calling `get()` method of the + * return value synchronizes this function. + * + * For optimal performance, should only be called when `is_device_read_preferred` returns `true`. + * Data source implementations that don't support direct device reads don't need to override this + * function. + * + * @throws cudf::logic_error when the object does not support direct device reads, i.e. + * `supports_device_read` returns `false`. + * + * @param offset Number of bytes from the start + * @param size Number of bytes to read + * @param dst Address of the existing device memory + * @param stream CUDA stream to use + * + * @return The number of bytes read as a future value (can be smaller than size) + */ + virtual std::future device_read_async(size_t offset, + size_t size, + uint8_t* dst, + rmm::cuda_stream_view stream) + { + CUDF_FAIL("datasource classes that support device_read_async must override it."); + } + /** * @brief Returns the size of the data in the source. * diff --git a/cpp/src/io/orc/reader_impl.cu b/cpp/src/io/orc/reader_impl.cu index 8e47da98a7c..033a2d9aff5 100644 --- a/cpp/src/io/orc/reader_impl.cu +++ b/cpp/src/io/orc/reader_impl.cu @@ -1132,6 +1132,7 @@ table_with_metadata reader::impl::read(size_type skip_rows, size_t num_rowgroups = 0; int stripe_idx = 0; + std::vector, size_t>> read_tasks; for (auto const& stripe_source_mapping : selected_stripes) { // Iterate through the source files selected stripes for (auto const& stripe : stripe_source_mapping.stripe_info) { @@ -1170,10 +1171,11 @@ table_with_metadata reader::impl::read(size_type skip_rows, } if (_metadata->per_file_metadata[stripe_source_mapping.source_idx] .source->is_device_read_preferred(len)) { - CUDF_EXPECTS( - _metadata->per_file_metadata[stripe_source_mapping.source_idx].source->device_read( - offset, len, d_dst, stream) == len, - "Unexpected discrepancy in bytes read."); + read_tasks.push_back( + std::make_pair(_metadata->per_file_metadata[stripe_source_mapping.source_idx] + .source->device_read_async(offset, len, d_dst, stream), + len)); + } else { const auto buffer = _metadata->per_file_metadata[stripe_source_mapping.source_idx].source->host_read( @@ -1246,6 +1248,9 @@ table_with_metadata reader::impl::read(size_type skip_rows, stripe_idx++; } } + for (auto& task : read_tasks) { + CUDF_EXPECTS(task.first.get() == task.second, "Unexpected discrepancy in bytes read."); + } // Process dataset chunk pages into output columns if (stripe_data.size() != 0) { diff --git a/cpp/src/io/parquet/reader_impl.cu b/cpp/src/io/parquet/reader_impl.cu index 3bf11063035..9f9bdfd4755 100644 --- a/cpp/src/io/parquet/reader_impl.cu +++ b/cpp/src/io/parquet/reader_impl.cu @@ -823,7 +823,7 @@ void generate_depth_remappings(std::map, std::ve /** * @copydoc cudf::io::detail::parquet::read_column_chunks */ -void reader::impl::read_column_chunks( +std::future reader::impl::read_column_chunks( std::vector>& page_data, hostdevice_vector& chunks, // TODO const? size_t begin_chunk, @@ -833,6 +833,7 @@ void reader::impl::read_column_chunks( rmm::cuda_stream_view stream) { // Transfer chunk data, coalescing adjacent chunks + std::vector> read_tasks; for (size_t chunk = begin_chunk; chunk < end_chunk;) { const size_t io_offset = column_chunk_offsets[chunk]; size_t io_size = chunks[chunk].compressed_size; @@ -854,7 +855,11 @@ void reader::impl::read_column_chunks( if (io_size != 0) { auto& source = _sources[chunk_source_map[chunk]]; if (source->is_device_read_preferred(io_size)) { - page_data[chunk] = source->device_read(io_offset, io_size, stream); + auto buffer = rmm::device_buffer(io_size, stream); + auto fut_read_size = source->device_read_async( + io_offset, io_size, static_cast(buffer.data()), stream); + read_tasks.emplace_back(std::move(fut_read_size)); + page_data[chunk] = datasource::buffer::create(std::move(buffer)); } else { auto const buffer = source->host_read(io_offset, io_size); page_data[chunk] = @@ -869,6 +874,12 @@ void reader::impl::read_column_chunks( chunk = next_chunk; } } + auto sync_fn = [](decltype(read_tasks) read_tasks) { + for (auto& task : read_tasks) { + task.wait(); + } + }; + return std::async(std::launch::deferred, sync_fn, std::move(read_tasks)); } /** @@ -1435,6 +1446,7 @@ table_with_metadata reader::impl::read(size_type skip_rows, // Initialize column chunk information size_t total_decompressed_size = 0; auto remaining_rows = num_rows; + std::vector> read_rowgroup_tasks; for (const auto& rg : selected_row_groups) { const auto& row_group = _metadata->get_row_group(rg.index, rg.source_index); auto const row_group_start = rg.start_row; @@ -1502,16 +1514,19 @@ table_with_metadata reader::impl::read(size_type skip_rows, } } // Read compressed chunk data to device memory - read_column_chunks(page_data, - chunks, - io_chunk_idx, - chunks.size(), - column_chunk_offsets, - chunk_source_map, - stream); + read_rowgroup_tasks.push_back(read_column_chunks(page_data, + chunks, + io_chunk_idx, + chunks.size(), + column_chunk_offsets, + chunk_source_map, + stream)); remaining_rows -= row_group.num_rows; } + for (auto& task : read_rowgroup_tasks) { + task.wait(); + } assert(remaining_rows <= 0); // Process dataset chunk pages into output columns diff --git a/cpp/src/io/parquet/reader_impl.hpp b/cpp/src/io/parquet/reader_impl.hpp index b93107aa9b2..4c3bd75b724 100644 --- a/cpp/src/io/parquet/reader_impl.hpp +++ b/cpp/src/io/parquet/reader_impl.hpp @@ -91,13 +91,13 @@ class reader::impl { * @param stream CUDA stream used for device memory operations and kernel launches. * */ - void read_column_chunks(std::vector>& page_data, - hostdevice_vector& chunks, - size_t begin_chunk, - size_t end_chunk, - const std::vector& column_chunk_offsets, - std::vector const& chunk_source_map, - rmm::cuda_stream_view stream); + std::future read_column_chunks(std::vector>& page_data, + hostdevice_vector& chunks, + size_t begin_chunk, + size_t end_chunk, + const std::vector& column_chunk_offsets, + std::vector const& chunk_source_map, + rmm::cuda_stream_view stream); /** * @brief Returns the number of total pages from the given column chunks diff --git a/cpp/src/io/utilities/datasource.cpp b/cpp/src/io/utilities/datasource.cpp index 4b23d008344..7afffaede9e 100644 --- a/cpp/src/io/utilities/datasource.cpp +++ b/cpp/src/io/utilities/datasource.cpp @@ -41,7 +41,7 @@ class file_source : public datasource { bool supports_device_read() const override { return _cufile_in != nullptr; } - bool is_device_read_preferred(size_t size) const + bool is_device_read_preferred(size_t size) const override { return _cufile_in != nullptr && _cufile_in->is_cufile_io_preferred(size); } @@ -67,6 +67,17 @@ class file_source : public datasource { return _cufile_in->read(offset, read_size, dst, stream); } + std::future device_read_async(size_t offset, + size_t size, + uint8_t* dst, + rmm::cuda_stream_view stream) override + { + CUDF_EXPECTS(supports_device_read(), "Device reads are not supported for this file."); + + auto const read_size = std::min(size, _file.size() - offset); + return _cufile_in->read_async(offset, read_size, dst, stream); + } + size_t size() const override { return _file.size(); } protected: diff --git a/cpp/src/io/utilities/file_io_utilities.cpp b/cpp/src/io/utilities/file_io_utilities.cpp index b5fb9fb51bc..7b55bf82f15 100644 --- a/cpp/src/io/utilities/file_io_utilities.cpp +++ b/cpp/src/io/utilities/file_io_utilities.cpp @@ -14,12 +14,14 @@ * limitations under the License. */ #include "file_io_utilities.hpp" +#include #include #include #include +#include namespace cudf { namespace io { @@ -166,8 +168,11 @@ void cufile_registered_file::register_handle() cufile_registered_file::~cufile_registered_file() { shim->handle_deregister(cf_handle); } cufile_input_impl::cufile_input_impl(std::string const& filepath) - : shim{cufile_shim::instance()}, cf_file(shim, filepath, O_RDONLY | O_DIRECT) + : shim{cufile_shim::instance()}, + cf_file(shim, filepath, O_RDONLY | O_DIRECT), + pool(16) // The benefit from multithreaded read plateaus around 16 threads { + pool.sleep_duration = 10; } std::unique_ptr cufile_input_impl::read(size_t offset, @@ -175,21 +180,56 @@ std::unique_ptr cufile_input_impl::read(size_t offset, rmm::cuda_stream_view stream) { rmm::device_buffer out_data(size, stream); - CUDF_EXPECTS(shim->read(cf_file.handle(), out_data.data(), size, offset, 0) != -1, - "cuFile error reading from a file"); - + auto read_size = read(offset, size, reinterpret_cast(out_data.data()), stream); + out_data.resize(read_size, stream); return datasource::buffer::create(std::move(out_data)); } +std::future cufile_input_impl::read_async(size_t offset, + size_t size, + uint8_t* dst, + rmm::cuda_stream_view stream) +{ + int device; + cudaGetDevice(&device); + + auto read_slice = [=](void* dst, size_t size, size_t offset) -> ssize_t { + cudaSetDevice(device); + auto read_size = shim->read(cf_file.handle(), dst, size, offset, 0); + CUDF_EXPECTS(read_size != -1, "cuFile error reading from a file"); + return read_size; + }; + + std::vector> slice_tasks; + constexpr size_t max_slice_bytes = 4 * 1024 * 1024; + size_t n_slices = util::div_rounding_up_safe(size, max_slice_bytes); + size_t slice_size = max_slice_bytes; + size_t slice_offset = 0; + for (size_t t = 0; t < n_slices; ++t) { + void* dst_slice = dst + slice_offset; + + if (t == n_slices - 1) { slice_size = size % max_slice_bytes; } + slice_tasks.push_back(pool.submit(read_slice, dst_slice, slice_size, offset + slice_offset)); + + slice_offset += slice_size; + } + auto waiter = [](decltype(slice_tasks) slice_tasks) -> size_t { + return std::accumulate(slice_tasks.begin(), slice_tasks.end(), 0, [](auto sum, auto& task) { + return sum + task.get(); + }); + }; + // The future returned from this function is deferred, not async becasue we want to avoid creating + // threads for each read_async call. This overhead is significant in case of multiple small reads. + return std::async(std::launch::deferred, waiter, std::move(slice_tasks)); +} + size_t cufile_input_impl::read(size_t offset, size_t size, uint8_t* dst, rmm::cuda_stream_view stream) { - CUDF_EXPECTS(shim->read(cf_file.handle(), dst, size, offset, 0) != -1, - "cuFile error reading from a file"); - // always read the requested size for now - return size; + auto result = read_async(offset, size, dst, stream); + return result.get(); } cufile_output_impl::cufile_output_impl(std::string const& filepath) diff --git a/cpp/src/io/utilities/file_io_utilities.hpp b/cpp/src/io/utilities/file_io_utilities.hpp index e92191095e3..fdf8d012b0e 100644 --- a/cpp/src/io/utilities/file_io_utilities.hpp +++ b/cpp/src/io/utilities/file_io_utilities.hpp @@ -17,6 +17,8 @@ #pragma once #ifdef CUFILE_FOUND +#include "thread_pool.hpp" + #include #include #endif @@ -106,6 +108,23 @@ class cufile_input : public cufile_io_base { * @return The number of bytes read */ virtual size_t read(size_t offset, size_t size, uint8_t* dst, rmm::cuda_stream_view stream) = 0; + + /** + * @brief Asynchronously reads into existing device memory. + * + * @throws cudf::logic_error on cuFile error + * + * @param offset Number of bytes from the start + * @param size Number of bytes to read + * @param dst Address of the existing device memory + * @param stream CUDA stream to use + * + * @return The number of bytes read as an std::future + */ + virtual std::future read_async(size_t offset, + size_t size, + uint8_t* dst, + rmm::cuda_stream_view stream) = 0; }; /** @@ -202,9 +221,15 @@ class cufile_input_impl final : public cufile_input { size_t read(size_t offset, size_t size, uint8_t* dst, rmm::cuda_stream_view stream) override; + std::future read_async(size_t offset, + size_t size, + uint8_t* dst, + rmm::cuda_stream_view stream) override; + private: cufile_shim const* shim = nullptr; cufile_registered_file const cf_file; + cudf::detail::thread_pool pool; }; /** diff --git a/cpp/src/io/utilities/thread_pool.hpp b/cpp/src/io/utilities/thread_pool.hpp new file mode 100644 index 00000000000..33690c53758 --- /dev/null +++ b/cpp/src/io/utilities/thread_pool.hpp @@ -0,0 +1,406 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +/** + * Modified from https://github.com/bshoshany/thread-pool + * @copyright Copyright (c) 2021 Barak Shoshany. Licensed under the MIT license. + * See file LICENSE for detail or copy at https://opensource.org/licenses/MIT + */ + +#include // std::atomic +#include // std::chrono +#include // std::int_fast64_t, std::uint_fast32_t +#include // std::function +#include // std::future, std::promise +#include // std::shared_ptr, std::unique_ptr +#include // std::mutex, std::scoped_lock +#include // std::queue +#include // std::this_thread, std::thread +#include // std::decay_t, std::enable_if_t, std::is_void_v, std::invoke_result_t +#include // std::move, std::swap + +namespace cudf { +namespace detail { + +/** + * @brief A C++17 thread pool class. The user submits tasks to be executed into a queue. Whenever a + * thread becomes available, it pops a task from the queue and executes it. Each task is + * automatically assigned a future, which can be used to wait for the task to finish executing + * and/or obtain its eventual return value. + */ +class thread_pool { + typedef std::uint_fast32_t ui32; + + public: + /** + * @brief Construct a new thread pool. + * + * @param _thread_count The number of threads to use. The default value is the total number of + * hardware threads available, as reported by the implementation. With a hyperthreaded CPU, this + * will be twice the number of CPU cores. If the argument is zero, the default value will be used + * instead. + */ + thread_pool(const ui32& _thread_count = std::thread::hardware_concurrency()) + : thread_count(_thread_count ? _thread_count : std::thread::hardware_concurrency()), + threads(new std::thread[_thread_count ? _thread_count : std::thread::hardware_concurrency()]) + { + create_threads(); + } + + /** + * @brief Destruct the thread pool. Waits for all tasks to complete, then destroys all threads. + * Note that if the variable paused is set to true, then any tasks still in the queue will never + * be executed. + */ + ~thread_pool() + { + wait_for_tasks(); + running = false; + destroy_threads(); + } + + /** + * @brief Get the number of tasks currently waiting in the queue to be executed by the threads. + * + * @return The number of queued tasks. + */ + size_t get_tasks_queued() const + { + const std::scoped_lock lock(queue_mutex); + return tasks.size(); + } + + /** + * @brief Get the number of tasks currently being executed by the threads. + * + * @return The number of running tasks. + */ + ui32 get_tasks_running() const { return tasks_total - (ui32)get_tasks_queued(); } + + /** + * @brief Get the total number of unfinished tasks - either still in the queue, or running in a + * thread. + * + * @return The total number of tasks. + */ + ui32 get_tasks_total() const { return tasks_total; } + + /** + * @brief Get the number of threads in the pool. + * + * @return The number of threads. + */ + ui32 get_thread_count() const { return thread_count; } + + /** + * @brief Parallelize a loop by splitting it into blocks, submitting each block separately to the + * thread pool, and waiting for all blocks to finish executing. The loop will be equivalent to: + * for (T i = first_index; i <= last_index; i++) loop(i); + * + * @tparam T The type of the loop index. Should be a signed or unsigned integer. + * @tparam F The type of the function to loop through. + * @param first_index The first index in the loop (inclusive). + * @param last_index The last index in the loop (inclusive). + * @param loop The function to loop through. Should take exactly one argument, the loop index. + * @param num_tasks The maximum number of tasks to split the loop into. The default is to use the + * number of threads in the pool. + */ + template + void parallelize_loop(T first_index, T last_index, const F& loop, ui32 num_tasks = 0) + { + if (num_tasks == 0) num_tasks = thread_count; + if (last_index < first_index) std::swap(last_index, first_index); + size_t total_size = last_index - first_index + 1; + size_t block_size = total_size / num_tasks; + if (block_size == 0) { + block_size = 1; + num_tasks = (ui32)total_size > 1 ? (ui32)total_size : 1; + } + std::atomic blocks_running = 0; + for (ui32 t = 0; t < num_tasks; t++) { + T start = (T)(t * block_size + first_index); + T end = (t == num_tasks - 1) ? last_index : (T)((t + 1) * block_size + first_index - 1); + blocks_running++; + push_task([start, end, &loop, &blocks_running] { + for (T i = start; i <= end; i++) + loop(i); + blocks_running--; + }); + } + while (blocks_running != 0) { + sleep_or_yield(); + } + } + + /** + * @brief Push a function with no arguments or return value into the task queue. + * + * @tparam F The type of the function. + * @param task The function to push. + */ + template + void push_task(const F& task) + { + tasks_total++; + { + const std::scoped_lock lock(queue_mutex); + tasks.push(std::function(task)); + } + } + + /** + * @brief Push a function with arguments, but no return value, into the task queue. + * @details The function is wrapped inside a lambda in order to hide the arguments, as the tasks + * in the queue must be of type std::function, so they cannot have any arguments or return + * value. If no arguments are provided, the other overload will be used, in order to avoid the + * (slight) overhead of using a lambda. + * + * @tparam F The type of the function. + * @tparam A The types of the arguments. + * @param task The function to push. + * @param args The arguments to pass to the function. + */ + template + void push_task(const F& task, const A&... args) + { + push_task([task, args...] { task(args...); }); + } + + /** + * @brief Reset the number of threads in the pool. Waits for all currently running tasks to be + * completed, then destroys all threads in the pool and creates a new thread pool with the new + * number of threads. Any tasks that were waiting in the queue before the pool was reset will then + * be executed by the new threads. If the pool was paused before resetting it, the new pool will + * be paused as well. + * + * @param _thread_count The number of threads to use. The default value is the total number of + * hardware threads available, as reported by the implementation. With a hyperthreaded CPU, this + * will be twice the number of CPU cores. If the argument is zero, the default value will be used + * instead. + */ + void reset(const ui32& _thread_count = std::thread::hardware_concurrency()) + { + bool was_paused = paused; + paused = true; + wait_for_tasks(); + running = false; + destroy_threads(); + thread_count = _thread_count ? _thread_count : std::thread::hardware_concurrency(); + threads.reset(new std::thread[thread_count]); + paused = was_paused; + create_threads(); + running = true; + } + + /** + * @brief Submit a function with zero or more arguments and no return value into the task queue, + * and get an std::future that will be set to true upon completion of the task. + * + * @tparam F The type of the function. + * @tparam A The types of the zero or more arguments to pass to the function. + * @param task The function to submit. + * @param args The zero or more arguments to pass to the function. + * @return A future to be used later to check if the function has finished its execution. + */ + template , std::decay_t...>>>> + std::future submit(const F& task, const A&... args) + { + std::shared_ptr> promise(new std::promise); + std::future future = promise->get_future(); + push_task([task, args..., promise] { + try { + task(args...); + promise->set_value(true); + } catch (...) { + promise->set_exception(std::current_exception()); + }; + }); + return future; + } + + /** + * @brief Submit a function with zero or more arguments and a return value into the task queue, + * and get a future for its eventual returned value. + * + * @tparam F The type of the function. + * @tparam A The types of the zero or more arguments to pass to the function. + * @tparam R The return type of the function. + * @param task The function to submit. + * @param args The zero or more arguments to pass to the function. + * @return A future to be used later to obtain the function's returned value, waiting for it to + * finish its execution if needed. + */ + template , std::decay_t...>, + typename = std::enable_if_t>> + std::future submit(const F& task, const A&... args) + { + std::shared_ptr> promise(new std::promise); + std::future future = promise->get_future(); + push_task([task, args..., promise] { + try { + promise->set_value(task(args...)); + } catch (...) { + promise->set_exception(std::current_exception()); + }; + }); + return future; + } + + /** + * @brief Wait for tasks to be completed. Normally, this function waits for all tasks, both those + * that are currently running in the threads and those that are still waiting in the queue. + * However, if the variable paused is set to true, this function only waits for the currently + * running tasks (otherwise it would wait forever). To wait for a specific task, use submit() + * instead, and call the wait() member function of the generated future. + */ + void wait_for_tasks() + { + while (true) { + if (!paused) { + if (tasks_total == 0) break; + } else { + if (get_tasks_running() == 0) break; + } + sleep_or_yield(); + } + } + + /** + * @brief An atomic variable indicating to the workers to pause. When set to true, the workers + * temporarily stop popping new tasks out of the queue, although any tasks already executed will + * keep running until they are done. Set to false again to resume popping tasks. + */ + std::atomic paused = false; + + /** + * @brief The duration, in microseconds, that the worker function should sleep for when it cannot + * find any tasks in the queue. If set to 0, then instead of sleeping, the worker function will + * execute std::this_thread::yield() if there are no tasks in the queue. The default value is + * 1000. + */ + ui32 sleep_duration = 1000; + + private: + /** + * @brief Create the threads in the pool and assign a worker to each thread. + */ + void create_threads() + { + for (ui32 i = 0; i < thread_count; i++) { + threads[i] = std::thread(&thread_pool::worker, this); + } + } + + /** + * @brief Destroy the threads in the pool by joining them. + */ + void destroy_threads() + { + for (ui32 i = 0; i < thread_count; i++) { + threads[i].join(); + } + } + + /** + * @brief Try to pop a new task out of the queue. + * + * @param task A reference to the task. Will be populated with a function if the queue is not + * empty. + * @return true if a task was found, false if the queue is empty. + */ + bool pop_task(std::function& task) + { + const std::scoped_lock lock(queue_mutex); + if (tasks.empty()) + return false; + else { + task = std::move(tasks.front()); + tasks.pop(); + return true; + } + } + + /** + * @brief Sleep for sleep_duration microseconds. If that variable is set to zero, yield instead. + * + */ + void sleep_or_yield() + { + if (sleep_duration) + std::this_thread::sleep_for(std::chrono::microseconds(sleep_duration)); + else + std::this_thread::yield(); + } + + /** + * @brief A worker function to be assigned to each thread in the pool. Continuously pops tasks out + * of the queue and executes them, as long as the atomic variable running is set to true. + */ + void worker() + { + while (running) { + std::function task; + if (!paused && pop_task(task)) { + task(); + tasks_total--; + } else { + sleep_or_yield(); + } + } + } + + /** + * @brief A mutex to synchronize access to the task queue by different threads. + */ + mutable std::mutex queue_mutex; + + /** + * @brief An atomic variable indicating to the workers to keep running. When set to false, the + * workers permanently stop working. + */ + std::atomic running = true; + + /** + * @brief A queue of tasks to be executed by the threads. + */ + std::queue> tasks; + + /** + * @brief The number of threads in the pool. + */ + ui32 thread_count; + + /** + * @brief A smart pointer to manage the memory allocated for the threads. + */ + std::unique_ptr threads; + + /** + * @brief An atomic variable to keep track of the total number of unfinished tasks - either still + * in the queue, or running in a thread. + */ + std::atomic tasks_total = 0; +}; + +} // namespace detail +} // namespace cudf