Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(rust): Add DataFrame::new_with_broadcast and simplify column uniqueness checks #18285

Merged
merged 4 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 93 additions & 113 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,37 @@ pub enum UniqueKeepStrategy {
Any,
}

fn ensure_names_unique<T, F>(items: &[T], mut get_name: F) -> PolarsResult<()>
where
F: FnMut(&T) -> &str,
{
// Always unique.
if items.len() <= 1 {
return Ok(());
}

if items.len() <= 4 {
// Too small to be worth spawning a hashmap for, this is at most 6 comparisons.
for i in 0..items.len() - 1 {
let name = get_name(&items[i]);
for other in items.iter().skip(i + 1) {
if name == get_name(other) {
polars_bail!(duplicate = name);
}
}
}
} else {
let mut names = PlHashSet::with_capacity(items.len());
for item in items {
let name = get_name(item);
if !names.insert(name) {
polars_bail!(duplicate = name);
}
}
}
Ok(())
}

/// A contiguous growable collection of `Series` that have the same length.
///
/// ## Use declarations
Expand Down Expand Up @@ -221,89 +252,62 @@ impl DataFrame {
/// let df = DataFrame::new(vec![s0, s1])?;
/// # Ok::<(), PolarsError>(())
/// ```
pub fn new<S: IntoSeries>(columns: Vec<S>) -> PolarsResult<Self> {
let mut first_len = None;
pub fn new(columns: Vec<Series>) -> PolarsResult<Self> {
ensure_names_unique(&columns, |s| s.name())?;

let shape_err = |&first_name, &first_len, &name, &len| {
polars_bail!(
ShapeMismatch: "could not create a new DataFrame: series {:?} has length {} \
while series {:?} has length {}",
first_name, first_len, name, len
);
};
if columns.len() > 1 {
let first_len = columns[0].len();
for col in &columns {
polars_ensure!(
col.len() == first_len,
ShapeMismatch: "could not create a new DataFrame: series {:?} has length {} while series {:?} has length {}",
columns[0].len(), first_len, col.name(), col.len()
);
}
}

let series_cols = if S::is_series() {
// SAFETY:
// we are guarded by the type system here.
#[allow(clippy::transmute_undefined_repr)]
let series_cols = unsafe { std::mem::transmute::<Vec<S>, Vec<Series>>(columns) };
let mut names = PlHashSet::with_capacity(series_cols.len());

for s in &series_cols {
let name = s.name();

match first_len {
Some(len) => {
if s.len() != len {
let first_series = &series_cols.first().unwrap();
return shape_err(
&first_series.name(),
&first_series.len(),
&name,
&s.len(),
);
}
},
None => first_len = Some(s.len()),
}
Ok(DataFrame { columns })
}

if !names.insert(name) {
polars_bail!(duplicate = name);
}
}
// we drop early as the brchk thinks the &str borrows are used when calling the drop
// of both `series_cols` and `names`
drop(names);
series_cols
} else {
let mut series_cols: Vec<Series> = Vec::with_capacity(columns.len());
let mut names = PlHashSet::with_capacity(columns.len());

// check for series length equality and convert into series in one pass
for s in columns {
let series = s.into_series();
// we have aliasing borrows so we must allocate a string
let name = series.name().to_string();

match first_len {
Some(len) => {
if series.len() != len {
let first_series = &series_cols.first().unwrap();
return shape_err(
&first_series.name(),
&first_series.len(),
&name.as_str(),
&series.len(),
);
}
},
None => first_len = Some(series.len()),
}
/// Converts a sequence of columns into a DataFrame, broadcasting length-1
/// columns to match the other columns.
pub fn new_with_broadcast(columns: Vec<Series>) -> PolarsResult<Self> {
ensure_names_unique(&columns, |s| s.name())?;
unsafe { Self::new_with_broadcast_no_checks(columns) }
}

if names.contains(&name) {
polars_bail!(duplicate = name);
/// Converts a sequence of columns into a DataFrame, broadcasting length-1
/// columns to match the other columns.
///
/// # Safety
/// Does not check that the column names are unique (which they must be).
pub unsafe fn new_with_broadcast_no_checks(mut columns: Vec<Series>) -> PolarsResult<Self> {
// The length of the longest non-unit length column determines the
// broadcast length. If all columns are unit-length the broadcast length
// is one.
let broadcast_len = columns
.iter()
.map(|s| s.len())
.filter(|l| *l != 1)
.max()
.unwrap_or(1);

for col in &mut columns {
// Length not equal to the broadcast len, needs broadcast or is an error.
let len = col.len();
if len != broadcast_len {
if len != 1 {
let name = col.name().to_owned();
let longest_column = columns.iter().max_by_key(|c| c.len()).unwrap().name();
polars_bail!(
ShapeMismatch: "could not create a new DataFrame: series {:?} has length {} while series {:?} has length {}",
name, len, longest_column, broadcast_len
);
}

series_cols.push(series);
names.insert(name);
*col = col.new_from_index(0, broadcast_len);
}
drop(names);
series_cols
};

Ok(DataFrame {
columns: series_cols,
})
}
Ok(unsafe { DataFrame::new_no_checks(columns) })
}

/// Creates an empty `DataFrame` usable in a compile time context (such as static initializers).
Expand Down Expand Up @@ -442,16 +446,7 @@ impl DataFrame {
/// It is the callers responsibility to uphold the contract of all `Series`
/// having an equal length, if not this may panic down the line.
pub unsafe fn new_no_length_checks(columns: Vec<Series>) -> PolarsResult<DataFrame> {
let mut names = PlHashSet::with_capacity(columns.len());
for column in &columns {
let name = column.name();
if !names.insert(name) {
polars_bail!(duplicate = name);
}
}
// we drop early as the brchk thinks the &str borrows are used when calling the drop
// of both `columns` and `names`
drop(names);
ensure_names_unique(&columns, |s| s.name())?;
Ok(DataFrame { columns })
}

Expand Down Expand Up @@ -637,12 +632,7 @@ impl DataFrame {
ShapeMismatch: "{} column names provided for a DataFrame of width {}",
names.len(), self.width()
);
let unique_names: PlHashSet<&str> =
PlHashSet::from_iter(names.iter().map(|name| name.as_ref()));
polars_ensure!(
unique_names.len() == self.width(),
Duplicate: "duplicate column names found"
);
ensure_names_unique(names, |s| s.as_ref())?;

let columns = mem::take(&mut self.columns);
self.columns = columns
Expand Down Expand Up @@ -1447,7 +1437,7 @@ impl DataFrame {
}

pub fn _select_impl(&self, cols: &[SmartString]) -> PolarsResult<Self> {
self.select_check_duplicates(cols)?;
ensure_names_unique(cols, |s| s.as_str())?;
self._select_impl_unchecked(cols)
}

Expand Down Expand Up @@ -1493,7 +1483,7 @@ impl DataFrame {
check_duplicates: bool,
) -> PolarsResult<Self> {
if check_duplicates {
self.select_check_duplicates(cols)?;
ensure_names_unique(cols, |s| s.as_str())?;
}
let selected = self.select_series_impl_with_schema(cols, schema)?;
Ok(unsafe { DataFrame::new_no_checks(selected) })
Expand Down Expand Up @@ -1526,21 +1516,11 @@ impl DataFrame {
}

fn select_physical_impl(&self, cols: &[SmartString]) -> PolarsResult<Self> {
self.select_check_duplicates(cols)?;
ensure_names_unique(cols, |s| s.as_str())?;
let selected = self.select_series_physical_impl(cols)?;
Ok(unsafe { DataFrame::new_no_checks(selected) })
}

fn select_check_duplicates(&self, cols: &[SmartString]) -> PolarsResult<()> {
let mut names = PlHashSet::with_capacity(cols.len());
for name in cols {
if !names.insert(name.as_str()) {
polars_bail!(duplicate = name);
}
}
Ok(())
}

/// Select column(s) from this [`DataFrame`] and return them into a [`Vec`].
///
/// # Example
Expand Down Expand Up @@ -1712,16 +1692,16 @@ impl DataFrame {
/// }
/// ```
pub fn rename(&mut self, column: &str, name: &str) -> PolarsResult<&mut Self> {
if column == name {
return Ok(self);
}
polars_ensure!(
self.columns.iter().all(|c| c.name() != name),
Duplicate: "column rename attempted with already existing name \"{name}\""
);
self.select_mut(column)
.ok_or_else(|| polars_err!(col_not_found = column))
.map(|s| s.rename(name))?;
let unique_names: PlHashSet<&str> =
PlHashSet::from_iter(self.columns.iter().map(|s| s.name()));
polars_ensure!(
unique_names.len() == self.width(),
Duplicate: "duplicate column names found"
);
drop(unique_names);
Ok(self)
}

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ impl Clone for DslPlan {

impl Default for DslPlan {
fn default() -> Self {
let df = DataFrame::new::<Series>(vec![]).unwrap();
let df = DataFrame::empty();
let schema = df.schema();
DslPlan::DataFrameScan {
df: Arc::new(df),
Expand Down
15 changes: 1 addition & 14 deletions crates/polars-stream/src/nodes/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,7 @@ impl ComputeNode for SelectNode {
out._add_columns(selected, &slf.schema)?;
out
} else {
// Broadcast scalars.
let max_non_unit_length = selected
.iter()
.map(|s| s.len())
.filter(|l| *l != 1)
.max()
.unwrap_or(1);
for s in &mut selected {
if s.len() != max_non_unit_length {
assert!(s.len() == 1, "got series of incompatible lengths");
*s = s.new_from_index(0, max_non_unit_length);
}
}
unsafe { DataFrame::new_no_checks(selected) }
DataFrame::new_with_broadcast(selected)?
};

let mut morsel = Morsel::new(ret, seq, source_token);
Expand Down
5 changes: 4 additions & 1 deletion docs/src/rust/user-guide/expressions/lists.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
ListPrimitiveChunkedBuilder::new("Array_2", 8, 8, DataType::Int32);
col2.append_slice(&[1, 7, 3]);
col2.append_slice(&[8, 1, 0]);
let array_df = DataFrame::new([col1.finish(), col2.finish()].into())?;
let array_df = DataFrame::new(vec![
col1.finish().into_series(),
col2.finish().into_series(),
])?;

println!("{}", &array_df);
// --8<-- [end:array_df]
Expand Down