Skip to content

Commit

Permalink
Merge branch 'branch-24.08' of github.com:rapidsai/cudf into pylibcud…
Browse files Browse the repository at this point in the history
…f-io-writers
  • Loading branch information
lithomas1 committed Jun 24, 2024
2 parents 53b821c + 0c6b828 commit 186a2fb
Show file tree
Hide file tree
Showing 39 changed files with 857 additions and 235 deletions.
13 changes: 13 additions & 0 deletions cpp/include/cudf/io/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,19 @@ struct column_name_info {
}

column_name_info() = default;

/**
* @brief Compares two column name info structs for equality
*
* @param rhs column name info struct to compare against
* @return boolean indicating if this and rhs are equal
*/
bool operator==(column_name_info const& rhs) const
{
return ((name == rhs.name) && (is_nullable == rhs.is_nullable) &&
(is_binary == rhs.is_binary) && (type_length == rhs.type_length) &&
(children == rhs.children));
};
};

/**
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/groupby/hash/groupby.cu
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,8 @@ std::unique_ptr<table> groupby(table_view const& keys,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
auto const num_keys = keys.num_rows();
// convert to int64_t to avoid potential overflow with large `keys`
auto const num_keys = static_cast<int64_t>(keys.num_rows());
auto const null_keys_are_equal = null_equality::EQUAL;
auto const has_null = nullate::DYNAMIC{cudf::has_nested_nulls(keys)};

Expand Down
121 changes: 105 additions & 16 deletions cpp/src/io/json/read_json.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include "io/json/nested_json.hpp"
#include "read_json.hpp"

#include <cudf/concatenate.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/utilities/integer_utils.hpp>
#include <cudf/detail/utilities/stream_pool.hpp>
#include <cudf/detail/utilities/vector_factories.hpp>
#include <cudf/io/detail/json.hpp>
Expand Down Expand Up @@ -76,15 +78,15 @@ device_span<char> ingest_raw_input(device_span<char> buffer,
auto constexpr num_delimiter_chars = 1;

if (compression == compression_type::NONE) {
std::vector<size_type> delimiter_map{};
std::vector<size_t> delimiter_map{};
std::vector<size_t> prefsum_source_sizes(sources.size());
std::vector<std::unique_ptr<datasource::buffer>> h_buffers;
delimiter_map.reserve(sources.size());
size_t bytes_read = 0;
std::transform_inclusive_scan(sources.begin(),
sources.end(),
prefsum_source_sizes.begin(),
std::plus<int>{},
std::plus<size_t>{},
[](std::unique_ptr<datasource> const& s) { return s->size(); });
auto upper =
std::upper_bound(prefsum_source_sizes.begin(), prefsum_source_sizes.end(), range_offset);
Expand Down Expand Up @@ -259,6 +261,33 @@ datasource::owning_buffer<rmm::device_uvector<char>> get_record_range_raw_input(
readbufspan.size() - first_delim_pos - shift_for_nonzero_offset);
}

table_with_metadata read_batch(host_span<std::unique_ptr<datasource>> sources,
json_reader_options const& reader_opts,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
datasource::owning_buffer<rmm::device_uvector<char>> bufview =
get_record_range_raw_input(sources, reader_opts, stream);

// If input JSON buffer has single quotes and option to normalize single quotes is enabled,
// invoke pre-processing FST
if (reader_opts.is_enabled_normalize_single_quotes()) {
normalize_single_quotes(bufview, stream, rmm::mr::get_current_device_resource());
}

// If input JSON buffer has unquoted spaces and tabs and option to normalize whitespaces is
// enabled, invoke pre-processing FST
if (reader_opts.is_enabled_normalize_whitespace()) {
normalize_whitespace(bufview, stream, rmm::mr::get_current_device_resource());
}

auto buffer =
cudf::device_span<char const>(reinterpret_cast<char const*>(bufview.data()), bufview.size());
stream.synchronize();
return device_parse_nested_json(buffer, reader_opts, stream, mr);
}

table_with_metadata read_json(host_span<std::unique_ptr<datasource>> sources,
json_reader_options const& reader_opts,
rmm::cuda_stream_view stream,
Expand All @@ -278,25 +307,85 @@ table_with_metadata read_json(host_span<std::unique_ptr<datasource>> sources,
"Multiple inputs are supported only for JSON Lines format");
}

datasource::owning_buffer<rmm::device_uvector<char>> bufview =
get_record_range_raw_input(sources, reader_opts, stream);
std::for_each(sources.begin(), sources.end(), [](auto const& source) {
CUDF_EXPECTS(source->size() < std::numeric_limits<int>::max(),
"The size of each source file must be less than INT_MAX bytes");
});

// If input JSON buffer has single quotes and option to normalize single quotes is enabled,
// invoke pre-processing FST
if (reader_opts.is_enabled_normalize_single_quotes()) {
normalize_single_quotes(bufview, stream, rmm::mr::get_current_device_resource());
constexpr size_t batch_size_ub = std::numeric_limits<int>::max();
size_t const chunk_offset = reader_opts.get_byte_range_offset();
size_t chunk_size = reader_opts.get_byte_range_size();
chunk_size = !chunk_size ? sources_size(sources, 0, 0) : chunk_size;

// Identify the position of starting source file from which to begin batching based on
// byte range offset. If the offset is larger than the sum of all source
// sizes, then start_source is total number of source files i.e. no file is read
size_t const start_source = [&]() {
size_t sum = 0;
for (size_t src_idx = 0; src_idx < sources.size(); ++src_idx) {
if (sum + sources[src_idx]->size() > chunk_offset) return src_idx;
sum += sources[src_idx]->size();
}
return sources.size();
}();

// Construct batches of source files, with starting position of batches indicated by
// batch_positions. The size of each batch i.e. the sum of sizes of the source files in the batch
// is capped at INT_MAX bytes.
size_t cur_size = 0;
std::vector<size_t> batch_positions;
std::vector<size_t> batch_sizes;
batch_positions.push_back(0);
for (size_t i = start_source; i < sources.size(); i++) {
cur_size += sources[i]->size();
if (cur_size >= batch_size_ub) {
batch_positions.push_back(i);
batch_sizes.push_back(cur_size - sources[i]->size());
cur_size = sources[i]->size();
}
}
batch_positions.push_back(sources.size());
batch_sizes.push_back(cur_size);

// If input JSON buffer has unquoted spaces and tabs and option to normalize whitespaces is
// enabled, invoke pre-processing FST
if (reader_opts.is_enabled_normalize_whitespace()) {
normalize_whitespace(bufview, stream, rmm::mr::get_current_device_resource());
// If there is a single batch, then we can directly return the table without the
// unnecessary concatenate
if (batch_sizes.size() == 1) return read_batch(sources, reader_opts, stream, mr);

std::vector<cudf::io::table_with_metadata> partial_tables;
json_reader_options batched_reader_opts{reader_opts};

// Dispatch individual batches to read_batch and push the resulting table into
// partial_tables array. Note that the reader options need to be updated for each
// batch to adjust byte range offset and byte range size.
for (size_t i = 0; i < batch_sizes.size(); i++) {
batched_reader_opts.set_byte_range_size(std::min(batch_sizes[i], chunk_size));
partial_tables.emplace_back(read_batch(
host_span<std::unique_ptr<datasource>>(sources.begin() + batch_positions[i],
batch_positions[i + 1] - batch_positions[i]),
batched_reader_opts,
stream,
rmm::mr::get_current_device_resource()));
if (chunk_size <= batch_sizes[i]) break;
chunk_size -= batch_sizes[i];
batched_reader_opts.set_byte_range_offset(0);
}

auto buffer =
cudf::device_span<char const>(reinterpret_cast<char const*>(bufview.data()), bufview.size());
stream.synchronize();
return device_parse_nested_json(buffer, reader_opts, stream, mr);
auto expects_schema_equality =
std::all_of(partial_tables.begin() + 1,
partial_tables.end(),
[&gt = partial_tables[0].metadata.schema_info](auto& ptbl) {
return ptbl.metadata.schema_info == gt;
});
CUDF_EXPECTS(expects_schema_equality,
"Mismatch in JSON schema across batches in multi-source multi-batch reading");

auto partial_table_views = std::vector<cudf::table_view>(partial_tables.size());
std::transform(partial_tables.begin(),
partial_tables.end(),
partial_table_views.begin(),
[](auto const& table) { return table.tbl->view(); });
return table_with_metadata{cudf::concatenate(partial_table_views, stream, mr),
{partial_tables[0].metadata.schema_info}};
}

} // namespace cudf::io::json::detail
4 changes: 2 additions & 2 deletions cpp/src/io/text/byte_range_info.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,7 +31,7 @@ std::vector<byte_range_info> create_byte_range_infos_consecutive(int64_t total_b
auto range_size = util::div_rounding_up_safe(total_bytes, range_count);
auto ranges = std::vector<byte_range_info>();

ranges.reserve(range_size);
ranges.reserve(range_count);

for (int64_t i = 0; i < range_count; i++) {
auto offset = i * range_size;
Expand Down
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ ConfigureTest(
LARGE_STRINGS_TEST
large_strings/concatenate_tests.cpp
large_strings/case_tests.cpp
large_strings/json_tests.cpp
large_strings/large_strings_fixture.cpp
large_strings/merge_tests.cpp
large_strings/parquet_tests.cpp
Expand Down
58 changes: 58 additions & 0 deletions cpp/tests/large_strings/json_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "large_strings_fixture.hpp"

#include <cudf/io/json.hpp>
#include <cudf/utilities/span.hpp>

struct JsonLargeReaderTest : public cudf::test::StringsLargeTest {};

TEST_F(JsonLargeReaderTest, MultiBatch)
{
std::string json_string = R"(
{ "a": { "y" : 6}, "b" : [1, 2, 3], "c": 11 }
{ "a": { "y" : 6}, "b" : [4, 5 ], "c": 12 }
{ "a": { "y" : 6}, "b" : [6 ], "c": 13 }
{ "a": { "y" : 6}, "b" : [7 ], "c": 14 })";
constexpr size_t expected_file_size = std::numeric_limits<int>::max() / 2;
std::size_t const log_repetitions =
static_cast<std::size_t>(std::ceil(std::log2(expected_file_size / json_string.size())));

json_string.reserve(json_string.size() * (1UL << log_repetitions));
std::size_t numrows = 4;
for (std::size_t i = 0; i < log_repetitions; i++) {
json_string += json_string;
numrows <<= 1;
}

constexpr int num_sources = 2;
std::vector<cudf::host_span<char>> hostbufs(
num_sources, cudf::host_span<char>(json_string.data(), json_string.size()));

// Initialize parsing options (reading json lines)
cudf::io::json_reader_options json_lines_options =
cudf::io::json_reader_options::builder(
cudf::io::source_info{
cudf::host_span<cudf::host_span<char>>(hostbufs.data(), hostbufs.size())})
.lines(true)
.compression(cudf::io::compression_type::NONE)
.recovery_mode(cudf::io::json_recovery_mode_t::FAIL);

// Read full test data via existing, nested JSON lines reader
cudf::io::table_with_metadata current_reader_table = cudf::io::read_json(json_lines_options);
ASSERT_EQ(current_reader_table.tbl->num_rows(), numrows * num_sources);
}
3 changes: 2 additions & 1 deletion java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7838,11 +7838,12 @@ void testSumWithStrings() {
.build();
Table result = t.groupBy(0).aggregate(
GroupByAggregation.sum().onColumn(1));
Table sorted = result.orderBy(OrderByArg.asc(0));
Table expected = new Table.TestBuilder()
.column("1-URGENT", "3-MEDIUM")
.column(5289L + 5303L, 5203L + 5206L)
.build()) {
assertTablesAreEqual(expected, result);
assertTablesAreEqual(expected, sorted);
}
}

Expand Down
10 changes: 4 additions & 6 deletions python/cudf/cudf/core/_base_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,7 @@ def __contains__(self, item):
hash(item)
return item in self._values

def _copy_type_metadata(
self, other: Self, *, override_dtypes=None
) -> Self:
def _copy_type_metadata(self: Self, other: Self) -> Self:
raise NotImplementedError

def get_level_values(self, level):
Expand Down Expand Up @@ -1122,7 +1120,7 @@ def difference(self, other, sort=None):
res_name = _get_result_name(self.name, other.name)

if is_mixed_with_object_dtype(self, other) or len(other) == 0:
difference = self.copy().unique()
difference = self.unique()
difference.name = res_name
if sort is True:
return difference.sort_values()
Expand Down Expand Up @@ -1746,7 +1744,7 @@ def rename(self, name, inplace=False):
self.name = name
return None
else:
out = self.copy(deep=True)
out = self.copy(deep=False)
out.name = name
return out

Expand Down Expand Up @@ -2070,7 +2068,7 @@ def dropna(self, how="any"):
raise ValueError(f"{how=} must be 'any' or 'all'")
try:
if not self.hasnans:
return self.copy()
return self.copy(deep=False)
except NotImplementedError:
pass
# This is to be consistent with IndexedFrame.dropna to handle nans
Expand Down
33 changes: 32 additions & 1 deletion python/cudf/cudf/core/_internals/timezones.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,50 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
from __future__ import annotations

import datetime
import os
import zoneinfo
from functools import lru_cache
from typing import TYPE_CHECKING, Literal

import numpy as np
import pandas as pd

import cudf
from cudf._lib.timezone import make_timezone_transition_table
from cudf.core.column.column import as_column

if TYPE_CHECKING:
from cudf.core.column.datetime import DatetimeColumn
from cudf.core.column.timedelta import TimeDeltaColumn


def get_compatible_timezone(dtype: pd.DatetimeTZDtype) -> pd.DatetimeTZDtype:
"""Convert dtype.tz object to zoneinfo object if possible."""
tz = dtype.tz
if isinstance(tz, zoneinfo.ZoneInfo):
return dtype
if cudf.get_option("mode.pandas_compatible"):
raise NotImplementedError(
f"{tz} must be a zoneinfo.ZoneInfo object in pandas_compatible mode."
)
elif (tzname := getattr(tz, "zone", None)) is not None:
# pytz-like
key = tzname
elif (tz_file := getattr(tz, "_filename", None)) is not None:
# dateutil-like
key = tz_file.split("zoneinfo/")[-1]
elif isinstance(tz, datetime.tzinfo):
# Try to get UTC-like tzinfos
reference = datetime.datetime.now()
key = tz.tzname(reference)
if not (isinstance(key, str) and key.lower() == "utc"):
raise NotImplementedError(f"cudf does not support {tz}")
else:
raise NotImplementedError(f"cudf does not support {tz}")
new_tz = zoneinfo.ZoneInfo(key)
return pd.DatetimeTZDtype(dtype.unit, new_tz)


@lru_cache(maxsize=20)
def get_tz_data(zone_name: str) -> tuple[DatetimeColumn, TimeDeltaColumn]:
"""
Expand Down Expand Up @@ -87,6 +116,8 @@ def _read_tzfile_as_columns(
)

if not transition_times_and_offsets:
from cudf.core.column.column import as_column

# this happens for UTC-like zones
min_date = np.int64(np.iinfo("int64").min + 1).astype("M8[s]")
return (as_column([min_date]), as_column([np.timedelta64(0, "s")]))
Expand Down
Loading

0 comments on commit 186a2fb

Please sign in to comment.