Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust, python): reject multithreading on excessive ',\n' fields #6906

Merged
merged 1 commit into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 51 additions & 16 deletions polars/polars-io/src/csv/parser.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use memchr::memchr2_iter;
use num::traits::Pow;
use polars_core::prelude::*;

Expand Down Expand Up @@ -32,12 +33,36 @@ pub(crate) fn next_line_position(
quote_char: Option<u8>,
eol_char: u8,
) -> Option<usize> {
fn accept_line(
line: &[u8],
expected_fields: usize,
delimiter: u8,
eol_char: u8,
quote_char: Option<u8>,
) -> bool {
let mut count = 0usize;
for (field, _) in SplitFields::new(line, delimiter, quote_char, eol_char) {
if memchr2_iter(delimiter, eol_char, field).count() >= expected_fields {
return false;
}
count += 1;
}
count == expected_fields
}

// we check 3 subsequent lines for `accept_line` before we accept
// if 3 groups are rejected we reject completely
let mut rejected_line_groups = 0u8;

let mut total_pos = 0;
if input.is_empty() {
return None;
}
let mut lines_checked = 0u16;
loop {
if rejected_line_groups >= 3 {
return None;
}
lines_checked += 1;
// headers might have an extra value
// So if we have churned through enough lines
Expand All @@ -53,29 +78,39 @@ pub(crate) fn next_line_position(
}
debug_assert!(pos <= input.len());
let new_input = unsafe { input.get_unchecked(pos..) };
let line = SplitLines::new(new_input, quote_char.unwrap_or(b'"'), eol_char).next();

let count_fields =
|line: &[u8]| SplitFields::new(line, delimiter, quote_char, eol_char).count();
let mut lines = SplitLines::new(new_input, quote_char.unwrap_or(b'"'), eol_char);
let line = lines.next();

match (line, expected_fields) {
// count the fields, and determine if they are equal to what we expect from the schema
(Some(line), Some(expected_fields)) if { count_fields(line) == expected_fields } => {
return Some(total_pos + pos)
}
(Some(_), Some(_)) => {
debug_assert!(pos < input.len());
unsafe {
input = input.get_unchecked(pos + 1..);
(Some(line), Some(expected_fields)) => {
if accept_line(line, expected_fields, delimiter, eol_char, quote_char) {
let mut valid = true;
for line in lines.take(2) {
if !accept_line(line, expected_fields, delimiter, eol_char, quote_char) {
valid = false;
break;
}
}
if valid {
return Some(total_pos + pos);
} else {
rejected_line_groups += 1;
}
} else {
debug_assert!(pos < input.len());
unsafe {
input = input.get_unchecked(pos + 1..);
}
total_pos += pos + 1;
}
total_pos += pos + 1;
}
// don't count the fields
(Some(_), None) => return Some(total_pos + pos),
// no new line found, check latest line (without eol) for number of fields
(None, Some(expected_fields)) if { count_fields(new_input) == expected_fields } => {
return Some(total_pos + pos)
}
// // no new line found, check latest line (without eol) for number of fields
// (None, Some(expected_fields)) if { count_fields(new_input) == expected_fields } => {
// return Some(total_pos + pos)
// }
_ => return None,
}
}
Expand Down
35 changes: 17 additions & 18 deletions polars/polars-io/src/csv/read_impl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,30 +408,29 @@ impl<'a> CoreReader<'a> {

let chunk_size = std::cmp::min(self.chunk_size, total_rows);
let n_chunks = total_rows / chunk_size;
if logging {
eprintln!(
"no. of chunks: {n_chunks} processed by: {n_threads} threads at 1 chunk/thread",
);
}

let n_file_chunks = if streaming { n_chunks } else { *n_threads };

// split the file by the nearest new line characters such that every thread processes
// approximately the same number of rows.
Ok((
get_file_chunks(
bytes,
n_file_chunks,
self.schema.len(),
self.delimiter,
self.quote_char,
self.eol_char,
),
chunk_size,
total_rows,
starting_point_offset,

let chunks = get_file_chunks(
bytes,
))
n_file_chunks,
self.schema.len(),
self.delimiter,
self.quote_char,
self.eol_char,
);

if logging {
eprintln!(
"no. of chunks: {} processed by: {n_threads} threads.",
chunks.len()
);
}

Ok((chunks, chunk_size, total_rows, starting_point_offset, bytes))
}

fn get_projection(&mut self) -> Vec<usize> {
Expand Down
28 changes: 28 additions & 0 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from datetime import date, datetime, time, timedelta, timezone
from pathlib import Path

import numpy as np
import pytest

import polars as pl
Expand Down Expand Up @@ -1168,3 +1169,30 @@ def test_read_web_file() -> None:
url = "https://raw.githubusercontent.com/pola-rs/polars/master/examples/datasets/foods1.csv"
df = pl.read_csv(url)
assert df.shape == (27, 4)


@pytest.mark.slow()
def test_csv_multiline_splits() -> None:
# create a very unlikely csv file with many multilines in a
# single field (e.g. 5000). polars must reject multi-threading here
# as it cannot find proper file chunks without sequentially parsing.

np.random.seed(0)
f = io.BytesIO()

def some_multiline_str(n: int) -> str:
strs = []
strs.append('"')
# sample between 0 and 5 so that it is likely
# the multiline field also go 3 separators.
for length in np.random.randint(0, 5, n):
strs.append(f"{'xx,' * length}")

strs.append('"')
return "\n".join(strs)

for _ in range(4):
f.write(f"field1,field2,{some_multiline_str(5000)}\n".encode())

f.seek(0)
assert pl.read_csv(f, has_header=False).shape == (4, 3)