Skip to content

Commit

Permalink
[FEA] Adds option to recover from invalid JSON lines in JSON tokenizer (
Browse files Browse the repository at this point in the history
#13344)

This PR adds the option to recover from invalid JSON lines to the JSON tokenizer. 

**New option and behaviour:**
- We add the option `enable_recover_from_error` to `json_reader_options`. When this option is enabled for a JSON lines input, the reader will recover from a parsing error encountered on an invalid JSON line and continue parsing the next line.
- When the new option is not enabled, we expect the behaviour of existing functionality to remain untouched.
- When recovering from invalid JSON lines is enabled, all newline characters that are not enclosed in quotes (i.e., newline characters outside of `strings` and `field names`) are interpreted as delimiters of a JSON line. We will introduce a new option that reflects this behaviour for JSON lines inputs that should not recover from errors in a future PR. Hence, this PR introduces the `JSON_LINES_STRICT` enum but does not yet hook it up.

**Implementation details:**
- When recovering from invalid JSON lines is enabled, `get_token_stream()` will delimit each JSON line with a `LineEnd` token to facilitate the identification of tokens that belong to an invalid JSON line.
- We extend the logical stack and introduce a new operation, `reset()`. A `reset()` operation resets the logical stack to an empty stack. This is necessary to reset the stack of the pushdown automaton (PDA) after an invalid JSON line to make sure the stack in subsequent lines is not corrupted.
- We modify the transition and translation table of the finite-state transducer (FST) that is used to generate the push-down automaton's (PDA) stack context operations to emit such a `reset()` operation, iff `recovery` is enabled.
- We modify the transition and translation table of the finite-state transducer (FST) that is used to simulate the full PDA to (1) recover after an invalid JSON line and (2) emit the `LineEnd` token, iff `recovery` is enabled.
- To clean up JSON lines that contain tokens belonging to an invalid line, a token *post-processing* stage is needed. The *post-processing* will replace sequences of `LineEnd` `token*` `ErrorBegin` with the sequence `StructBegin` `StructEnd` (i.e., effectively a `null` row) for record orient inputs. 
- This post-processing is implemented by running an FST on the reverse token stream, discarding all tokens between `ErrorBegin` and the next `LineEnd`, emitting `StructBegin` `StructEnd` pairs on the end of such an invalid line.

This is an initial PR to addresses #12532.

Authors:
  - Elias Stehle (https://github.com/elstehle)

Approvers:
  - Vukasin Milovanovic (https://github.com/vuule)
  - Karthikeyan (https://github.com/karthikeyann)

URL: #13344
  • Loading branch information
elstehle authored Jul 14, 2023
1 parent 6a30b69 commit 2436e0b
Show file tree
Hide file tree
Showing 12 changed files with 1,603 additions and 554 deletions.
31 changes: 23 additions & 8 deletions cpp/benchmarks/io/fst.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,9 @@ auto make_test_json_data(nvbench::state& state)
// Type used to represent the atomic symbol type used within the finite-state machine
using SymbolT = char;
// Type sufficiently large to index symbols within the input and output (may be unsigned)
using SymbolOffsetT = uint32_t;
// Helper class to set up transition table, symbol group lookup table, and translation table
using DfaFstT = cudf::io::fst::detail::Dfa<char, NUM_SYMBOL_GROUPS, TT_NUM_STATES>;
constexpr std::size_t single_item = 1;
using SymbolOffsetT = uint32_t;
constexpr std::size_t single_item = 1;
constexpr auto max_translation_table_size = TT_NUM_STATES * NUM_SYMBOL_GROUPS;

} // namespace

Expand All @@ -94,7 +93,11 @@ void BM_FST_JSON(nvbench::state& state)
cudf::detail::hostdevice_vector<SymbolOffsetT> out_indexes_gpu(d_input.size(), stream_view);

// Run algorithm
DfaFstT parser{pda_sgs, pda_state_tt, pda_out_tt, stream.value()};
auto parser = cudf::io::fst::detail::make_fst(
cudf::io::fst::detail::make_symbol_group_lut(pda_sgs),
cudf::io::fst::detail::make_transition_table(pda_state_tt),
cudf::io::fst::detail::make_translation_table<max_translation_table_size>(pda_out_tt),
stream);

state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value()));
state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) {
Expand Down Expand Up @@ -129,7 +132,11 @@ void BM_FST_JSON_no_outidx(nvbench::state& state)
cudf::detail::hostdevice_vector<SymbolOffsetT> out_indexes_gpu(d_input.size(), stream_view);

// Run algorithm
DfaFstT parser{pda_sgs, pda_state_tt, pda_out_tt, stream.value()};
auto parser = cudf::io::fst::detail::make_fst(
cudf::io::fst::detail::make_symbol_group_lut(pda_sgs),
cudf::io::fst::detail::make_transition_table(pda_state_tt),
cudf::io::fst::detail::make_translation_table<max_translation_table_size>(pda_out_tt),
stream);

state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value()));
state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) {
Expand Down Expand Up @@ -162,7 +169,11 @@ void BM_FST_JSON_no_out(nvbench::state& state)
cudf::detail::hostdevice_vector<SymbolOffsetT> output_gpu_size(single_item, stream_view);

// Run algorithm
DfaFstT parser{pda_sgs, pda_state_tt, pda_out_tt, stream.value()};
auto parser = cudf::io::fst::detail::make_fst(
cudf::io::fst::detail::make_symbol_group_lut(pda_sgs),
cudf::io::fst::detail::make_transition_table(pda_state_tt),
cudf::io::fst::detail::make_translation_table<max_translation_table_size>(pda_out_tt),
stream);

state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value()));
state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) {
Expand Down Expand Up @@ -196,7 +207,11 @@ void BM_FST_JSON_no_str(nvbench::state& state)
cudf::detail::hostdevice_vector<SymbolOffsetT> out_indexes_gpu(d_input.size(), stream_view);

// Run algorithm
DfaFstT parser{pda_sgs, pda_state_tt, pda_out_tt, stream.value()};
auto parser = cudf::io::fst::detail::make_fst(
cudf::io::fst::detail::make_symbol_group_lut(pda_sgs),
cudf::io::fst::detail::make_transition_table(pda_state_tt),
cudf::io::fst::detail::make_translation_table<max_translation_table_size>(pda_out_tt),
stream);

state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value()));
state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) {
Expand Down
2 changes: 2 additions & 0 deletions cpp/include/cudf/io/detail/tokenize_json.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ enum token_t : PdaTokenT {
ValueEnd,
/// Beginning-of-error token (on first encounter of a parsing error)
ErrorBegin,
/// Delimiting a JSON line for error recovery
LineEnd,
/// Total number of tokens
NUM_TOKENS
};
Expand Down
37 changes: 37 additions & 0 deletions cpp/include/cudf/io/json.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ struct schema_element {
std::map<std::string, schema_element> child_types;
};

/**
* @brief Control the error recovery behavior of the json parser
*/
enum class json_recovery_mode_t {
FAIL, ///< Does not recover from an error when encountering an invalid format
RECOVER_WITH_NULL ///< Recovers from an error, replacing invalid records with null
};

/**
* @brief Input arguments to the `read_json` interface.
*
Expand Down Expand Up @@ -105,6 +113,9 @@ class json_reader_options {
// Whether to keep the quote characters of string values
bool _keep_quotes = false;

// Whether to recover after an invalid JSON line
json_recovery_mode_t _recovery_mode = json_recovery_mode_t::FAIL;

/**
* @brief Constructor from source info.
*
Expand Down Expand Up @@ -235,6 +246,13 @@ class json_reader_options {
*/
bool is_enabled_keep_quotes() const { return _keep_quotes; }

/**
* @brief Queries the JSON reader's behavior on invalid JSON lines.
*
* @returns An enum that specifies the JSON reader's behavior on invalid JSON lines.
*/
json_recovery_mode_t recovery_mode() const { return _recovery_mode; }

/**
* @brief Set data types for columns to be read.
*
Expand Down Expand Up @@ -305,6 +323,13 @@ class json_reader_options {
* of string values
*/
void enable_keep_quotes(bool val) { _keep_quotes = val; }

/**
* @brief Specifies the JSON reader's behavior on invalid JSON lines.
*
* @param val An enum value to indicate the JSON reader's behavior on invalid JSON lines.
*/
void set_recovery_mode(json_recovery_mode_t val) { _recovery_mode = val; }
};

