Skip to content

Commit

Permalink
fix: Check funtion input len at expansion (#17763)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 21, 2024
1 parent 9de860a commit 7888d3b
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 9 deletions.
8 changes: 6 additions & 2 deletions crates/polars-plan/src/dsl/functions/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ pub fn all_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
input: exprs,
function: FunctionExpr::Boolean(BooleanFunction::AllHorizontal),
options: FunctionOptions {
flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION,
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION
| FunctionFlags::ALLOW_EMPTY_INPUTS,
..Default::default()
},
})
Expand All @@ -221,7 +223,9 @@ pub fn any_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
input: exprs,
function: FunctionExpr::Boolean(BooleanFunction::AnyHorizontal),
options: FunctionOptions {
flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION,
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION
| FunctionFlags::ALLOW_EMPTY_INPUTS,
..Default::default()
},
})
Expand Down
10 changes: 9 additions & 1 deletion crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,15 @@ impl Expr {

/// Drop null values.
pub fn drop_nulls(self) -> Self {
self.apply_private(FunctionExpr::DropNulls)
Expr::Function {
input: vec![self],
function: FunctionExpr::DropNulls,
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
flags: FunctionFlags::default() | FunctionFlags::ALLOW_EMPTY_INPUTS,
..Default::default()
},
}
}

/// Drop NaN values.
Expand Down
15 changes: 10 additions & 5 deletions crates/polars-plan/src/plans/conversion/expr_expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,17 +541,22 @@ fn prepare_excluded(
}

// functions can have col(["a", "b"]) or col(String) as inputs
fn expand_function_inputs(expr: Expr, schema: &Schema) -> Expr {
expr.map_expr(|mut e| match &mut e {
fn expand_function_inputs(expr: Expr, schema: &Schema) -> PolarsResult<Expr> {
expr.try_map_expr(|mut e| match &mut e {
Expr::AnonymousFunction { input, options, .. } | Expr::Function { input, options, .. }
if options
.flags
.contains(FunctionFlags::INPUT_WILDCARD_EXPANSION) =>
{
*input = rewrite_projections(core::mem::take(input), schema, &[]).unwrap();
e
if input.is_empty() && !options.flags.contains(FunctionFlags::ALLOW_EMPTY_INPUTS) {
// Needed to visualize the error
*input = vec![Expr::Literal(LiteralValue::Null)];
polars_bail!(InvalidOperation: "expected at least 1 input in {}", e)
}
Ok(e)
},
_ => e,
_ => Ok(e),
})
}

Expand Down Expand Up @@ -648,7 +653,7 @@ pub(crate) fn rewrite_projections(
let result_offset = result.len();

// Functions can have col(["a", "b"]) or col(String) as inputs.
expr = expand_function_inputs(expr, schema);
expr = expand_function_inputs(expr, schema)?;

let mut flags = find_flags(&expr)?;
if flags.has_selector {
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-plan/src/plans/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ bitflags!(
/// This can lead to recursively entering the engine and sometimes deadlocks.
/// This flag must be set to handle that.
const OPTIONAL_RE_ENTRANT = 1 << 6;
/// Whether this function allows no inputs.
const ALLOW_EMPTY_INPUTS = 1 << 7;
}
);

Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/unit/test_empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,14 @@ def test_empty_list_cat_16405() -> None:
def test_empty_list_concat_16924() -> None:
df = pl.DataFrame(schema={"a": pl.Int16, "b": pl.List(pl.String)})
df.with_columns(pl.col("b").list.concat([pl.col("a").cast(pl.String)]))


def test_empty_input_expansion() -> None:
df = pl.DataFrame({"A": [1], "B": [2]})

with pytest.raises(pl.exceptions.InvalidOperationError):
(
df.select("A", "B").with_columns(
pl.col("B").sort_by(pl.struct(pl.exclude("A", "B")))
)
)
2 changes: 1 addition & 1 deletion py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def test_sort_by_err_9259() -> None:
def test_empty_inputs_error() -> None:
df = pl.DataFrame({"col1": [1]})
with pytest.raises(
ComputeError, match="expression: 'sum_horizontal' didn't get any inputs"
pl.exceptions.InvalidOperationError, match="expected at least 1 input"
):
df.select(pl.sum_horizontal(pl.exclude("col1")))

Expand Down

0 comments on commit 7888d3b

Please sign in to comment.