Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust, python): block proj-pd and pred-pd on swapping rename #6303

Merged
merged 1 commit into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions polars/polars-core/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ impl Schema {
.ok_or_else(|| PolarsError::NotFound(name.to_string().into()))
}

pub fn remove(&mut self, name: &str) -> Option<DataType> {
self.inner.remove(name)
}

pub fn get_full(&self, name: &str) -> Option<(usize, &String, &DataType)> {
self.inner.get_full(name)
}
Expand Down Expand Up @@ -170,8 +174,18 @@ impl Schema {
Some(())
}

pub fn with_column(&mut self, name: String, dtype: DataType) {
self.inner.insert(name, dtype);
/// Insert a new column in the [`Schema`]
///
/// If an equivalent name already exists in the schema: the name remains and
/// retains in its place in the order, its corresponding value is updated
/// with [`DataType`] and the older dtype is returned inside `Some(_)`.
///
/// If no equivalent key existed in the map: the new name-dtype pair is
/// inserted, last in order, and `None` is returned.
///
/// Computes in **O(1)** time (amortized average).
pub fn with_column(&mut self, name: String, dtype: DataType) -> Option<DataType> {
self.inner.insert(name, dtype)
}

pub fn merge(&mut self, other: Self) {
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/polars-plan/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ impl LogicalPlanBuilder {

for fld in other_schema.iter_fields() {
if schema.get(fld.name()).is_none() {
schema.with_column(fld.name, fld.dtype)
schema.with_column(fld.name, fld.dtype);
}
}
}
Expand Down Expand Up @@ -550,7 +550,7 @@ impl LogicalPlanBuilder {
if let Expr::Column(name) = e {
if let Some(DataType::List(inner)) = schema.get(name) {
let inner = *inner.clone();
schema.with_column(name.to_string(), inner)
schema.with_column(name.to_string(), inner);
}

(**name).to_owned()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,15 @@ impl FunctionNode {
if let DataType::Struct(flds) = dtype {
for fld in flds {
new_schema
.with_column(fld.name().clone(), fld.data_type().clone())
.with_column(fld.name().clone(), fld.data_type().clone());
}
} else {
return Err(PolarsError::ComputeError(
format!("expected struct dtype, got: '{dtype:?}'").into(),
));
}
} else {
new_schema.with_column(name.clone(), dtype.clone())
new_schema.with_column(name.clone(), dtype.clone());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ fn update_scan_schema(
new_cols.sort_unstable_by_key(|item| item.0);
}
for item in new_cols {
new_schema.with_column(item.1.clone(), item.2.clone())
new_schema.with_column(item.1.clone(), item.2.clone());
}
Ok(new_schema)
}
Expand Down Expand Up @@ -718,7 +718,7 @@ impl ProjectionPushDown {
let other_schema = lp_arena.get(*node).schema(lp_arena);
for fld in other_schema.iter_fields() {
if new_schema.get(fld.name()).is_none() {
new_schema.with_column(fld.name, fld.dtype)
new_schema.with_column(fld.name, fld.dtype);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/polars-plan/src/logical_plan/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ pub(crate) fn det_join_schema(

for (name, dtype) in schema_left.iter() {
names.insert(name.as_str());
new_schema.with_column(name.to_string(), dtype.clone())
new_schema.with_column(name.to_string(), dtype.clone());
}

// make sure that expression are assigned to the schema
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/frame/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ impl<'a> LazyCsvReader<'a> {
// the dtypes set may be for the new names, so update again
if let Some(overwrite_schema) = self.schema_overwrite {
for (name, dtype) in overwrite_schema.iter() {
schema.with_column(name.clone(), dtype.clone())
schema.with_column(name.clone(), dtype.clone());
}
}

Expand Down
19 changes: 13 additions & 6 deletions polars/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ impl LazyFrame {
// schema after renaming
for (old, new) in existing2.iter().zip(new2.iter()) {
let dtype = old_schema.try_get(old)?;
new_schema.with_column(new.clone(), dtype.clone());
if new_schema.with_column(new.clone(), dtype.clone()).is_none() {
new_schema.remove(old);
}
}
Ok(Arc::new(new_schema))
};
Expand All @@ -297,7 +299,12 @@ impl LazyFrame {
let columns = std::mem::take(df.get_columns_mut());
DataFrame::new(columns)
},
None,
// Don't allow optimizations. Swapping names are opaque to the optimizer
AllowedOptimizations {
projection_pushdown: false,
predicate_pushdown: false,
..Default::default()
},
Some(Arc::new(udf_schema)),
Some("RENAME_SWAPPING"),
)
Expand Down Expand Up @@ -343,7 +350,7 @@ impl LazyFrame {
cols.truncate(cols.len() - (existing.len() - removed_count));
Ok(df)
},
None,
Default::default(),
Some(Arc::new(udf_schema)),
Some("RENAME"),
)
Expand Down Expand Up @@ -1118,7 +1125,7 @@ impl LazyFrame {
pub fn map<F>(
self,
function: F,
optimizations: Option<AllowedOptimizations>,
optimizations: AllowedOptimizations,
schema: Option<Arc<dyn UdfSchema>>,
name: Option<&'static str>,
) -> LazyFrame
Expand All @@ -1130,7 +1137,7 @@ impl LazyFrame {
.get_plan_builder()
.map(
function,
optimizations.unwrap_or_default(),
optimizations,
schema,
name.unwrap_or("ANONYMOUS UDF"),
)
Expand Down Expand Up @@ -1206,7 +1213,7 @@ impl LazyFrame {
Ok(df)
}
},
Some(opt),
opt,
Some(Arc::new(udf_schema)),
Some("WITH ROW COUNT"),
)
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/tests/predicate_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ fn filter_blocked_by_map() -> PolarsResult<()> {
};
let q = df
.lazy()
.map(|df| Ok(df), Some(allowed), None, None)
.map(|df| Ok(df), allowed, None, None)
.filter(col("A").gt(lit(2i32)));

assert!(!predicate_at_scan(q.clone()));
Expand Down
5 changes: 5 additions & 0 deletions py-polars/polars/internals/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2736,6 +2736,11 @@ def rename(self: LDF, mapping: dict[str, str]) -> LDF:
mapping
Key value pairs that map from old name to new name.

Notes
-----
If names are swapped. E.g. 'A' points to 'B' and 'B' points to 'A', polars
will block projection and predicate pushdowns at this node.

Examples
--------
>>> df = pl.DataFrame(
Expand Down
2 changes: 1 addition & 1 deletion py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl PyDataFrame {
*dtype_ = dtype;
}
} else {
schema.with_column(name, dtype)
schema.with_column(name, dtype);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion py-polars/src/lazy/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@ impl PyLazyFrame {

let udf_schema =
schema.map(move |s| Arc::new(move |_: &Schema| Ok(s.clone())) as Arc<dyn UdfSchema>);
ldf.map(function, Some(opt), udf_schema, None).into()
ldf.map(function, opt, udf_schema, None).into()
}

pub fn drop_columns(&self, cols: Vec<String>) -> Self {
Expand Down
35 changes: 35 additions & 0 deletions py-polars/tests/unit/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1915,6 +1915,41 @@ def test_rename_swap() -> None:
)
assert out.frame_equal(expected)

# 6195
ldf = pl.DataFrame(
{
"weekday": [
1,
],
"priority": [
2,
],
"roundNumber": [
3,
],
"flag": [
4,
],
}
).lazy()

# Rename some columns (note: swapping two columns)
rename_dict = {
"weekday": "priority",
"priority": "weekday",
"roundNumber": "round_number",
}
ldf = ldf.rename(rename_dict)

# Select some columns
ldf = ldf.select(["priority", "weekday", "round_number"])

assert ldf.collect().to_dict(False) == {
"priority": [1],
"weekday": [2],
"round_number": [3],
}


def test_rename_same_name() -> None:
df = pl.DataFrame(
Expand Down