/**
Expand Down Expand Up @@ -449,6 +474,18 @@ class json_reader_options_builder {
return *this;
}

/**
* @brief Specifies the JSON reader's behavior on invalid JSON lines.
*
* @param val An enum value to indicate the JSON reader's behavior on invalid JSON lines.
* @return this for chaining
*/
json_reader_options_builder& recovery_mode(json_recovery_mode_t val)
{
options._recovery_mode = val;
return *this;
}

/**
* @brief move json_reader_options member once it's built.
*/
Expand Down
53 changes: 37 additions & 16 deletions cpp/src/io/fst/agent_dfa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,18 @@ class DFASimulationCallbackWrapper {
if (!write) out_count = 0;
}

template <typename CharIndexT, typename StateIndexT, typename SymbolIndexT>
template <typename CharIndexT, typename StateIndexT, typename SymbolIndexT, typename SymbolT>
__host__ __device__ __forceinline__ void ReadSymbol(CharIndexT const character_index,
StateIndexT const old_state,
StateIndexT const new_state,
SymbolIndexT const symbol_id)
SymbolIndexT const symbol_id,
SymbolT const read_symbol)
{
uint32_t const count = transducer_table(old_state, symbol_id);
uint32_t const count = transducer_table(old_state, symbol_id, read_symbol);
if (write) {
for (uint32_t out_char = 0; out_char < count; out_char++) {
out_it[out_count + out_char] = transducer_table(old_state, symbol_id, out_char);
out_it[out_count + out_char] =
transducer_table(old_state, symbol_id, out_char, read_symbol);
out_idx_it[out_count + out_char] = offset + character_index;
}
}
Expand Down Expand Up @@ -127,9 +129,10 @@ class StateVectorTransitionOp {
{
}

template <typename CharIndexT, typename SymbolIndexT>
template <typename CharIndexT, typename SymbolIndexT, typename SymbolT>
__host__ __device__ __forceinline__ void ReadSymbol(CharIndexT const& character_index,
SymbolIndexT const read_symbol_id) const
SymbolIndexT const& read_symbol_id,
SymbolT const& read_symbol) const
{
for (int32_t i = 0; i < NUM_INSTANCES; ++i) {
state_vector[i] = transition_table(state_vector[i], read_symbol_id);
Expand All @@ -154,15 +157,16 @@ struct StateTransitionOp {
{
}

template <typename CharIndexT, typename SymbolIndexT>
template <typename CharIndexT, typename SymbolIndexT, typename SymbolT>
__host__ __device__ __forceinline__ void ReadSymbol(CharIndexT const& character_index,
SymbolIndexT const& read_symbol_id)
SymbolIndexT const& read_symbol_id,
SymbolT const& read_symbol)
{
// Remember what state we were in before we made the transition
StateIndexT previous_state = state;

state = transition_table(state, read_symbol_id);
callback_op.ReadSymbol(character_index, previous_state, state, read_symbol_id);
callback_op.ReadSymbol(character_index, previous_state, state, read_symbol_id, read_symbol);
}
};

Expand Down Expand Up @@ -230,7 +234,7 @@ struct AgentDFA {
for (int32_t i = 0; i < NUM_SYMBOLS; ++i) {
if (IS_FULL_BLOCK || threadIdx.x * SYMBOLS_PER_THREAD + i < max_num_chars) {
auto matched_id = symbol_matcher(chars[i]);
callback_op.ReadSymbol(i, matched_id);
callback_op.ReadSymbol(i, matched_id, chars[i]);
}
}
}
Expand All @@ -253,15 +257,16 @@ struct AgentDFA {
//---------------------------------------------------------------------
// LOADING FULL BLOCK OF CHARACTERS, NON-ALIASED
//---------------------------------------------------------------------
__device__ __forceinline__ void LoadBlock(CharT const* d_chars,
template <typename CharInItT>
__device__ __forceinline__ void LoadBlock(CharInItT d_chars,
OffsetT const block_offset,
OffsetT const num_total_symbols,
cub::Int2Type<true> /*IS_FULL_BLOCK*/,
cub::Int2Type<1> /*ALIGNMENT*/)
{
CharT thread_chars[SYMBOLS_PER_THREAD];

CharT const* d_block_symbols = d_chars + block_offset;
CharInItT d_block_symbols = d_chars + block_offset;
cub::LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_block_symbols, thread_chars);

#pragma unroll
Expand All @@ -273,7 +278,8 @@ struct AgentDFA {
//---------------------------------------------------------------------
// LOADING PARTIAL BLOCK OF CHARACTERS, NON-ALIASED
//---------------------------------------------------------------------
__device__ __forceinline__ void LoadBlock(CharT const* d_chars,
template <typename CharInItT>
__device__ __forceinline__ void LoadBlock(CharInItT d_chars,
OffsetT const block_offset,
OffsetT const num_total_symbols,
cub::Int2Type<false> /*IS_FULL_BLOCK*/,
Expand All @@ -286,7 +292,7 @@ struct AgentDFA {
// Last unit to be loaded is IDIV_CEIL(#SYM, SYMBOLS_PER_UNIT)
OffsetT num_total_chars = num_total_symbols - block_offset;

CharT const* d_block_symbols = d_chars + block_offset;
CharInItT d_block_symbols = d_chars + block_offset;
cub::LoadDirectStriped<BLOCK_THREADS>(
threadIdx.x, d_block_symbols, thread_chars, num_total_chars);

Expand Down Expand Up @@ -372,11 +378,26 @@ struct AgentDFA {
}
}

template <typename CharInItT>
__device__ __forceinline__ void LoadBlock(CharInItT d_chars,
OffsetT const block_offset,
OffsetT const num_total_symbols)
{
// Check if we are loading a full tile of data
if (block_offset + SYMBOLS_PER_UINT_BLOCK < num_total_symbols) {
LoadBlock(
d_chars, block_offset, num_total_symbols, cub::Int2Type<true>(), cub::Int2Type<1>());
} else {
LoadBlock(
d_chars, block_offset, num_total_symbols, cub::Int2Type<false>(), cub::Int2Type<1>());
}
}

template <int32_t NUM_STATES, typename SymbolMatcherT, typename TransitionTableT>
__device__ __forceinline__ void GetThreadStateTransitionVector(
SymbolMatcherT const& symbol_matcher,
TransitionTableT const& transition_table,
CharT const* d_chars,
SymbolItT d_chars,
OffsetT const block_offset,
OffsetT const num_total_symbols,
std::array<StateIndexT, NUM_STATES>& state_vector)
Expand Down Expand Up @@ -416,7 +437,7 @@ struct AgentDFA {
__device__ __forceinline__ void GetThreadStateTransitions(
SymbolMatcherT const& symbol_matcher,
TransitionTableT const& transition_table,
CharT const* d_chars,
SymbolItT d_chars,
OffsetT const block_offset,
OffsetT const num_total_symbols,
StateIndexT& state,
Expand Down
51 changes: 30 additions & 21 deletions cpp/src/io/fst/logical_stack.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ namespace cudf::io::fst {
* @brief Describes the kind of stack operation.
*/
enum class stack_op_type : int8_t {
READ = 0, ///< Operation reading what is currently on top of the stack
PUSH = 1, ///< Operation pushing a new item on top of the stack
POP = 2 ///< Operation popping the item currently on top of the stack
READ = 0, ///< Operation reading what is currently on top of the stack
PUSH = 1, ///< Operation pushing a new item on top of the stack
POP = 2, ///< Operation popping the item currently on top of the stack
RESET = 3 ///< Operation popping all items currently on the stack
};

namespace detail {
Expand Down Expand Up @@ -119,9 +120,9 @@ struct StackSymbolToStackOp {
{
stack_op_type stack_op = symbol_to_stack_op_type(stack_symbol);
// PUSH => +1, POP => -1, READ => 0
int32_t level_delta = stack_op == stack_op_type::PUSH ? 1
: stack_op == stack_op_type::POP ? -1
: 0;
int32_t level_delta = (stack_op == stack_op_type::PUSH) ? 1
: (stack_op == stack_op_type::POP) ? -1
: 0;
return StackOpT{static_cast<decltype(StackOpT::stack_level)>(level_delta), stack_symbol};
}

Expand All @@ -133,14 +134,20 @@ struct StackSymbolToStackOp {
* @brief Binary reduction operator to compute the absolute stack level from relative stack levels
* (i.e., +1 for a PUSH, -1 for a POP operation).
*/
template <typename StackSymbolToStackOpTypeT>
struct AddStackLevelFromStackOp {
template <typename StackLevelT, typename ValueT>
constexpr CUDF_HOST_DEVICE StackOp<StackLevelT, ValueT> operator()(
StackOp<StackLevelT, ValueT> const& lhs, StackOp<StackLevelT, ValueT> const& rhs) const
{
StackLevelT new_level = lhs.stack_level + rhs.stack_level;
StackLevelT new_level = (symbol_to_stack_op_type(rhs.value) == stack_op_type::RESET)
? 0
: (lhs.stack_level + rhs.stack_level);
return StackOp<StackLevelT, ValueT>{new_level, rhs.value};
}

/// Function object returning a stack operation type for a given stack symbol
StackSymbolToStackOpTypeT symbol_to_stack_op_type;
};

/**
Expand Down Expand Up @@ -323,13 +330,14 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols,

// Getting temporary storage requirements for the prefix sum of the stack level after each
// operation
CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScan(nullptr,
stack_level_scan_bytes,
stack_symbols_in,
d_kv_operations.Current(),
detail::AddStackLevelFromStackOp{},
num_symbols_in,
stream));
CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScan(
nullptr,
stack_level_scan_bytes,
stack_symbols_in,
d_kv_operations.Current(),
detail::AddStackLevelFromStackOp<StackSymbolToStackOpTypeT>{symbol_to_stack_op},
num_symbols_in,
stream));

// Getting temporary storage requirements for the stable radix sort (sorting by stack level of the
// operations)
Expand Down Expand Up @@ -393,13 +401,14 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols,
d_kv_operations = cub::DoubleBuffer<StackOpT>{d_kv_ops_current.data(), d_kv_ops_alt.data()};

// Compute prefix sum of the stack level after each operation
CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScan(temp_storage.data(),
total_temp_storage_bytes,
stack_symbols_in,
d_kv_operations.Current(),
detail::AddStackLevelFromStackOp{},
num_symbols_in,
stream));
CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScan(
temp_storage.data(),
total_temp_storage_bytes,
stack_symbols_in,
d_kv_operations.Current(),
detail::AddStackLevelFromStackOp<StackSymbolToStackOpTypeT>{symbol_to_stack_op},
num_symbols_in,
stream));

// Stable radix sort, sorting by stack level of the operations
d_kv_operations_unsigned = cub::DoubleBuffer<StackOpUnsignedT>{
Expand Down
Loading

0 comments on commit 2436e0b

Please sign in to comment.