From b925b78fd8040f858168e439eda5042bd2a34af6 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 20 Dec 2023 18:18:56 +0100 Subject: [PATCH 01/63] replace not-impl-err (#8589) --- datafusion/physical-expr/src/array_expressions.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index d39658108337..0a7631918804 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -455,7 +455,7 @@ pub fn array_element(args: &[ArrayRef]) -> Result { let indexes = as_int64_array(&args[1])?; general_array_element::(array, indexes) } - _ => not_impl_err!( + _ => exec_err!( "array_element does not support type: {:?}", args[0].data_type() ), @@ -571,7 +571,7 @@ pub fn array_slice(args: &[ArrayRef]) -> Result { let to_array = as_int64_array(&args[2])?; general_array_slice::(array, from_array, to_array) } - _ => not_impl_err!("array_slice does not support type: {:?}", array_data_type), + _ => exec_err!("array_slice does not support type: {:?}", array_data_type), } } @@ -1335,7 +1335,7 @@ pub fn array_positions(args: &[ArrayRef]) -> Result { general_positions::(arr, element) } array_type => { - not_impl_err!("array_positions does not support type '{array_type:?}'.") + exec_err!("array_positions does not support type '{array_type:?}'.") } } } From 0e9c189a2e4f8f6304239d6cbe14f5114a6d0406 Mon Sep 17 00:00:00 2001 From: Tanmay Gujar Date: Wed, 20 Dec 2023 15:48:11 -0500 Subject: [PATCH 02/63] Substrait insubquery (#8363) * testing in subquery support for substrait producer * consumer fails with table not found * testing roundtrip check * pass in ctx to expr * basic test for Insubquery * fix: outer refs in consumer * fix: merge issues * minor fixes * fix: fmt and clippy CI errors * improve error msg in consumer * minor fixes --- .../substrait/src/logical_plan/consumer.rs | 151 +++++++++++++---- .../substrait/src/logical_plan/producer.rs | 155 ++++++++++++++---- .../tests/cases/roundtrip_logical_plan.rs | 18 ++ 3 files changed, 256 insertions(+), 68 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index b7fee96bba1c..9931dd15aec8 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -28,7 +28,7 @@ use datafusion::logical_expr::{ }; use datafusion::logical_expr::{ expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, - Repartition, WindowFrameBound, WindowFrameUnits, + Repartition, Subquery, WindowFrameBound, WindowFrameUnits, }; use datafusion::prelude::JoinType; use datafusion::sql::TableReference; @@ -39,6 +39,7 @@ use datafusion::{ scalar::ScalarValue, }; use substrait::proto::exchange_rel::ExchangeKind; +use substrait::proto::expression::subquery::SubqueryType; use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; use substrait::proto::{ aggregate_function::AggregationInvocation, @@ -61,7 +62,7 @@ use substrait::proto::{ use substrait::proto::{FunctionArgument, SortField}; use datafusion::common::plan_err; -use datafusion::logical_expr::expr::{InList, Sort}; +use datafusion::logical_expr::expr::{InList, InSubquery, Sort}; use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; @@ -230,7 +231,8 @@ pub async fn from_substrait_rel( let mut exprs: Vec = vec![]; for e in &p.expressions { let x = - from_substrait_rex(e, input.clone().schema(), extensions).await?; + from_substrait_rex(ctx, e, input.clone().schema(), extensions) + .await?; // if the expression is WindowFunction, wrap in a Window relation // before returning and do not add to list of this Projection's expression list // otherwise, add expression to the Projection's expression list @@ -256,7 +258,8 @@ pub async fn from_substrait_rel( ); if let Some(condition) = filter.condition.as_ref() { let expr = - from_substrait_rex(condition, input.schema(), extensions).await?; + from_substrait_rex(ctx, condition, input.schema(), extensions) + .await?; input.filter(expr.as_ref().clone())?.build() } else { not_impl_err!("Filter without an condition is not valid") @@ -288,7 +291,8 @@ pub async fn from_substrait_rel( from_substrait_rel(ctx, input, extensions).await?, ); let sorts = - from_substrait_sorts(&sort.sorts, input.schema(), extensions).await?; + from_substrait_sorts(ctx, &sort.sorts, input.schema(), extensions) + .await?; input.sort(sorts)?.build() } else { not_impl_err!("Sort without an input is not valid") @@ -306,7 +310,8 @@ pub async fn from_substrait_rel( 1 => { for e in &agg.groupings[0].grouping_expressions { let x = - from_substrait_rex(e, input.schema(), extensions).await?; + from_substrait_rex(ctx, e, input.schema(), extensions) + .await?; group_expr.push(x.as_ref().clone()); } } @@ -315,8 +320,13 @@ pub async fn from_substrait_rel( for grouping in &agg.groupings { let mut grouping_set = vec![]; for e in &grouping.grouping_expressions { - let x = from_substrait_rex(e, input.schema(), extensions) - .await?; + let x = from_substrait_rex( + ctx, + e, + input.schema(), + extensions, + ) + .await?; grouping_set.push(x.as_ref().clone()); } grouping_sets.push(grouping_set); @@ -334,7 +344,7 @@ pub async fn from_substrait_rel( for m in &agg.measures { let filter = match &m.filter { Some(fil) => Some(Box::new( - from_substrait_rex(fil, input.schema(), extensions) + from_substrait_rex(ctx, fil, input.schema(), extensions) .await? .as_ref() .clone(), @@ -402,8 +412,8 @@ pub async fn from_substrait_rel( // Otherwise, build join with only the filter, without join keys match &join.expression.as_ref() { Some(expr) => { - let on = - from_substrait_rex(expr, &in_join_schema, extensions).await?; + let on = from_substrait_rex(ctx, expr, &in_join_schema, extensions) + .await?; // The join expression can contain both equal and non-equal ops. // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. // So we extract each part as follows: @@ -612,14 +622,16 @@ fn from_substrait_jointype(join_type: i32) -> Result { /// Convert Substrait Sorts to DataFusion Exprs pub async fn from_substrait_sorts( + ctx: &SessionContext, substrait_sorts: &Vec, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { let mut sorts: Vec = vec![]; for s in substrait_sorts { - let expr = from_substrait_rex(s.expr.as_ref().unwrap(), input_schema, extensions) - .await?; + let expr = + from_substrait_rex(ctx, s.expr.as_ref().unwrap(), input_schema, extensions) + .await?; let asc_nullfirst = match &s.sort_kind { Some(k) => match k { Direction(d) => { @@ -660,13 +672,14 @@ pub async fn from_substrait_sorts( /// Convert Substrait Expressions to DataFusion Exprs pub async fn from_substrait_rex_vec( + ctx: &SessionContext, exprs: &Vec, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { let mut expressions: Vec = vec![]; for expr in exprs { - let expression = from_substrait_rex(expr, input_schema, extensions).await?; + let expression = from_substrait_rex(ctx, expr, input_schema, extensions).await?; expressions.push(expression.as_ref().clone()); } Ok(expressions) @@ -674,6 +687,7 @@ pub async fn from_substrait_rex_vec( /// Convert Substrait FunctionArguments to DataFusion Exprs pub async fn from_substriat_func_args( + ctx: &SessionContext, arguments: &Vec, input_schema: &DFSchema, extensions: &HashMap, @@ -682,7 +696,7 @@ pub async fn from_substriat_func_args( for arg in arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(e, input_schema, extensions).await + from_substrait_rex(ctx, e, input_schema, extensions).await } _ => { not_impl_err!("Aggregated function argument non-Value type not supported") @@ -707,7 +721,7 @@ pub async fn from_substrait_agg_func( for arg in &f.arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(e, input_schema, extensions).await + from_substrait_rex(ctx, e, input_schema, extensions).await } _ => { not_impl_err!("Aggregated function argument non-Value type not supported") @@ -745,6 +759,7 @@ pub async fn from_substrait_agg_func( /// Convert Substrait Rex to DataFusion Expr #[async_recursion] pub async fn from_substrait_rex( + ctx: &SessionContext, e: &Expression, input_schema: &DFSchema, extensions: &HashMap, @@ -755,13 +770,18 @@ pub async fn from_substrait_rex( let substrait_list = s.options.as_ref(); Ok(Arc::new(Expr::InList(InList { expr: Box::new( - from_substrait_rex(substrait_expr, input_schema, extensions) + from_substrait_rex(ctx, substrait_expr, input_schema, extensions) .await? .as_ref() .clone(), ), - list: from_substrait_rex_vec(substrait_list, input_schema, extensions) - .await?, + list: from_substrait_rex_vec( + ctx, + substrait_list, + input_schema, + extensions, + ) + .await?, negated: false, }))) } @@ -779,6 +799,7 @@ pub async fn from_substrait_rex( if if_expr.then.is_none() { expr = Some(Box::new( from_substrait_rex( + ctx, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -793,6 +814,7 @@ pub async fn from_substrait_rex( when_then_expr.push(( Box::new( from_substrait_rex( + ctx, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -803,6 +825,7 @@ pub async fn from_substrait_rex( ), Box::new( from_substrait_rex( + ctx, if_expr.then.as_ref().unwrap(), input_schema, extensions, @@ -816,7 +839,7 @@ pub async fn from_substrait_rex( // Parse `else` let else_expr = match &if_then.r#else { Some(e) => Some(Box::new( - from_substrait_rex(e, input_schema, extensions) + from_substrait_rex(ctx, e, input_schema, extensions) .await? .as_ref() .clone(), @@ -843,7 +866,7 @@ pub async fn from_substrait_rex( for arg in &f.arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(e, input_schema, extensions).await + from_substrait_rex(ctx, e, input_schema, extensions).await } _ => not_impl_err!( "Aggregated function argument non-Value type not supported" @@ -868,14 +891,14 @@ pub async fn from_substrait_rex( (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => { Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { left: Box::new( - from_substrait_rex(l, input_schema, extensions) + from_substrait_rex(ctx, l, input_schema, extensions) .await? .as_ref() .clone(), ), op, right: Box::new( - from_substrait_rex(r, input_schema, extensions) + from_substrait_rex(ctx, r, input_schema, extensions) .await? .as_ref() .clone(), @@ -888,7 +911,7 @@ pub async fn from_substrait_rex( } } ScalarFunctionType::Expr(builder) => { - builder.build(f, input_schema, extensions).await + builder.build(ctx, f, input_schema, extensions).await } } } @@ -900,6 +923,7 @@ pub async fn from_substrait_rex( Some(output_type) => Ok(Arc::new(Expr::Cast(Cast::new( Box::new( from_substrait_rex( + ctx, cast.as_ref().input.as_ref().unwrap().as_ref(), input_schema, extensions, @@ -921,7 +945,8 @@ pub async fn from_substrait_rex( ), }; let order_by = - from_substrait_sorts(&window.sorts, input_schema, extensions).await?; + from_substrait_sorts(ctx, &window.sorts, input_schema, extensions) + .await?; // Substrait does not encode WindowFrameUnits so we're using a simple logic to determine the units // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row @@ -934,12 +959,14 @@ pub async fn from_substrait_rex( Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction { fun: fun?.unwrap(), args: from_substriat_func_args( + ctx, &window.arguments, input_schema, extensions, ) .await?, partition_by: from_substrait_rex_vec( + ctx, &window.partitions, input_schema, extensions, @@ -953,6 +980,51 @@ pub async fn from_substrait_rex( }, }))) } + Some(RexType::Subquery(subquery)) => match &subquery.as_ref().subquery_type { + Some(subquery_type) => match subquery_type { + SubqueryType::InPredicate(in_predicate) => { + if in_predicate.needles.len() != 1 { + Err(DataFusionError::Substrait( + "InPredicate Subquery type must have exactly one Needle expression" + .to_string(), + )) + } else { + let needle_expr = &in_predicate.needles[0]; + let haystack_expr = &in_predicate.haystack; + if let Some(haystack_expr) = haystack_expr { + let haystack_expr = + from_substrait_rel(ctx, haystack_expr, extensions) + .await?; + let outer_refs = haystack_expr.all_out_ref_exprs(); + Ok(Arc::new(Expr::InSubquery(InSubquery { + expr: Box::new( + from_substrait_rex( + ctx, + needle_expr, + input_schema, + extensions, + ) + .await? + .as_ref() + .clone(), + ), + subquery: Subquery { + subquery: Arc::new(haystack_expr), + outer_ref_columns: outer_refs, + }, + negated: false, + }))) + } else { + substrait_err!("InPredicate Subquery type must have a Haystack expression") + } + } + } + _ => substrait_err!("Subquery type not implemented"), + }, + None => { + substrait_err!("Subquery experssion without SubqueryType is not allowed") + } + }, _ => not_impl_err!("unsupported rex_type"), } } @@ -1312,16 +1384,22 @@ impl BuiltinExprBuilder { pub async fn build( self, + ctx: &SessionContext, f: &ScalarFunction, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { match self.expr_name.as_str() { - "like" => Self::build_like_expr(false, f, input_schema, extensions).await, - "ilike" => Self::build_like_expr(true, f, input_schema, extensions).await, + "like" => { + Self::build_like_expr(ctx, false, f, input_schema, extensions).await + } + "ilike" => { + Self::build_like_expr(ctx, true, f, input_schema, extensions).await + } "not" | "negative" | "is_null" | "is_not_null" | "is_true" | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { - Self::build_unary_expr(&self.expr_name, f, input_schema, extensions).await + Self::build_unary_expr(ctx, &self.expr_name, f, input_schema, extensions) + .await } _ => { not_impl_err!("Unsupported builtin expression: {}", self.expr_name) @@ -1330,6 +1408,7 @@ impl BuiltinExprBuilder { } async fn build_unary_expr( + ctx: &SessionContext, fn_name: &str, f: &ScalarFunction, input_schema: &DFSchema, @@ -1341,7 +1420,7 @@ impl BuiltinExprBuilder { let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for {fn_name} expr"); }; - let arg = from_substrait_rex(expr_substrait, input_schema, extensions) + let arg = from_substrait_rex(ctx, expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); @@ -1365,6 +1444,7 @@ impl BuiltinExprBuilder { } async fn build_like_expr( + ctx: &SessionContext, case_insensitive: bool, f: &ScalarFunction, input_schema: &DFSchema, @@ -1378,22 +1458,23 @@ impl BuiltinExprBuilder { let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let expr = from_substrait_rex(expr_substrait, input_schema, extensions) + let expr = from_substrait_rex(ctx, expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); + let pattern = + from_substrait_rex(ctx, pattern_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let escape_char_expr = - from_substrait_rex(escape_char_substrait, input_schema, extensions) + from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions) .await? .as_ref() .clone(); diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 50f872544298..926883251a63 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -36,12 +36,13 @@ use datafusion::common::{substrait_err, DFSchemaRef}; use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, - ScalarFunctionDefinition, Sort, WindowFunction, + InSubquery, ScalarFunctionDefinition, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; +use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; use substrait::proto::{CrossRel, ExchangeRel}; use substrait::{ @@ -58,7 +59,8 @@ use substrait::{ window_function::bound::Kind as BoundKind, window_function::Bound, FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, - ScalarFunction, SingularOrList, WindowFunction as SubstraitWindowFunction, + ScalarFunction, SingularOrList, Subquery, + WindowFunction as SubstraitWindowFunction, }, extensions::{ self, @@ -167,7 +169,7 @@ pub fn to_substrait_rel( let expressions = p .expr .iter() - .map(|e| to_substrait_rex(e, p.input.schema(), 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extension_info)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { @@ -181,6 +183,7 @@ pub fn to_substrait_rel( LogicalPlan::Filter(filter) => { let input = to_substrait_rel(filter.input.as_ref(), ctx, extension_info)?; let filter_expr = to_substrait_rex( + ctx, &filter.predicate, filter.input.schema(), 0, @@ -214,7 +217,9 @@ pub fn to_substrait_rel( let sort_fields = sort .expr .iter() - .map(|e| substrait_sort_field(e, sort.input.schema(), extension_info)) + .map(|e| { + substrait_sort_field(ctx, e, sort.input.schema(), extension_info) + }) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { @@ -228,6 +233,7 @@ pub fn to_substrait_rel( LogicalPlan::Aggregate(agg) => { let input = to_substrait_rel(agg.input.as_ref(), ctx, extension_info)?; let groupings = to_substrait_groupings( + ctx, &agg.group_expr, agg.input.schema(), extension_info, @@ -235,7 +241,9 @@ pub fn to_substrait_rel( let measures = agg .aggr_expr .iter() - .map(|e| to_substrait_agg_measure(e, agg.input.schema(), extension_info)) + .map(|e| { + to_substrait_agg_measure(ctx, e, agg.input.schema(), extension_info) + }) .collect::>>()?; Ok(Box::new(Rel { @@ -283,6 +291,7 @@ pub fn to_substrait_rel( let in_join_schema = join.left.schema().join(join.right.schema())?; let join_filter = match &join.filter { Some(filter) => Some(to_substrait_rex( + ctx, filter, &Arc::new(in_join_schema), 0, @@ -299,6 +308,7 @@ pub fn to_substrait_rel( Operator::Eq }; let join_on = to_substrait_join_expr( + ctx, &join.on, eq_op, join.left.schema(), @@ -401,6 +411,7 @@ pub fn to_substrait_rel( let mut window_exprs = vec![]; for expr in &window.window_expr { window_exprs.push(to_substrait_rex( + ctx, expr, window.input.schema(), 0, @@ -500,6 +511,7 @@ pub fn to_substrait_rel( } fn to_substrait_join_expr( + ctx: &SessionContext, join_conditions: &Vec<(Expr, Expr)>, eq_op: Operator, left_schema: &DFSchemaRef, @@ -513,9 +525,10 @@ fn to_substrait_join_expr( let mut exprs: Vec = vec![]; for (left, right) in join_conditions { // Parse left - let l = to_substrait_rex(left, left_schema, 0, extension_info)?; + let l = to_substrait_rex(ctx, left, left_schema, 0, extension_info)?; // Parse right let r = to_substrait_rex( + ctx, right, right_schema, left_schema.fields().len(), // offset to return the correct index @@ -576,6 +589,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { } pub fn parse_flat_grouping_exprs( + ctx: &SessionContext, exprs: &[Expr], schema: &DFSchemaRef, extension_info: &mut ( @@ -585,7 +599,7 @@ pub fn parse_flat_grouping_exprs( ) -> Result { let grouping_expressions = exprs .iter() - .map(|e| to_substrait_rex(e, schema, 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, 0, extension_info)) .collect::>>()?; Ok(Grouping { grouping_expressions, @@ -593,6 +607,7 @@ pub fn parse_flat_grouping_exprs( } pub fn to_substrait_groupings( + ctx: &SessionContext, exprs: &Vec, schema: &DFSchemaRef, extension_info: &mut ( @@ -608,7 +623,9 @@ pub fn to_substrait_groupings( )), GroupingSet::GroupingSets(sets) => Ok(sets .iter() - .map(|set| parse_flat_grouping_exprs(set, schema, extension_info)) + .map(|set| { + parse_flat_grouping_exprs(ctx, set, schema, extension_info) + }) .collect::>>()?), GroupingSet::Rollup(set) => { let mut sets: Vec> = vec![vec![]]; @@ -618,17 +635,21 @@ pub fn to_substrait_groupings( Ok(sets .iter() .rev() - .map(|set| parse_flat_grouping_exprs(set, schema, extension_info)) + .map(|set| { + parse_flat_grouping_exprs(ctx, set, schema, extension_info) + }) .collect::>>()?) } }, _ => Ok(vec![parse_flat_grouping_exprs( + ctx, exprs, schema, extension_info, )?]), }, _ => Ok(vec![parse_flat_grouping_exprs( + ctx, exprs, schema, extension_info, @@ -638,6 +659,7 @@ pub fn to_substrait_groupings( #[allow(deprecated)] pub fn to_substrait_agg_measure( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, extension_info: &mut ( @@ -650,13 +672,13 @@ pub fn to_substrait_agg_measure( match func_def { AggregateFunctionDefinition::BuiltIn (fun) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); } let function_anchor = _register_function(fun.to_string(), extension_info); Ok(Measure { @@ -674,20 +696,20 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), None => None } }) } AggregateFunctionDefinition::UDF(fun) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); } let function_anchor = _register_function(fun.name().to_string(), extension_info); Ok(Measure { @@ -702,7 +724,7 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), None => None } }) @@ -714,7 +736,7 @@ pub fn to_substrait_agg_measure( } Expr::Alias(Alias{expr,..})=> { - to_substrait_agg_measure(expr, schema, extension_info) + to_substrait_agg_measure(ctx, expr, schema, extension_info) } _ => internal_err!( "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", @@ -726,6 +748,7 @@ pub fn to_substrait_agg_measure( /// Converts sort expression to corresponding substrait `SortField` fn to_substrait_sort_field( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, extension_info: &mut ( @@ -743,6 +766,7 @@ fn to_substrait_sort_field( }; Ok(SortField { expr: Some(to_substrait_rex( + ctx, sort.expr.deref(), schema, 0, @@ -851,6 +875,7 @@ pub fn make_binary_op_scalar_func( /// * `extension_info` - Substrait extension info. Contains registered function information #[allow(deprecated)] pub fn to_substrait_rex( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, @@ -867,10 +892,10 @@ pub fn to_substrait_rex( }) => { let substrait_list = list .iter() - .map(|x| to_substrait_rex(x, schema, col_ref_offset, extension_info)) + .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extension_info)) .collect::>>()?; let substrait_expr = - to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { @@ -903,6 +928,7 @@ pub fn to_substrait_rex( for arg in &fun.args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( + ctx, arg, schema, col_ref_offset, @@ -937,11 +963,11 @@ pub fn to_substrait_rex( if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = - to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let substrait_low = - to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; let substrait_high = - to_substrait_rex(high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, @@ -965,11 +991,11 @@ pub fn to_substrait_rex( } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) let substrait_expr = - to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let substrait_low = - to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; let substrait_high = - to_substrait_rex(high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_low, @@ -997,8 +1023,8 @@ pub fn to_substrait_rex( substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(left, schema, col_ref_offset, extension_info)?; - let r = to_substrait_rex(right, schema, col_ref_offset, extension_info)?; + let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extension_info)?; + let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extension_info)?; Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) } @@ -1013,6 +1039,7 @@ pub fn to_substrait_rex( // Base expression exists ifs.push(IfClause { r#if: Some(to_substrait_rex( + ctx, e, schema, col_ref_offset, @@ -1025,12 +1052,14 @@ pub fn to_substrait_rex( for (r#if, then) in when_then_expr { ifs.push(IfClause { r#if: Some(to_substrait_rex( + ctx, r#if, schema, col_ref_offset, extension_info, )?), then: Some(to_substrait_rex( + ctx, then, schema, col_ref_offset, @@ -1042,6 +1071,7 @@ pub fn to_substrait_rex( // Parse outer `else` let r#else: Option> = match else_expr { Some(e) => Some(Box::new(to_substrait_rex( + ctx, e, schema, col_ref_offset, @@ -1060,6 +1090,7 @@ pub fn to_substrait_rex( substrait::proto::expression::Cast { r#type: Some(to_substrait_type(data_type)?), input: Some(Box::new(to_substrait_rex( + ctx, expr, schema, col_ref_offset, @@ -1072,7 +1103,7 @@ pub fn to_substrait_rex( } Expr::Literal(value) => to_substrait_literal(value), Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(expr, schema, col_ref_offset, extension_info) + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info) } Expr::WindowFunction(WindowFunction { fun, @@ -1088,6 +1119,7 @@ pub fn to_substrait_rex( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( + ctx, arg, schema, col_ref_offset, @@ -1098,12 +1130,12 @@ pub fn to_substrait_rex( // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(e, schema, col_ref_offset, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extension_info)) .collect::>>()?; // order by expressions let order_by = order_by .iter() - .map(|e| substrait_sort_field(e, schema, extension_info)) + .map(|e| substrait_sort_field(ctx, e, schema, extension_info)) .collect::>>()?; // window frame let bounds = to_substrait_bounds(window_frame)?; @@ -1124,6 +1156,7 @@ pub fn to_substrait_rex( escape_char, case_insensitive, }) => make_substrait_like_expr( + ctx, *case_insensitive, *negated, expr, @@ -1133,7 +1166,50 @@ pub fn to_substrait_rex( col_ref_offset, extension_info, ), + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => { + let substrait_expr = + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + + let subquery_plan = + to_substrait_rel(subquery.subquery.as_ref(), ctx, extension_info)?; + + let substrait_subquery = Expression { + rex_type: Some(RexType::Subquery(Box::new(Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::InPredicate( + Box::new(InPredicate { + needles: (vec![substrait_expr]), + haystack: Some(subquery_plan), + }), + ), + ), + }))), + }; + if *negated { + let function_anchor = + _register_function("not".to_string(), extension_info); + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_subquery)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_subquery) + } + } Expr::Not(arg) => to_substrait_unary_scalar_fn( + ctx, "not", arg, schema, @@ -1141,6 +1217,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsNull(arg) => to_substrait_unary_scalar_fn( + ctx, "is_null", arg, schema, @@ -1148,6 +1225,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( + ctx, "is_not_null", arg, schema, @@ -1155,6 +1233,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( + ctx, "is_true", arg, schema, @@ -1162,6 +1241,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( + ctx, "is_false", arg, schema, @@ -1169,6 +1249,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( + ctx, "is_unknown", arg, schema, @@ -1176,6 +1257,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( + ctx, "is_not_true", arg, schema, @@ -1183,6 +1265,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( + ctx, "is_not_false", arg, schema, @@ -1190,6 +1273,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( + ctx, "is_not_unknown", arg, schema, @@ -1197,6 +1281,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::Negative(arg) => to_substrait_unary_scalar_fn( + ctx, "negative", arg, schema, @@ -1421,6 +1506,7 @@ fn make_substrait_window_function( #[allow(deprecated)] #[allow(clippy::too_many_arguments)] fn make_substrait_like_expr( + ctx: &SessionContext, ignore_case: bool, negated: bool, expr: &Expr, @@ -1438,8 +1524,8 @@ fn make_substrait_like_expr( } else { _register_function("like".to_string(), extension_info) }; - let expr = to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; - let pattern = to_substrait_rex(pattern, schema, col_ref_offset, extension_info)?; + let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; let escape_char = to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c| c.to_string())))?; let arguments = vec![ @@ -1669,6 +1755,7 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { /// Util to generate substrait [RexType::ScalarFunction] with one argument fn to_substrait_unary_scalar_fn( + ctx: &SessionContext, fn_name: &str, arg: &Expr, schema: &DFSchemaRef, @@ -1679,7 +1766,8 @@ fn to_substrait_unary_scalar_fn( ), ) -> Result { let function_anchor = _register_function(fn_name.to_string(), extension_info); - let substrait_expr = to_substrait_rex(arg, schema, col_ref_offset, extension_info)?; + let substrait_expr = + to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?; Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1880,6 +1968,7 @@ fn try_to_substrait_field_reference( } fn substrait_sort_field( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, extension_info: &mut ( @@ -1893,7 +1982,7 @@ fn substrait_sort_field( asc, nulls_first, }) => { - let e = to_substrait_rex(expr, schema, 0, extension_info)?; + let e = to_substrait_rex(ctx, expr, schema, 0, extension_info)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 47eb5a8f73f5..d7327caee43d 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -394,6 +394,24 @@ async fn roundtrip_inlist_4() -> Result<()> { roundtrip("SELECT * FROM data WHERE f NOT IN ('a', 'b', 'c', 'd')").await } +#[tokio::test] +async fn roundtrip_inlist_5() -> Result<()> { + // on roundtrip there is an additional projection during TableScan which includes all column of the table, + // using assert_expected_plan here as a workaround + assert_expected_plan( + "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))", + "Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()\ + \n Subquery:\ + \n Projection: data2.a\ + \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ + \n TableScan: data2 projection=[a, b, c, d, e, f]\ + \n TableScan: data projection=[a, f], partial_filters=[data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()]\ + \n Subquery:\ + \n Projection: data2.a\ + \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ + \n TableScan: data2 projection=[a, b, c, d, e, f]").await +} + #[tokio::test] async fn roundtrip_cross_join() -> Result<()> { roundtrip("SELECT * FROM data CROSS JOIN data2").await From 448e413584226fc86e3d35a2f90725bcbdf390c9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Dec 2023 15:48:46 -0500 Subject: [PATCH 03/63] Minor: port last test from parquet.rs (#8587) --- datafusion/core/tests/sql/mod.rs | 1 - datafusion/core/tests/sql/parquet.rs | 91 ------------------- .../sqllogictest/test_files/parquet.slt | 17 ++++ 3 files changed, 17 insertions(+), 92 deletions(-) delete mode 100644 datafusion/core/tests/sql/parquet.rs diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 94fc8015a78a..a3d5e32097c6 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -79,7 +79,6 @@ pub mod expr; pub mod group_by; pub mod joins; pub mod order; -pub mod parquet; pub mod parquet_schema; pub mod partitioned_csv; pub mod predicates; diff --git a/datafusion/core/tests/sql/parquet.rs b/datafusion/core/tests/sql/parquet.rs deleted file mode 100644 index f80a28f7e4f9..000000000000 --- a/datafusion/core/tests/sql/parquet.rs +++ /dev/null @@ -1,91 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion_common::cast::{as_list_array, as_primitive_array, as_string_array}; - -use super::*; - -#[tokio::test] -#[ignore = "Test ignored, will be enabled as part of the nested Parquet reader"] -async fn parquet_list_columns() { - let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "list_columns", - &format!("{testdata}/list_columns.parquet"), - ParquetReadOptions::default(), - ) - .await - .unwrap(); - - let schema = Arc::new(Schema::new(vec![ - Field::new_list( - "int64_list", - Field::new("item", DataType::Int64, true), - true, - ), - Field::new_list("utf8_list", Field::new("item", DataType::Utf8, true), true), - ])); - - let sql = "SELECT int64_list, utf8_list FROM list_columns"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - - // int64_list utf8_list - // 0 [1, 2, 3] [abc, efg, hij] - // 1 [None, 1] None - // 2 [4] [efg, None, hij, xyz] - - assert_eq!(1, results.len()); - let batch = &results[0]; - assert_eq!(3, batch.num_rows()); - assert_eq!(2, batch.num_columns()); - assert_eq!(schema, batch.schema()); - - let int_list_array = as_list_array(batch.column(0)).unwrap(); - let utf8_list_array = as_list_array(batch.column(1)).unwrap(); - - assert_eq!( - as_primitive_array::(&int_list_array.value(0)).unwrap(), - &PrimitiveArray::::from(vec![Some(1), Some(2), Some(3),]) - ); - - assert_eq!( - as_string_array(&utf8_list_array.value(0)).unwrap(), - &StringArray::from(vec![Some("abc"), Some("efg"), Some("hij"),]) - ); - - assert_eq!( - as_primitive_array::(&int_list_array.value(1)).unwrap(), - &PrimitiveArray::::from(vec![None, Some(1),]) - ); - - assert!(utf8_list_array.is_null(1)); - - assert_eq!( - as_primitive_array::(&int_list_array.value(2)).unwrap(), - &PrimitiveArray::::from(vec![Some(4),]) - ); - - let result = utf8_list_array.value(2); - let result = as_string_array(&result).unwrap(); - - assert_eq!(result.value(0), "efg"); - assert!(result.is_null(1)); - assert_eq!(result.value(2), "hij"); - assert_eq!(result.value(3), "xyz"); -} diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt index bbe7f33e260c..6c3bd687700a 100644 --- a/datafusion/sqllogictest/test_files/parquet.slt +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -302,3 +302,20 @@ NULL # Clean up statement ok DROP TABLE single_nan; + + +statement ok +CREATE EXTERNAL TABLE list_columns +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../../parquet-testing/data/list_columns.parquet'; + +query ?? +SELECT int64_list, utf8_list FROM list_columns +---- +[1, 2, 3] [abc, efg, hij] +[, 1] NULL +[4] [efg, , hij, xyz] + +statement ok +DROP TABLE list_columns; From 778779f7d72c45e7583100e5ff25c504cd48042b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Dec 2023 15:49:36 -0500 Subject: [PATCH 04/63] Minor: consolidate map sqllogictest tests (#8550) * Minor: consolidate map sqllogictest tests * add plan --- datafusion/sqllogictest/src/test_context.rs | 2 +- .../sqllogictest/test_files/explain.slt | 4 ---- datafusion/sqllogictest/test_files/map.slt | 19 +++++++++++++++++++ 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 91093510afec..941dcb69d2f4 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -84,7 +84,7 @@ impl TestContext { info!("Registering table with many types"); register_table_with_many_types(test_ctx.session_ctx()).await; } - "explain.slt" => { + "map.slt" => { info!("Registering table with map"); register_table_with_map(test_ctx.session_ctx()).await; } diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index a51c3aed13ec..4583ef319b7f 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -379,7 +379,3 @@ Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64 physical_plan ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] --PlaceholderRowExec - -# Testing explain on a table with a map filter, registered in test_context.rs. -statement ok -explain select * from table_with_map where int_field > 0 diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index c3d16fca904e..7863bf445499 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -44,3 +44,22 @@ DELETE 24 query T SELECT strings['not_found'] FROM data LIMIT 1; ---- + +statement ok +drop table data; + + +# Testing explain on a table with a map filter, registered in test_context.rs. +query TT +explain select * from table_with_map where int_field > 0; +---- +logical_plan +Filter: table_with_map.int_field > Int64(0) +--TableScan: table_with_map projection=[int_field, map_field] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: int_field@0 > 0 +----MemoryExec: partitions=1, partition_sizes=[0] + +statement ok +drop table table_with_map; From 98a5a4eb1ea1277f5fe001e1c7602b37592452f1 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 20 Dec 2023 22:11:30 +0100 Subject: [PATCH 05/63] feat: support `LargeList` in `array_dims` (#8592) * support LargeList in array_dims * drop table * add argument check --- .../physical-expr/src/array_expressions.rs | 31 ++++++++++--- datafusion/sqllogictest/test_files/array.slt | 43 ++++++++++++++++++- 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 0a7631918804..bdab65cab9e3 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1925,12 +1925,33 @@ pub fn array_length(args: &[ArrayRef]) -> Result { /// Array_dims SQL function pub fn array_dims(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; + if args.len() != 1 { + return exec_err!("array_dims needs one argument"); + } + + let data = match args[0].data_type() { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + array + .iter() + .map(compute_array_dims) + .collect::>>()? + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + array + .iter() + .map(compute_array_dims) + .collect::>>()? + } + _ => { + return exec_err!( + "array_dims does not support type '{:?}'", + args[0].data_type() + ); + } + }; - let data = list_array - .iter() - .map(compute_array_dims) - .collect::>>()?; let result = ListArray::from_iter_primitive::(data); Ok(Arc::new(result) as ArrayRef) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index b38f73ecb8db..ca33f08de06d 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -67,6 +67,16 @@ AS VALUES (make_array(make_array(15, 16),make_array(NULL, 18)), make_array(16.6, 17.7, 18.8), NULL) ; +statement ok +CREATE TABLE large_arrays +AS + SELECT + arrow_cast(column1, 'LargeList(List(Int64))') AS column1, + arrow_cast(column2, 'LargeList(Float64)') AS column2, + arrow_cast(column3, 'LargeList(Utf8)') AS column3 + FROM arrays +; + statement ok CREATE TABLE slices AS VALUES @@ -2820,8 +2830,7 @@ NULL 10 ## array_dims (aliases: `list_dims`) # array dims error -# TODO this is a separate bug -query error Internal error: could not cast value to arrow_array::array::list_array::GenericListArray\. +query error Execution error: array_dims does not support type 'Int64' select array_dims(1); # array_dims scalar function @@ -2830,6 +2839,11 @@ select array_dims(make_array(1, 2, 3)), array_dims(make_array([1, 2], [3, 4])), ---- [3] [2, 2] [1, 1, 1, 2, 1] +query ??? +select array_dims(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), array_dims(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')), array_dims(arrow_cast(make_array([[[[1], [2]]]]), 'LargeList(List(List(List(List(Int64)))))')); +---- +[3] [2, 2] [1, 1, 1, 2, 1] + # array_dims scalar function #2 query ?? select array_dims(array_repeat(array_repeat(array_repeat(2, 3), 2), 1)), array_dims(array_repeat(array_repeat(array_repeat(3, 4), 5), 2)); @@ -2842,12 +2856,22 @@ select array_dims(make_array()), array_dims(make_array(make_array())) ---- NULL [1, 0] +query ?? +select array_dims(arrow_cast(make_array(), 'LargeList(Null)')), array_dims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +---- +NULL [1, 0] + # list_dims scalar function #4 (function alias `array_dims`) query ??? select list_dims(make_array(1, 2, 3)), list_dims(make_array([1, 2], [3, 4])), list_dims(make_array([[[[1], [2]]]])); ---- [3] [2, 2] [1, 1, 1, 2, 1] +query ??? +select list_dims(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), list_dims(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')), list_dims(arrow_cast(make_array([[[[1], [2]]]]), 'LargeList(List(List(List(List(Int64)))))')); +---- +[3] [2, 2] [1, 1, 1, 2, 1] + # array_dims with columns query ??? select array_dims(column1), array_dims(column2), array_dims(column3) from arrays; @@ -2860,6 +2884,18 @@ NULL [3] [4] [2, 2] NULL [1] [2, 2] [3] NULL +query ??? +select array_dims(column1), array_dims(column2), array_dims(column3) from large_arrays; +---- +[2, 2] [3] [5] +[2, 2] [3] [5] +[2, 2] [3] [5] +[2, 2] [3] [3] +NULL [3] [4] +[2, 2] NULL [1] +[2, 2] [3] NULL + + ## array_ndims (aliases: `list_ndims`) # array_ndims scalar function #1 @@ -3768,6 +3804,9 @@ drop table nested_arrays; statement ok drop table arrays; +statement ok +drop table large_arrays; + statement ok drop table slices; From bc013fc98a6c3c86cff8fe22de688cdd250b8674 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 20 Dec 2023 15:49:42 -0700 Subject: [PATCH 06/63] Fix regression in regenerating protobuf source (#8603) * Fix regression in regenerating protobuf source * update serde code --- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 10 +++++----- datafusion/proto/src/generated/prost.rs | 4 ++-- datafusion/proto/src/logical_plan/from_proto.rs | 6 +++++- datafusion/proto/src/logical_plan/to_proto.rs | 2 +- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index bd8053c817e7..76fe449d2fa3 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -409,7 +409,7 @@ message LogicalExprNode { } message Wildcard { - optional string qualifier = 1; + string qualifier = 1; } message PlaceholderNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 88310be0318a..0671757ad427 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -25797,12 +25797,12 @@ impl serde::Serialize for Wildcard { { use serde::ser::SerializeStruct; let mut len = 0; - if self.qualifier.is_some() { + if !self.qualifier.is_empty() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.Wildcard", len)?; - if let Some(v) = self.qualifier.as_ref() { - struct_ser.serialize_field("qualifier", v)?; + if !self.qualifier.is_empty() { + struct_ser.serialize_field("qualifier", &self.qualifier)?; } struct_ser.end() } @@ -25868,12 +25868,12 @@ impl<'de> serde::Deserialize<'de> for Wildcard { if qualifier__.is_some() { return Err(serde::de::Error::duplicate_field("qualifier")); } - qualifier__ = map_.next_value()?; + qualifier__ = Some(map_.next_value()?); } } } Ok(Wildcard { - qualifier: qualifier__, + qualifier: qualifier__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 3dfd3938615f..771bd715d3c5 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -636,8 +636,8 @@ pub mod logical_expr_node { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Wildcard { - #[prost(string, optional, tag = "1")] - pub qualifier: ::core::option::Option<::prost::alloc::string::String>, + #[prost(string, tag = "1")] + pub qualifier: ::prost::alloc::string::String, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 193e0947d6d9..854bfda9a861 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1338,7 +1338,11 @@ pub fn parse_expr( in_list.negated, ))), ExprType::Wildcard(protobuf::Wildcard { qualifier }) => Ok(Expr::Wildcard { - qualifier: qualifier.clone(), + qualifier: if qualifier.is_empty() { + None + } else { + Some(qualifier.clone()) + }, }), ExprType::ScalarFunction(expr) => { let scalar_function = protobuf::ScalarFunction::try_from(expr.fun) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2997d147424d..b9987ff6c727 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1000,7 +1000,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } Expr::Wildcard { qualifier } => Self { expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { - qualifier: qualifier.clone(), + qualifier: qualifier.clone().unwrap_or("".to_string()), })), }, Expr::ScalarSubquery(_) From 96c5b8afcda12f95ce6852102c5387021f907ca6 Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Thu, 21 Dec 2023 08:27:12 -0500 Subject: [PATCH 07/63] Remove unbounded_input from FileSinkOptions (#8605) * regen protoc * remove proto flag --- .../file_format/write/orchestration.rs | 17 ++--------------- .../core/src/datasource/listing/table.rs | 9 +-------- .../core/src/datasource/physical_plan/mod.rs | 18 ------------------ datafusion/core/src/physical_planner.rs | 1 - datafusion/proto/proto/datafusion.proto | 5 ++--- datafusion/proto/src/generated/pbjson.rs | 18 ------------------ datafusion/proto/src/generated/prost.rs | 4 +--- .../proto/src/physical_plan/from_proto.rs | 1 - datafusion/proto/src/physical_plan/to_proto.rs | 1 - .../tests/cases/roundtrip_physical_plan.rs | 1 - 10 files changed, 6 insertions(+), 69 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index 2ae6b70ed1c5..120e27ecf669 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -52,7 +52,6 @@ pub(crate) async fn serialize_rb_stream_to_object_store( mut data_rx: Receiver, mut serializer: Box, mut writer: AbortableWrite>, - unbounded_input: bool, ) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> { let (tx, mut rx) = mpsc::channel::>>(100); @@ -71,9 +70,6 @@ pub(crate) async fn serialize_rb_stream_to_object_store( "Unknown error writing to object store".into(), ) })?; - if unbounded_input { - tokio::task::yield_now().await; - } } Err(_) => { return Err(DataFusionError::Internal( @@ -140,7 +136,6 @@ type FileWriteBundle = (Receiver, SerializerType, WriterType); pub(crate) async fn stateless_serialize_and_write_files( mut rx: Receiver, tx: tokio::sync::oneshot::Sender, - unbounded_input: bool, ) -> Result<()> { let mut row_count = 0; // tracks if any writers encountered an error triggering the need to abort @@ -153,13 +148,7 @@ pub(crate) async fn stateless_serialize_and_write_files( let mut join_set = JoinSet::new(); while let Some((data_rx, serializer, writer)) = rx.recv().await { join_set.spawn(async move { - serialize_rb_stream_to_object_store( - data_rx, - serializer, - writer, - unbounded_input, - ) - .await + serialize_rb_stream_to_object_store(data_rx, serializer, writer).await }); } let mut finished_writers = Vec::new(); @@ -241,7 +230,6 @@ pub(crate) async fn stateless_multipart_put( let single_file_output = config.single_file_output; let base_output_path = &config.table_paths[0]; - let unbounded_input = config.unbounded_input; let part_cols = if !config.table_partition_cols.is_empty() { Some(config.table_partition_cols.clone()) } else { @@ -266,8 +254,7 @@ pub(crate) async fn stateless_multipart_put( let (tx_file_bundle, rx_file_bundle) = tokio::sync::mpsc::channel(rb_buffer_size / 2); let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel(); let write_coordinater_task = tokio::spawn(async move { - stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt, unbounded_input) - .await + stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt).await }); while let Some((location, rb_stream)) = file_stream_rx.recv().await { let serializer = get_serializer(); diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 4c13d9d443ca..21d43dcd56db 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -38,7 +38,7 @@ use crate::datasource::{ }, get_statistics_with_limit, listing::ListingTableUrl, - physical_plan::{is_plan_streaming, FileScanConfig, FileSinkConfig}, + physical_plan::{FileScanConfig, FileSinkConfig}, TableProvider, TableType, }; use crate::{ @@ -790,13 +790,6 @@ impl TableProvider for ListingTable { file_groups, output_schema: self.schema(), table_partition_cols: self.options.table_partition_cols.clone(), - // A plan can produce finite number of rows even if it has unbounded sources, like LIMIT - // queries. Thus, we can check if the plan is streaming to ensure file sink input is - // unbounded. When `unbounded_input` flag is `true` for sink, we occasionally call `yield_now` - // to consume data at the input. When `unbounded_input` flag is `false` (e.g non-streaming data), - // all of the data at the input is sink after execution finishes. See discussion for rationale: - // https://github.com/apache/arrow-datafusion/pull/7610#issuecomment-1728979918 - unbounded_input: is_plan_streaming(&input)?, single_file_output: self.options.single_file, overwrite, file_type_writer_options, diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 9d1c373aee7c..4a6ebeab09e1 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -69,7 +69,6 @@ use arrow::{ use datafusion_common::{file_options::FileTypeWriterOptions, plan_err}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_plan::ExecutionPlan; use log::debug; use object_store::path::Path; @@ -93,8 +92,6 @@ pub struct FileSinkConfig { /// regardless of input partitioning. Otherwise, each table path is assumed to be a directory /// to which each output partition is written to its own output file. pub single_file_output: bool, - /// If input is unbounded, tokio tasks need to yield to not block execution forever - pub unbounded_input: bool, /// Controls whether existing data should be overwritten by this sink pub overwrite: bool, /// Contains settings specific to writing a given FileType, e.g. parquet max_row_group_size @@ -510,21 +507,6 @@ fn get_projected_output_ordering( all_orderings } -// Get output (un)boundedness information for the given `plan`. -pub(crate) fn is_plan_streaming(plan: &Arc) -> Result { - let result = if plan.children().is_empty() { - plan.unbounded_output(&[]) - } else { - let children_unbounded_output = plan - .children() - .iter() - .map(is_plan_streaming) - .collect::>>(); - plan.unbounded_output(&children_unbounded_output?) - }; - result -} - #[cfg(test)] mod tests { use arrow_array::cast::AsArray; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index e5816eb49ebb..31d50be10f70 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -593,7 +593,6 @@ impl DefaultPhysicalPlanner { file_groups: vec![], output_schema: Arc::new(schema), table_partition_cols: vec![], - unbounded_input: false, single_file_output: *single_file_output, overwrite: false, file_type_writer_options diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 76fe449d2fa3..cc802ee95710 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1201,9 +1201,8 @@ message FileSinkConfig { Schema output_schema = 4; repeated PartitionColumn table_partition_cols = 5; bool single_file_output = 7; - bool unbounded_input = 8; - bool overwrite = 9; - FileTypeWriterOptions file_type_writer_options = 10; + bool overwrite = 8; + FileTypeWriterOptions file_type_writer_options = 9; } message JsonSink { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0671757ad427..fb3a3ad91d06 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -7500,9 +7500,6 @@ impl serde::Serialize for FileSinkConfig { if self.single_file_output { len += 1; } - if self.unbounded_input { - len += 1; - } if self.overwrite { len += 1; } @@ -7528,9 +7525,6 @@ impl serde::Serialize for FileSinkConfig { if self.single_file_output { struct_ser.serialize_field("singleFileOutput", &self.single_file_output)?; } - if self.unbounded_input { - struct_ser.serialize_field("unboundedInput", &self.unbounded_input)?; - } if self.overwrite { struct_ser.serialize_field("overwrite", &self.overwrite)?; } @@ -7559,8 +7553,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "tablePartitionCols", "single_file_output", "singleFileOutput", - "unbounded_input", - "unboundedInput", "overwrite", "file_type_writer_options", "fileTypeWriterOptions", @@ -7574,7 +7566,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { OutputSchema, TablePartitionCols, SingleFileOutput, - UnboundedInput, Overwrite, FileTypeWriterOptions, } @@ -7604,7 +7595,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "outputSchema" | "output_schema" => Ok(GeneratedField::OutputSchema), "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), "singleFileOutput" | "single_file_output" => Ok(GeneratedField::SingleFileOutput), - "unboundedInput" | "unbounded_input" => Ok(GeneratedField::UnboundedInput), "overwrite" => Ok(GeneratedField::Overwrite), "fileTypeWriterOptions" | "file_type_writer_options" => Ok(GeneratedField::FileTypeWriterOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -7632,7 +7622,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { let mut output_schema__ = None; let mut table_partition_cols__ = None; let mut single_file_output__ = None; - let mut unbounded_input__ = None; let mut overwrite__ = None; let mut file_type_writer_options__ = None; while let Some(k) = map_.next_key()? { @@ -7673,12 +7662,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { } single_file_output__ = Some(map_.next_value()?); } - GeneratedField::UnboundedInput => { - if unbounded_input__.is_some() { - return Err(serde::de::Error::duplicate_field("unboundedInput")); - } - unbounded_input__ = Some(map_.next_value()?); - } GeneratedField::Overwrite => { if overwrite__.is_some() { return Err(serde::de::Error::duplicate_field("overwrite")); @@ -7700,7 +7683,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { output_schema: output_schema__, table_partition_cols: table_partition_cols__.unwrap_or_default(), single_file_output: single_file_output__.unwrap_or_default(), - unbounded_input: unbounded_input__.unwrap_or_default(), overwrite: overwrite__.unwrap_or_default(), file_type_writer_options: file_type_writer_options__, }) diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 771bd715d3c5..9030e90a24c8 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1635,10 +1635,8 @@ pub struct FileSinkConfig { #[prost(bool, tag = "7")] pub single_file_output: bool, #[prost(bool, tag = "8")] - pub unbounded_input: bool, - #[prost(bool, tag = "9")] pub overwrite: bool, - #[prost(message, optional, tag = "10")] + #[prost(message, optional, tag = "9")] pub file_type_writer_options: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 5c0ef615cacd..65f9f139a87b 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -739,7 +739,6 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { output_schema: Arc::new(convert_required!(conf.output_schema)?), table_partition_cols, single_file_output: conf.single_file_output, - unbounded_input: conf.unbounded_input, overwrite: conf.overwrite, file_type_writer_options: convert_required!(conf.file_type_writer_options)?, }) diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index ea00b726b9d6..e9cdb34cf1b9 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -846,7 +846,6 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { output_schema: Some(conf.output_schema.as_ref().try_into()?), table_partition_cols, single_file_output: conf.single_file_output, - unbounded_input: conf.unbounded_input, overwrite: conf.overwrite, file_type_writer_options: Some(file_type_writer_options.try_into()?), }) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 9a9827f2a090..2eb04ab6cbab 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -733,7 +733,6 @@ fn roundtrip_json_sink() -> Result<()> { output_schema: schema.clone(), table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], single_file_output: true, - unbounded_input: false, overwrite: true, file_type_writer_options: FileTypeWriterOptions::JSON(JsonWriterOptions::new( CompressionTypeVariant::UNCOMPRESSED, From df806bd314df9c2a8087fe1422337bce25dc8614 Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 21 Dec 2023 09:41:51 -0800 Subject: [PATCH 08/63] Add `arrow_err!` macros, optional backtrace to ArrowError (#8586) * Introducing `arrow_err!` macros --- datafusion-cli/Cargo.lock | 80 ++++++++--------- datafusion-cli/Cargo.toml | 2 +- datafusion/common/src/error.rs | 85 +++++++++++++------ datafusion/common/src/scalar.rs | 9 +- datafusion/common/src/utils.rs | 10 +-- .../avro_to_arrow/arrow_array_reader.rs | 5 +- .../src/datasource/listing_table_factory.rs | 4 +- datafusion/core/src/datasource/memory.rs | 2 +- .../physical_plan/parquet/row_filter.rs | 4 +- .../tests/user_defined/user_defined_plan.rs | 3 +- .../simplify_expressions/expr_simplifier.rs | 42 +++++---- .../physical-expr/src/aggregate/first_last.rs | 4 +- .../aggregate/groups_accumulator/adapter.rs | 7 +- .../physical-expr/src/expressions/binary.rs | 15 ++-- .../physical-expr/src/regex_expressions.rs | 6 +- .../physical-expr/src/window/lead_lag.rs | 8 +- .../src/joins/stream_join_utils.rs | 6 +- datafusion/physical-plan/src/joins/utils.rs | 14 ++- .../physical-plan/src/repartition/mod.rs | 9 +- .../src/windows/bounded_window_agg_exec.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 7 +- datafusion/sqllogictest/test_files/math.slt | 37 ++++---- 22 files changed, 191 insertions(+), 172 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 19ad6709362d..ac05ddf10a73 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -384,7 +384,7 @@ checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -1069,12 +1069,12 @@ dependencies = [ [[package]] name = "ctor" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37e366bff8cd32dd8754b0991fb66b279dc48f598c3a18914852a6673deef583" +checksum = "30d2b3721e861707777e3195b0158f950ae6dc4a27e4d02ff9f67e3eb3de199e" dependencies = [ "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -1576,7 +1576,7 @@ checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -1781,9 +1781,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.27" +version = "0.14.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" +checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" dependencies = [ "bytes", "futures-channel", @@ -1796,7 +1796,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.4.10", + "socket2", "tokio", "tower-service", "tracing", @@ -2496,7 +2496,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -2715,9 +2715,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "reqwest" -version = "0.11.22" +version = "0.11.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b" +checksum = "37b1ae8d9ac08420c66222fb9096fc5de435c3c48542bc5336c51892cffafb41" dependencies = [ "base64", "bytes", @@ -3020,7 +3020,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3106,16 +3106,6 @@ version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" -[[package]] -name = "socket2" -version = "0.4.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f7916fc008ca5542385b89a3d3ce689953c143e9304a9bf8beec1de48994c0d" -dependencies = [ - "libc", - "winapi", -] - [[package]] name = "socket2" version = "0.5.5" @@ -3196,7 +3186,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3218,9 +3208,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.40" +version = "2.0.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13fa70a4ee923979ffb522cacce59d34421ebdea5625e1073c4326ef9d2dd42e" +checksum = "44c8b28c477cc3bf0e7966561e3460130e1255f7a1cf71931075f1c5e7a7e269" dependencies = [ "proc-macro2", "quote", @@ -3284,22 +3274,22 @@ checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" [[package]] name = "thiserror" -version = "1.0.50" +version = "1.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" +checksum = "f11c217e1416d6f036b870f14e0413d480dbf28edbee1f877abaf0206af43bb7" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.50" +version = "1.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" +checksum = "01742297787513b79cf8e29d1056ede1313e2420b7b3b15d0a768b4921f549df" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3315,9 +3305,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4a34ab300f2dee6e562c10a046fc05e358b29f9bf92277f30c3c8d82275f6f5" +checksum = "f657ba42c3f86e7680e53c8cd3af8abbe56b5491790b46e22e19c0d57463583e" dependencies = [ "deranged", "powerfmt", @@ -3334,9 +3324,9 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ad70d68dba9e1f8aceda7aa6711965dfec1cac869f311a51bd08b3a2ccbce20" +checksum = "26197e33420244aeb70c3e8c78376ca46571bc4e701e4791c2cd9f57dcb3a43f" dependencies = [ "time-core", ] @@ -3378,7 +3368,7 @@ dependencies = [ "num_cpus", "parking_lot", "pin-project-lite", - "socket2 0.5.5", + "socket2", "tokio-macros", "windows-sys 0.48.0", ] @@ -3391,7 +3381,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3488,7 +3478,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3533,7 +3523,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3687,7 +3677,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", "wasm-bindgen-shared", ] @@ -3721,7 +3711,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3970,22 +3960,22 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.30" +version = "0.7.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "306dca4455518f1f31635ec308b6b3e4eb1b11758cefafc782827d0aa7acb5c7" +checksum = "1c4061bedbb353041c12f413700357bec76df2c7e2ca8e4df8bac24c6bf68e3d" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.30" +version = "0.7.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be912bf68235a88fbefd1b73415cb218405958d1655b2ece9035a19920bdf6ba" +checksum = "b3c129550b3e6de3fd0ba67ba5c81818f9805e58b8d7fee80a3a59d2c9fc601a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 1bf24808fb90..f57097683698 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -35,6 +35,7 @@ aws-config = "0.55" aws-credential-types = "0.55" clap = { version = "3", features = ["derive", "cargo"] } datafusion = { path = "../datafusion/core", version = "34.0.0", features = ["avro", "crypto_expressions", "encoding_expressions", "parquet", "regex_expressions", "unicode_expressions", "compression"] } +datafusion-common = { path = "../datafusion/common" } dirs = "4.0.0" env_logger = "0.9" mimalloc = { version = "0.1", default-features = false } @@ -49,6 +50,5 @@ url = "2.2" [dev-dependencies] assert_cmd = "2.0" ctor = "0.2.0" -datafusion-common = { path = "../datafusion/common" } predicates = "3.0" rstest = "0.17" diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 56b52bd73f9b..515acc6d1c47 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -47,7 +47,8 @@ pub type GenericError = Box; #[derive(Debug)] pub enum DataFusionError { /// Error returned by arrow. - ArrowError(ArrowError), + /// 2nd argument is for optional backtrace + ArrowError(ArrowError, Option), /// Wraps an error from the Parquet crate #[cfg(feature = "parquet")] ParquetError(ParquetError), @@ -60,7 +61,8 @@ pub enum DataFusionError { /// Error associated to I/O operations and associated traits. IoError(io::Error), /// Error returned when SQL is syntactically incorrect. - SQL(ParserError), + /// 2nd argument is for optional backtrace + SQL(ParserError, Option), /// Error returned on a branch that we know it is possible /// but to which we still have no implementation for. /// Often, these errors are tracked in our issue tracker. @@ -223,14 +225,14 @@ impl From for DataFusionError { impl From for DataFusionError { fn from(e: ArrowError) -> Self { - DataFusionError::ArrowError(e) + DataFusionError::ArrowError(e, None) } } impl From for ArrowError { fn from(e: DataFusionError) -> Self { match e { - DataFusionError::ArrowError(e) => e, + DataFusionError::ArrowError(e, _) => e, DataFusionError::External(e) => ArrowError::ExternalError(e), other => ArrowError::ExternalError(Box::new(other)), } @@ -267,7 +269,7 @@ impl From for DataFusionError { impl From for DataFusionError { fn from(e: ParserError) -> Self { - DataFusionError::SQL(e) + DataFusionError::SQL(e, None) } } @@ -280,8 +282,9 @@ impl From for DataFusionError { impl Display for DataFusionError { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match *self { - DataFusionError::ArrowError(ref desc) => { - write!(f, "Arrow error: {desc}") + DataFusionError::ArrowError(ref desc, ref backtrace) => { + let backtrace = backtrace.clone().unwrap_or("".to_owned()); + write!(f, "Arrow error: {desc}{backtrace}") } #[cfg(feature = "parquet")] DataFusionError::ParquetError(ref desc) => { @@ -294,8 +297,9 @@ impl Display for DataFusionError { DataFusionError::IoError(ref desc) => { write!(f, "IO error: {desc}") } - DataFusionError::SQL(ref desc) => { - write!(f, "SQL error: {desc:?}") + DataFusionError::SQL(ref desc, ref backtrace) => { + let backtrace = backtrace.clone().unwrap_or("".to_owned()); + write!(f, "SQL error: {desc:?}{backtrace}") } DataFusionError::Configuration(ref desc) => { write!(f, "Invalid or Unsupported Configuration: {desc}") @@ -339,7 +343,7 @@ impl Display for DataFusionError { impl Error for DataFusionError { fn source(&self) -> Option<&(dyn Error + 'static)> { match self { - DataFusionError::ArrowError(e) => Some(e), + DataFusionError::ArrowError(e, _) => Some(e), #[cfg(feature = "parquet")] DataFusionError::ParquetError(e) => Some(e), #[cfg(feature = "avro")] @@ -347,7 +351,7 @@ impl Error for DataFusionError { #[cfg(feature = "object_store")] DataFusionError::ObjectStore(e) => Some(e), DataFusionError::IoError(e) => Some(e), - DataFusionError::SQL(e) => Some(e), + DataFusionError::SQL(e, _) => Some(e), DataFusionError::NotImplemented(_) => None, DataFusionError::Internal(_) => None, DataFusionError::Configuration(_) => None, @@ -505,32 +509,57 @@ macro_rules! make_error { }; } -// Exposes a macro to create `DataFusionError::Plan` +// Exposes a macro to create `DataFusionError::Plan` with optional backtrace make_error!(plan_err, plan_datafusion_err, Plan); -// Exposes a macro to create `DataFusionError::Internal` +// Exposes a macro to create `DataFusionError::Internal` with optional backtrace make_error!(internal_err, internal_datafusion_err, Internal); -// Exposes a macro to create `DataFusionError::NotImplemented` +// Exposes a macro to create `DataFusionError::NotImplemented` with optional backtrace make_error!(not_impl_err, not_impl_datafusion_err, NotImplemented); -// Exposes a macro to create `DataFusionError::Execution` +// Exposes a macro to create `DataFusionError::Execution` with optional backtrace make_error!(exec_err, exec_datafusion_err, Execution); -// Exposes a macro to create `DataFusionError::Substrait` +// Exposes a macro to create `DataFusionError::Substrait` with optional backtrace make_error!(substrait_err, substrait_datafusion_err, Substrait); -// Exposes a macro to create `DataFusionError::SQL` +// Exposes a macro to create `DataFusionError::SQL` with optional backtrace +#[macro_export] +macro_rules! sql_datafusion_err { + ($ERR:expr) => { + DataFusionError::SQL($ERR, Some(DataFusionError::get_back_trace())) + }; +} + +// Exposes a macro to create `Err(DataFusionError::SQL)` with optional backtrace #[macro_export] macro_rules! sql_err { ($ERR:expr) => { - Err(DataFusionError::SQL($ERR)) + Err(datafusion_common::sql_datafusion_err!($ERR)) + }; +} + +// Exposes a macro to create `DataFusionError::ArrowError` with optional backtrace +#[macro_export] +macro_rules! arrow_datafusion_err { + ($ERR:expr) => { + DataFusionError::ArrowError($ERR, Some(DataFusionError::get_back_trace())) + }; +} + +// Exposes a macro to create `Err(DataFusionError::ArrowError)` with optional backtrace +#[macro_export] +macro_rules! arrow_err { + ($ERR:expr) => { + Err(datafusion_common::arrow_datafusion_err!($ERR)) }; } // To avoid compiler error when using macro in the same crate: // macros from the current crate cannot be referred to by absolute paths pub use exec_err as _exec_err; +pub use internal_datafusion_err as _internal_datafusion_err; pub use internal_err as _internal_err; pub use not_impl_err as _not_impl_err; pub use plan_err as _plan_err; @@ -600,9 +629,12 @@ mod test { ); do_root_test( - DataFusionError::ArrowError(ArrowError::ExternalError(Box::new( - DataFusionError::ResourcesExhausted("foo".to_string()), - ))), + DataFusionError::ArrowError( + ArrowError::ExternalError(Box::new(DataFusionError::ResourcesExhausted( + "foo".to_string(), + ))), + None, + ), DataFusionError::ResourcesExhausted("foo".to_string()), ); @@ -621,11 +653,12 @@ mod test { ); do_root_test( - DataFusionError::ArrowError(ArrowError::ExternalError(Box::new( - ArrowError::ExternalError(Box::new(DataFusionError::ResourcesExhausted( - "foo".to_string(), - ))), - ))), + DataFusionError::ArrowError( + ArrowError::ExternalError(Box::new(ArrowError::ExternalError(Box::new( + DataFusionError::ResourcesExhausted("foo".to_string()), + )))), + None, + ), DataFusionError::ResourcesExhausted("foo".to_string()), ); diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index d730fbf89b72..48878aa9bd99 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -24,6 +24,7 @@ use std::convert::{Infallible, TryInto}; use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; +use crate::arrow_datafusion_err; use crate::cast::{ as_decimal128_array, as_decimal256_array, as_dictionary_array, as_fixed_size_binary_array, as_fixed_size_list_array, as_struct_array, @@ -1654,11 +1655,11 @@ impl ScalarValue { match value { Some(val) => Decimal128Array::from(vec![val; size]) .with_precision_and_scale(precision, scale) - .map_err(DataFusionError::ArrowError), + .map_err(|e| arrow_datafusion_err!(e)), None => { let mut builder = Decimal128Array::builder(size) .with_precision_and_scale(precision, scale) - .map_err(DataFusionError::ArrowError)?; + .map_err(|e| arrow_datafusion_err!(e))?; builder.append_nulls(size); Ok(builder.finish()) } @@ -1675,7 +1676,7 @@ impl ScalarValue { .take(size) .collect::() .with_precision_and_scale(precision, scale) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } /// Converts `Vec` where each element has type corresponding to @@ -1882,7 +1883,7 @@ impl ScalarValue { .take(size) .collect::>(); arrow::compute::concat(arrays.as_slice()) - .map_err(DataFusionError::ArrowError)? + .map_err(|e| arrow_datafusion_err!(e))? } ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 2d38ca21829b..cfdef309a4ee 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -17,8 +17,8 @@ //! This module provides the bisect function, which implements binary search. -use crate::error::_internal_err; -use crate::{DataFusionError, Result, ScalarValue}; +use crate::error::{_internal_datafusion_err, _internal_err}; +use crate::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; use arrow::array::{ArrayRef, PrimitiveArray}; use arrow::buffer::OffsetBuffer; use arrow::compute; @@ -95,7 +95,7 @@ pub fn get_record_batch_at_indices( new_columns, &RecordBatchOptions::new().with_row_count(Some(indices.len())), ) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } /// This function compares two tuples depending on the given sort options. @@ -117,7 +117,7 @@ pub fn compare_rows( lhs.partial_cmp(rhs) } .ok_or_else(|| { - DataFusionError::Internal("Column array shouldn't be empty".to_string()) + _internal_datafusion_err!("Column array shouldn't be empty") })?, (true, true, _) => continue, }; @@ -291,7 +291,7 @@ pub fn get_arrayref_at_indices( indices, None, // None: no index check ) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect() } diff --git a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs index 855a8d0dbf40..a16c1ae3333f 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs @@ -45,6 +45,7 @@ use arrow::array::{BinaryArray, FixedSizeBinaryArray, GenericListArray}; use arrow::datatypes::{Fields, SchemaRef}; use arrow::error::ArrowError::SchemaError; use arrow::error::Result as ArrowResult; +use datafusion_common::arrow_err; use num_traits::NumCast; use std::collections::BTreeMap; use std::io::Read; @@ -86,9 +87,9 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { } Ok(lookup) } - _ => Err(DataFusionError::ArrowError(SchemaError( + _ => arrow_err!(SchemaError( "expected avro schema to be a record".to_string(), - ))), + )), } } diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 7c859ee988d5..68c97bbb7806 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -36,7 +36,7 @@ use crate::execution::context::SessionState; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::file_options::{FileTypeWriterOptions, StatementOptions}; -use datafusion_common::{plan_err, DataFusionError, FileType}; +use datafusion_common::{arrow_datafusion_err, plan_err, DataFusionError, FileType}; use datafusion_expr::CreateExternalTable; use async_trait::async_trait; @@ -114,7 +114,7 @@ impl TableProviderFactory for ListingTableFactory { .map(|col| { schema .field_with_name(col) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect::>>()? .into_iter() diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 7c044b29366d..7c61cc536860 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -423,7 +423,7 @@ mod tests { .scan(&session_ctx.state(), Some(&projection), &[], None) .await { - Err(DataFusionError::ArrowError(ArrowError::SchemaError(e))) => { + Err(DataFusionError::ArrowError(ArrowError::SchemaError(e), _)) => { assert_eq!( "\"project index 4 out of bounds, max field 3\"", format!("{e:?}") diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index 5fe0a0a13a73..151ab5f657b1 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -21,7 +21,7 @@ use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{arrow_err, DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::utils::reassign_predicate_columns; use std::collections::BTreeSet; @@ -243,7 +243,7 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { } Err(e) => { // If the column is not in the table schema, should throw the error - Err(DataFusionError::ArrowError(e)) + arrow_err!(e) } }; } diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index d4a8842c0a7a..29708c4422ca 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -91,6 +91,7 @@ use datafusion::{ }; use async_trait::async_trait; +use datafusion_common::arrow_datafusion_err; use futures::{Stream, StreamExt}; /// Execute the specified sql and return the resulting record batches @@ -99,7 +100,7 @@ async fn exec_sql(ctx: &mut SessionContext, sql: &str) -> Result { let df = ctx.sql(sql).await?; let batches = df.collect().await?; pretty_format_batches(&batches) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) .map(|d| d.to_string()) } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index e2fbd5e927a1..5a300e2ff246 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -29,11 +29,11 @@ use crate::simplify_expressions::SimplifyInfo; use arrow::{ array::new_null_array, datatypes::{DataType, Field, Schema}, - error::ArrowError, record_batch::RecordBatch, }; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, + plan_err, tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{ @@ -792,7 +792,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Divide, right, }) if is_null(&right) => *right, - // A / 0 -> DivideByZero Error if A is not null and not floating + // A / 0 -> Divide by zero error if A is not null and not floating // (float / 0 -> inf | -inf | NAN) Expr::BinaryExpr(BinaryExpr { left, @@ -802,7 +802,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&left)?.is_floating() && is_zero(&right) => { - return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)); + return plan_err!("Divide by zero"); } // @@ -832,7 +832,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { { lit(0) } - // A % 0 --> DivideByZero Error (if A is not floating and not null) + // A % 0 --> Divide by zero Error (if A is not floating and not null) // A % 0 --> NAN (if A is floating and not null) Expr::BinaryExpr(BinaryExpr { left, @@ -843,9 +843,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { DataType::Float32 => lit(f32::NAN), DataType::Float64 => lit(f64::NAN), _ => { - return Err(DataFusionError::ArrowError( - ArrowError::DivideByZero, - )); + return plan_err!("Divide by zero"); } } } @@ -1315,7 +1313,9 @@ mod tests { array::{ArrayRef, Int32Array}, datatypes::{DataType, Field, Schema}, }; - use datafusion_common::{assert_contains, cast::as_int32_array, DFField, ToDFSchema}; + use datafusion_common::{ + assert_contains, cast::as_int32_array, plan_datafusion_err, DFField, ToDFSchema, + }; use datafusion_expr::{interval_arithmetic::Interval, *}; use datafusion_physical_expr::{ execution_props::ExecutionProps, functions::make_scalar_function, @@ -1771,25 +1771,23 @@ mod tests { #[test] fn test_simplify_divide_zero_by_zero() { - // 0 / 0 -> DivideByZero + // 0 / 0 -> Divide by zero let expr = lit(0) / lit(0); let err = try_simplify(expr).unwrap_err(); - assert!( - matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)), - "{err}" - ); + let _expected = plan_datafusion_err!("Divide by zero"); + + assert!(matches!(err, ref _expected), "{err}"); } #[test] - #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: ArrowError(DivideByZero)" - )] fn test_simplify_divide_by_zero() { // A / 0 -> DivideByZeroError let expr = col("c2_non_null") / lit(0); - - simplify(expr); + assert_eq!( + try_simplify(expr).unwrap_err().strip_backtrace(), + "Error during planning: Divide by zero" + ); } #[test] @@ -2209,12 +2207,12 @@ mod tests { } #[test] - #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: ArrowError(DivideByZero)" - )] fn test_simplify_modulo_by_zero_non_null() { let expr = col("c2_non_null") % lit(0); - simplify(expr); + assert_eq!( + try_simplify(expr).unwrap_err().strip_backtrace(), + "Error during planning: Divide by zero" + ); } #[test] diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 5e2012bdbb67..c009881d8918 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -31,7 +31,7 @@ use arrow::compute::{self, lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field}; use arrow_schema::SortOptions; use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; /// FIRST_VALUE aggregate expression @@ -541,7 +541,7 @@ fn filter_states_according_to_is_set( ) -> Result> { states .iter() - .map(|state| compute::filter(state, flags).map_err(DataFusionError::ArrowError)) + .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e))) .collect::>>() } diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs index cf980f4c3f16..c6fd17a69b39 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs @@ -25,7 +25,8 @@ use arrow::{ }; use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray}; use datafusion_common::{ - utils::get_arrayref_at_indices, DataFusionError, Result, ScalarValue, + arrow_datafusion_err, utils::get_arrayref_at_indices, DataFusionError, Result, + ScalarValue, }; use datafusion_expr::Accumulator; @@ -372,7 +373,7 @@ fn get_filter_at_indices( ) }) .transpose() - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } // Copied from physical-plan @@ -394,7 +395,7 @@ pub(crate) fn slice_and_maybe_filter( sliced_arrays .iter() .map(|array| { - compute::filter(array, filter_array).map_err(DataFusionError::ArrowError) + compute::filter(array, filter_array).map_err(|e| arrow_datafusion_err!(e)) }) .collect() } else { diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 9c7fdd2e814b..c17081398cb8 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -629,8 +629,7 @@ mod tests { use arrow::datatypes::{ ArrowNumericType, Decimal128Type, Field, Int32Type, SchemaRef, }; - use arrow_schema::ArrowError; - use datafusion_common::Result; + use datafusion_common::{plan_datafusion_err, Result}; use datafusion_expr::type_coercion::binary::get_input_types; /// Performs a binary operation, applying any type coercion necessary @@ -3608,10 +3607,9 @@ mod tests { ) .unwrap_err(); - assert!( - matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)), - "{err}" - ); + let _expected = plan_datafusion_err!("Divide by zero"); + + assert!(matches!(err, ref _expected), "{err}"); // decimal let schema = Arc::new(Schema::new(vec![ @@ -3633,10 +3631,7 @@ mod tests { ) .unwrap_err(); - assert!( - matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)), - "{err}" - ); + assert!(matches!(err, ref _expected), "{err}"); Ok(()) } diff --git a/datafusion/physical-expr/src/regex_expressions.rs b/datafusion/physical-expr/src/regex_expressions.rs index 41cd01949595..7bafed072b61 100644 --- a/datafusion/physical-expr/src/regex_expressions.rs +++ b/datafusion/physical-expr/src/regex_expressions.rs @@ -26,7 +26,7 @@ use arrow::array::{ OffsetSizeTrait, }; use arrow::compute; -use datafusion_common::plan_err; +use datafusion_common::{arrow_datafusion_err, plan_err}; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; @@ -58,7 +58,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { 2 => { let values = as_generic_string_array::(&args[0])?; let regex = as_generic_string_array::(&args[1])?; - compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError) + compute::regexp_match(values, regex, None).map_err(|e| arrow_datafusion_err!(e)) } 3 => { let values = as_generic_string_array::(&args[0])?; @@ -69,7 +69,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { Some(f) if f.iter().any(|s| s == Some("g")) => { plan_err!("regexp_match() does not support the \"global\" option") }, - _ => compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError), + _ => compute::regexp_match(values, regex, flags).map_err(|e| arrow_datafusion_err!(e)), } } other => internal_err!( diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index d22660d41ebd..7ee736ce9caa 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -23,7 +23,7 @@ use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::compute::cast; use arrow::datatypes::{DataType, Field}; -use datafusion_common::ScalarValue; +use datafusion_common::{arrow_datafusion_err, ScalarValue}; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::PartitionEvaluator; use std::any::Any; @@ -142,7 +142,7 @@ fn create_empty_array( .transpose()? .unwrap_or_else(|| new_null_array(data_type, size)); if array.data_type() != data_type { - cast(&array, data_type).map_err(DataFusionError::ArrowError) + cast(&array, data_type).map_err(|e| arrow_datafusion_err!(e)) } else { Ok(array) } @@ -172,10 +172,10 @@ fn shift_with_default_value( // Concatenate both arrays, add nulls after if shift > 0 else before if offset > 0 { concat(&[default_values.as_ref(), slice.as_ref()]) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } else { concat(&[slice.as_ref(), default_values.as_ref()]) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } } } diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 64a976a1e39f..50b1618a35dd 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -33,7 +33,9 @@ use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; use arrow_schema::{Schema, SchemaRef}; use async_trait::async_trait; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{DataFusionError, JoinSide, Result, ScalarValue}; +use datafusion_common::{ + arrow_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue, +}; use datafusion_execution::SendableRecordBatchStream; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; @@ -595,7 +597,7 @@ pub fn combine_two_batches( (Some(left_batch), Some(right_batch)) => { // If both batches are present, concatenate them: concat_batches(output_schema, &[left_batch, right_batch]) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) .map(Some) } (None, None) => { diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index eae65ce9c26b..c902ba85f271 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -1370,7 +1370,7 @@ mod tests { use arrow::error::{ArrowError, Result as ArrowResult}; use arrow_schema::SortOptions; - use datafusion_common::ScalarValue; + use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { let left = left @@ -1406,9 +1406,7 @@ mod tests { #[tokio::test] async fn check_error_nesting() { let once_fut = OnceFut::<()>::new(async { - Err(DataFusionError::ArrowError(ArrowError::CsvError( - "some error".to_string(), - ))) + arrow_err!(ArrowError::CsvError("some error".to_string())) }); struct TestFut(OnceFut<()>); @@ -1432,10 +1430,10 @@ mod tests { let wrapped_err = DataFusionError::from(arrow_err_from_fut); let root_err = wrapped_err.find_root(); - assert!(matches!( - root_err, - DataFusionError::ArrowError(ArrowError::CsvError(_)) - )) + let _expected = + arrow_datafusion_err!(ArrowError::CsvError("some error".to_owned())); + + assert!(matches!(root_err, _expected)) } #[test] diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 769dc5e0e197..07693f747fee 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -34,7 +34,7 @@ use log::trace; use parking_lot::Mutex; use tokio::task::JoinHandle; -use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_common::{arrow_datafusion_err, not_impl_err, DataFusionError, Result}; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; @@ -200,7 +200,7 @@ impl BatchPartitioner { .iter() .map(|c| { arrow::compute::take(c.as_ref(), &indices, None) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect::>>()?; @@ -1414,9 +1414,8 @@ mod tests { // pull partitions for i in 0..exec.partitioning.partition_count() { let mut stream = exec.execute(i, task_ctx.clone())?; - let err = DataFusionError::ArrowError( - stream.next().await.unwrap().unwrap_err().into(), - ); + let err = + arrow_datafusion_err!(stream.next().await.unwrap().unwrap_err().into()); let err = err.find_root(); assert!( matches!(err, DataFusionError::ResourcesExhausted(_)), diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 431a43bc6055..0871ec0d7ff3 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -51,7 +51,7 @@ use datafusion_common::utils::{ evaluate_partition_ranges, get_arrayref_at_indices, get_at_indices, get_record_batch_at_indices, get_row_at_idx, }; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_expr::window_state::{PartitionBatchState, WindowAggState}; use datafusion_expr::ColumnarValue; @@ -499,7 +499,7 @@ impl PartitionSearcher for LinearSearch { .iter() .map(|items| { concat(&items.iter().map(|e| e.as_ref()).collect::>()) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect::>>()?; // We should emit columns according to row index ordering. diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 854bfda9a861..c582e92dc11c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -36,8 +36,9 @@ use arrow::{ }; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ - internal_err, plan_datafusion_err, Column, Constraint, Constraints, DFField, - DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, ScalarValue, + arrow_datafusion_err, internal_err, plan_datafusion_err, Column, Constraint, + Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, + Result, ScalarValue, }; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ @@ -717,7 +718,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { None, &message.version(), ) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; let arr = record_batch.column(0); match value { diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index ee1e345f946a..0fa7ff9c2051 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -293,53 +293,52 @@ select c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 from test_non_nullable_int ---- 0 0 0 0 0 0 0 0 -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c2/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c3/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c4/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c5/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c6/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c7/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c8/0 FROM test_non_nullable_integer - -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c2%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c3%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c4%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c5%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c6%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c7%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c8%0 FROM test_non_nullable_integer statement ok @@ -557,10 +556,10 @@ SELECT c1*0 FROM test_non_nullable_decimal ---- 0 -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1/0 FROM test_non_nullable_decimal -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1%0 FROM test_non_nullable_decimal statement ok From fd121d3e29404a243a3c18c67c40fa7132ed9ed2 Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Fri, 22 Dec 2023 02:00:25 -0500 Subject: [PATCH 09/63] Add examples of DataFrame::write* methods without S3 dependency (#8606) --- datafusion-examples/README.md | 3 +- .../examples/dataframe_output.rs | 76 +++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 datafusion-examples/examples/dataframe_output.rs diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 305422ccd0be..057cdd475273 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -47,7 +47,8 @@ cargo run --example csv_sql - [`catalog.rs`](examples/external_dependency/catalog.rs): Register the table into a custom catalog - [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) - [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame against a local parquet file -- [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 +- [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 and writing back to s3 +- [`dataframe_output.rs`](examples/dataframe_output.rs): Examples of methods which write data out from a DataFrame - [`dataframe_in_memory.rs`](examples/dataframe_in_memory.rs): Run a query using a DataFrame against data in memory - [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde - [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify and anaylze `Expr`s diff --git a/datafusion-examples/examples/dataframe_output.rs b/datafusion-examples/examples/dataframe_output.rs new file mode 100644 index 000000000000..c773384dfcd5 --- /dev/null +++ b/datafusion-examples/examples/dataframe_output.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{dataframe::DataFrameWriteOptions, prelude::*}; +use datafusion_common::{parsers::CompressionTypeVariant, DataFusionError}; + +/// This example demonstrates the various methods to write out a DataFrame to local storage. +/// See datafusion-examples/examples/external_dependency/dataframe-to-s3.rs for an example +/// using a remote object store. +#[tokio::main] +async fn main() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + + let mut df = ctx.sql("values ('a'), ('b'), ('c')").await.unwrap(); + + // Ensure the column names and types match the target table + df = df.with_column_renamed("column1", "tablecol1").unwrap(); + + ctx.sql( + "create external table + test(tablecol1 varchar) + stored as parquet + location './datafusion-examples/test_table/'", + ) + .await? + .collect() + .await?; + + // This is equivalent to INSERT INTO test VALUES ('a'), ('b'), ('c'). + // The behavior of write_table depends on the TableProvider's implementation + // of the insert_into method. + df.clone() + .write_table("test", DataFrameWriteOptions::new()) + .await?; + + df.clone() + .write_parquet( + "./datafusion-examples/test_parquet/", + DataFrameWriteOptions::new(), + None, + ) + .await?; + + df.clone() + .write_csv( + "./datafusion-examples/test_csv/", + // DataFrameWriteOptions contains options which control how data is written + // such as compression codec + DataFrameWriteOptions::new().with_compression(CompressionTypeVariant::GZIP), + None, + ) + .await?; + + df.clone() + .write_json( + "./datafusion-examples/test_json/", + DataFrameWriteOptions::new(), + ) + .await?; + + Ok(()) +} From 0ff5305db6b03128282d31afac69fa727e1fe7c4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 22 Dec 2023 04:14:45 -0700 Subject: [PATCH 10/63] Implement logical plan serde for CopyTo (#8618) * Implement logical plan serde for CopyTo * add link to issue * clippy * remove debug logging --- datafusion/proto/proto/datafusion.proto | 21 + datafusion/proto/src/generated/pbjson.rs | 395 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 43 +- datafusion/proto/src/logical_plan/mod.rs | 86 +++- .../tests/cases/roundtrip_logical_plan.rs | 68 ++- 5 files changed, 603 insertions(+), 10 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index cc802ee95710..05f0b6434368 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -74,6 +74,7 @@ message LogicalPlanNode { PrepareNode prepare = 26; DropViewNode drop_view = 27; DistinctOnNode distinct_on = 28; + CopyToNode copy_to = 29; } } @@ -317,6 +318,26 @@ message DistinctOnNode { LogicalPlanNode input = 4; } +message CopyToNode { + LogicalPlanNode input = 1; + string output_url = 2; + bool single_file_output = 3; + oneof CopyOptions { + SQLOptions sql_options = 4; + FileTypeWriterOptions writer_options = 5; + } + string file_type = 6; +} + +message SQLOptions { + repeated SQLOption option = 1; +} + +message SQLOption { + string key = 1; + string value = 2; +} + message UnionNode { repeated LogicalPlanNode inputs = 1; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index fb3a3ad91d06..0fdeab0a40f6 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -3704,6 +3704,188 @@ impl<'de> serde::Deserialize<'de> for Constraints { deserializer.deserialize_struct("datafusion.Constraints", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for CopyToNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if !self.output_url.is_empty() { + len += 1; + } + if self.single_file_output { + len += 1; + } + if !self.file_type.is_empty() { + len += 1; + } + if self.copy_options.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CopyToNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if !self.output_url.is_empty() { + struct_ser.serialize_field("outputUrl", &self.output_url)?; + } + if self.single_file_output { + struct_ser.serialize_field("singleFileOutput", &self.single_file_output)?; + } + if !self.file_type.is_empty() { + struct_ser.serialize_field("fileType", &self.file_type)?; + } + if let Some(v) = self.copy_options.as_ref() { + match v { + copy_to_node::CopyOptions::SqlOptions(v) => { + struct_ser.serialize_field("sqlOptions", v)?; + } + copy_to_node::CopyOptions::WriterOptions(v) => { + struct_ser.serialize_field("writerOptions", v)?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CopyToNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "output_url", + "outputUrl", + "single_file_output", + "singleFileOutput", + "file_type", + "fileType", + "sql_options", + "sqlOptions", + "writer_options", + "writerOptions", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + OutputUrl, + SingleFileOutput, + FileType, + SqlOptions, + WriterOptions, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "outputUrl" | "output_url" => Ok(GeneratedField::OutputUrl), + "singleFileOutput" | "single_file_output" => Ok(GeneratedField::SingleFileOutput), + "fileType" | "file_type" => Ok(GeneratedField::FileType), + "sqlOptions" | "sql_options" => Ok(GeneratedField::SqlOptions), + "writerOptions" | "writer_options" => Ok(GeneratedField::WriterOptions), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CopyToNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CopyToNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut output_url__ = None; + let mut single_file_output__ = None; + let mut file_type__ = None; + let mut copy_options__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::OutputUrl => { + if output_url__.is_some() { + return Err(serde::de::Error::duplicate_field("outputUrl")); + } + output_url__ = Some(map_.next_value()?); + } + GeneratedField::SingleFileOutput => { + if single_file_output__.is_some() { + return Err(serde::de::Error::duplicate_field("singleFileOutput")); + } + single_file_output__ = Some(map_.next_value()?); + } + GeneratedField::FileType => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fileType")); + } + file_type__ = Some(map_.next_value()?); + } + GeneratedField::SqlOptions => { + if copy_options__.is_some() { + return Err(serde::de::Error::duplicate_field("sqlOptions")); + } + copy_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::CopyOptions::SqlOptions) +; + } + GeneratedField::WriterOptions => { + if copy_options__.is_some() { + return Err(serde::de::Error::duplicate_field("writerOptions")); + } + copy_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::CopyOptions::WriterOptions) +; + } + } + } + Ok(CopyToNode { + input: input__, + output_url: output_url__.unwrap_or_default(), + single_file_output: single_file_output__.unwrap_or_default(), + file_type: file_type__.unwrap_or_default(), + copy_options: copy_options__, + }) + } + } + deserializer.deserialize_struct("datafusion.CopyToNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CreateCatalogNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -13336,6 +13518,9 @@ impl serde::Serialize for LogicalPlanNode { logical_plan_node::LogicalPlanType::DistinctOn(v) => { struct_ser.serialize_field("distinctOn", v)?; } + logical_plan_node::LogicalPlanType::CopyTo(v) => { + struct_ser.serialize_field("copyTo", v)?; + } } } struct_ser.end() @@ -13387,6 +13572,8 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "dropView", "distinct_on", "distinctOn", + "copy_to", + "copyTo", ]; #[allow(clippy::enum_variant_names)] @@ -13418,6 +13605,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { Prepare, DropView, DistinctOn, + CopyTo, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -13466,6 +13654,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "prepare" => Ok(GeneratedField::Prepare), "dropView" | "drop_view" => Ok(GeneratedField::DropView), "distinctOn" | "distinct_on" => Ok(GeneratedField::DistinctOn), + "copyTo" | "copy_to" => Ok(GeneratedField::CopyTo), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -13675,6 +13864,13 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { return Err(serde::de::Error::duplicate_field("distinctOn")); } logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DistinctOn) +; + } + GeneratedField::CopyTo => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("copyTo")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CopyTo) ; } } @@ -20742,6 +20938,205 @@ impl<'de> serde::Deserialize<'de> for RollupNode { deserializer.deserialize_struct("datafusion.RollupNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for SqlOption { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.key.is_empty() { + len += 1; + } + if !self.value.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SQLOption", len)?; + if !self.key.is_empty() { + struct_ser.serialize_field("key", &self.key)?; + } + if !self.value.is_empty() { + struct_ser.serialize_field("value", &self.value)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SqlOption { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "key", + "value", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Key, + Value, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "key" => Ok(GeneratedField::Key), + "value" => Ok(GeneratedField::Value), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SqlOption; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SQLOption") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut key__ = None; + let mut value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Key => { + if key__.is_some() { + return Err(serde::de::Error::duplicate_field("key")); + } + key__ = Some(map_.next_value()?); + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = Some(map_.next_value()?); + } + } + } + Ok(SqlOption { + key: key__.unwrap_or_default(), + value: value__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SQLOption", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for SqlOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.option.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SQLOptions", len)?; + if !self.option.is_empty() { + struct_ser.serialize_field("option", &self.option)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SqlOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "option", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Option, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "option" => Ok(GeneratedField::Option), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SqlOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SQLOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut option__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Option => { + if option__.is_some() { + return Err(serde::de::Error::duplicate_field("option")); + } + option__ = Some(map_.next_value()?); + } + } + } + Ok(SqlOptions { + option: option__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SQLOptions", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarDictionaryValue { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 9030e90a24c8..e44355859d65 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -38,7 +38,7 @@ pub struct DfSchema { pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29" )] pub logical_plan_type: ::core::option::Option, } @@ -101,6 +101,8 @@ pub mod logical_plan_node { DropView(super::DropViewNode), #[prost(message, tag = "28")] DistinctOn(::prost::alloc::boxed::Box), + #[prost(message, tag = "29")] + CopyTo(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -502,6 +504,45 @@ pub struct DistinctOnNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct CopyToNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(string, tag = "2")] + pub output_url: ::prost::alloc::string::String, + #[prost(bool, tag = "3")] + pub single_file_output: bool, + #[prost(string, tag = "6")] + pub file_type: ::prost::alloc::string::String, + #[prost(oneof = "copy_to_node::CopyOptions", tags = "4, 5")] + pub copy_options: ::core::option::Option, +} +/// Nested message and enum types in `CopyToNode`. +pub mod copy_to_node { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum CopyOptions { + #[prost(message, tag = "4")] + SqlOptions(super::SqlOptions), + #[prost(message, tag = "5")] + WriterOptions(super::FileTypeWriterOptions), + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SqlOptions { + #[prost(message, repeated, tag = "1")] + pub option: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SqlOption { + #[prost(string, tag = "1")] + pub key: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub value: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 948228d87d46..e03b3ffa7b84 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -22,7 +22,9 @@ use std::sync::Arc; use crate::common::{byte_to_string, proto_error, str_to_byte}; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; -use crate::protobuf::{CustomTableScanNode, LogicalExprNodeCollection}; +use crate::protobuf::{ + copy_to_node, CustomTableScanNode, LogicalExprNodeCollection, SqlOption, +}; use crate::{ convert_required, protobuf::{ @@ -44,12 +46,13 @@ use datafusion::{ datasource::{provider_as_source, source_as_provider}, prelude::SessionContext, }; -use datafusion_common::plan_datafusion_err; use datafusion_common::{ - context, internal_err, not_impl_err, parsers::CompressionTypeVariant, - DataFusionError, OwnedTableReference, Result, + context, file_options::StatementOptions, internal_err, not_impl_err, + parsers::CompressionTypeVariant, plan_datafusion_err, DataFusionError, FileType, + OwnedTableReference, Result, }; use datafusion_expr::{ + dml, logical_plan::{ builder::project, Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateView, CrossJoin, DdlStatement, Distinct, @@ -59,6 +62,7 @@ use datafusion_expr::{ DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, }; +use datafusion_expr::dml::CopyOptions; use prost::bytes::BufMut; use prost::Message; @@ -823,6 +827,36 @@ impl AsLogicalPlan for LogicalPlanNode { schema: Arc::new(convert_required!(dropview.schema)?), }), )), + LogicalPlanType::CopyTo(copy) => { + let input: LogicalPlan = + into_logical_plan!(copy.input, ctx, extension_codec)?; + + let copy_options = match ©.copy_options { + Some(copy_to_node::CopyOptions::SqlOptions(opt)) => { + let options = opt.option.iter().map(|o| (o.key.clone(), o.value.clone())).collect(); + CopyOptions::SQLOptions(StatementOptions::from( + &options, + )) + } + Some(copy_to_node::CopyOptions::WriterOptions(_)) => { + return Err(proto_error( + "LogicalPlan serde is not yet implemented for CopyTo with WriterOptions", + )) + } + other => return Err(proto_error(format!( + "LogicalPlan serde is not yet implemented for CopyTo with CopyOptions {other:?}", + ))) + }; + Ok(datafusion_expr::LogicalPlan::Copy( + datafusion_expr::dml::CopyTo { + input: Arc::new(input), + output_url: copy.output_url.clone(), + file_format: FileType::from_str(©.file_type)?, + single_file_output: copy.single_file_output, + copy_options, + }, + )) + } } } @@ -1534,9 +1568,47 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::Dml(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for Dml", )), - LogicalPlan::Copy(_) => Err(proto_error( - "LogicalPlan serde is not yet implemented for Copy", - )), + LogicalPlan::Copy(dml::CopyTo { + input, + output_url, + single_file_output, + file_format, + copy_options, + }) => { + let input = protobuf::LogicalPlanNode::try_from_logical_plan( + input, + extension_codec, + )?; + + let copy_options_proto: Option = match copy_options { + CopyOptions::SQLOptions(opt) => { + let options: Vec = opt.clone().into_inner().iter().map(|(k, v)| SqlOption { + key: k.to_string(), + value: v.to_string(), + }).collect(); + Some(copy_to_node::CopyOptions::SqlOptions(protobuf::SqlOptions { + option: options + })) + } + CopyOptions::WriterOptions(_) => { + return Err(proto_error( + "LogicalPlan serde is not yet implemented for CopyTo with WriterOptions", + )) + } + }; + + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::CopyTo(Box::new( + protobuf::CopyToNode { + input: Some(Box::new(input)), + single_file_output: *single_file_output, + output_url: output_url.to_string(), + file_type: file_format.to_string(), + copy_options: copy_options_proto, + }, + ))), + }) + } LogicalPlan::DescribeTable(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for DescribeTable", )), diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 8e15b5d0d480..9798b06f4724 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -31,12 +31,16 @@ use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext}; use datafusion::test_util::{TestTableFactory, TestTableProvider}; -use datafusion_common::Result; -use datafusion_common::{internal_err, not_impl_err, plan_err}; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; +use datafusion_common::file_options::StatementOptions; +use datafusion_common::{internal_err, not_impl_err, plan_err, FileTypeWriterOptions}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue}; +use datafusion_common::{FileType, Result}; +use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, ScalarFunction, Sort, @@ -301,6 +305,66 @@ async fn roundtrip_logical_plan_aggregation() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { + let ctx = SessionContext::new(); + + let input = create_csv_scan(&ctx).await?; + + let mut options = HashMap::new(); + options.insert("foo".to_string(), "bar".to_string()); + + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.csv".to_string(), + file_format: FileType::CSV, + single_file_output: true, + copy_options: CopyOptions::SQLOptions(StatementOptions::from(&options)), + }); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + +#[tokio::test] +#[ignore] // see https://github.com/apache/arrow-datafusion/issues/8619 +async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { + let ctx = SessionContext::new(); + + let input = create_csv_scan(&ctx).await?; + + let writer_properties = WriterProperties::builder() + .set_bloom_filter_enabled(true) + .set_created_by("DataFusion Test".to_string()) + .build(); + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.csv".to_string(), + file_format: FileType::CSV, + single_file_output: true, + copy_options: CopyOptions::WriterOptions(Box::new( + FileTypeWriterOptions::Parquet(ParquetWriterOptions::new(writer_properties)), + )), + }); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + +async fn create_csv_scan(ctx: &SessionContext) -> Result { + ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) + .await?; + + let input = ctx.table("t1").await?.into_optimized_plan()?; + Ok(input) +} + #[tokio::test] async fn roundtrip_logical_plan_distinct_on() -> Result<()> { let ctx = SessionContext::new(); From 55121d8e48d99178a72a5dbaa773f1fbf4a2e059 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 22 Dec 2023 06:15:13 -0500 Subject: [PATCH 11/63] Fix InListExpr to return the correct number of rows (#8601) * Fix InListExpr to return the correct number of rows * Reduce repetition --- .../physical-expr/src/expressions/in_list.rs | 57 +++++++++++++++++-- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 625b01ec9a7e..1a1634081c38 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -349,17 +349,18 @@ impl PhysicalExpr for InListExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { + let num_rows = batch.num_rows(); let value = self.expr.evaluate(batch)?; let r = match &self.static_filter { - Some(f) => f.contains(value.into_array(1)?.as_ref(), self.negated)?, + Some(f) => f.contains(value.into_array(num_rows)?.as_ref(), self.negated)?, None => { - let value = value.into_array(batch.num_rows())?; + let value = value.into_array(num_rows)?; let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold( - BooleanArray::new(BooleanBuffer::new_unset(batch.num_rows()), None), + BooleanArray::new(BooleanBuffer::new_unset(num_rows), None), |result, expr| -> Result { Ok(or_kleene( &result, - &eq(&value, &expr?.into_array(batch.num_rows())?)?, + &eq(&value, &expr?.into_array(num_rows)?)?, )?) }, )?; @@ -1267,4 +1268,52 @@ mod tests { Ok(()) } + + #[test] + fn in_list_no_cols() -> Result<()> { + // test logic when the in_list expression doesn't have any columns + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![Some(1), Some(2), None]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + let list = vec![lit(ScalarValue::from(1i32)), lit(ScalarValue::from(6i32))]; + + // 1 IN (1, 6) + let expr = lit(ScalarValue::Int32(Some(1))); + in_list!( + batch, + list.clone(), + &false, + // should have three outputs, as the input batch has three rows + vec![Some(true), Some(true), Some(true)], + expr, + &schema + ); + + // 2 IN (1, 6) + let expr = lit(ScalarValue::Int32(Some(2))); + in_list!( + batch, + list.clone(), + &false, + // should have three outputs, as the input batch has three rows + vec![Some(false), Some(false), Some(false)], + expr, + &schema + ); + + // NULL IN (1, 6) + let expr = lit(ScalarValue::Int32(None)); + in_list!( + batch, + list.clone(), + &false, + // should have three outputs, as the input batch has three rows + vec![None, None, None], + expr, + &schema + ); + + Ok(()) + } } From 39e9f41a21e8e2ffac39feabd13d6aa7eda5f213 Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Fri, 22 Dec 2023 06:56:27 -0500 Subject: [PATCH 12/63] Remove ListingTable single_file option (#8604) * remove listingtable single_file option * prettier --------- Co-authored-by: Andrew Lamb --- datafusion/core/src/datasource/listing/table.rs | 12 +----------- .../core/src/datasource/listing_table_factory.rs | 9 ++------- docs/source/user-guide/sql/write_options.md | 15 +++------------ 3 files changed, 6 insertions(+), 30 deletions(-) diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 21d43dcd56db..a7af1bf1be28 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -246,9 +246,6 @@ pub struct ListingOptions { /// multiple equivalent orderings, the outer `Vec` will have a /// single element. pub file_sort_order: Vec>, - /// This setting when true indicates that the table is backed by a single file. - /// Any inserts to the table may only append to this existing file. - pub single_file: bool, /// This setting holds file format specific options which should be used /// when inserting into this table. pub file_type_write_options: Option, @@ -269,7 +266,6 @@ impl ListingOptions { collect_stat: true, target_partitions: 1, file_sort_order: vec![], - single_file: false, file_type_write_options: None, } } @@ -421,12 +417,6 @@ impl ListingOptions { self } - /// Configure if this table is backed by a sigle file - pub fn with_single_file(mut self, single_file: bool) -> Self { - self.single_file = single_file; - self - } - /// Configure file format specific writing options. pub fn with_write_options( mut self, @@ -790,7 +780,7 @@ impl TableProvider for ListingTable { file_groups, output_schema: self.schema(), table_partition_cols: self.options.table_partition_cols.clone(), - single_file_output: self.options.single_file, + single_file_output: false, overwrite, file_type_writer_options, }; diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 68c97bbb7806..e8ffece320d7 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -135,12 +135,8 @@ impl TableProviderFactory for ListingTableFactory { let mut statement_options = StatementOptions::from(&cmd.options); - // Extract ListingTable specific options if present or set default - let single_file = statement_options - .take_bool_option("single_file")? - .unwrap_or(false); - - // Backwards compatibility (#8547) + // Backwards compatibility (#8547), discard deprecated options + statement_options.take_bool_option("single_file")?; if let Some(s) = statement_options.take_str_option("insert_mode") { if !s.eq_ignore_ascii_case("append_new_files") { return plan_err!("Unknown or unsupported insert mode {s}. Only append_new_files supported"); @@ -195,7 +191,6 @@ impl TableProviderFactory for ListingTableFactory { .with_target_partitions(state.config().target_partitions()) .with_table_partition_cols(table_partition_cols) .with_file_sort_order(cmd.order_exprs.clone()) - .with_single_file(single_file) .with_write_options(file_type_writer_options); let resolved_schema = match provided_schema { diff --git a/docs/source/user-guide/sql/write_options.md b/docs/source/user-guide/sql/write_options.md index 94adee960996..470591afafff 100644 --- a/docs/source/user-guide/sql/write_options.md +++ b/docs/source/user-guide/sql/write_options.md @@ -42,12 +42,11 @@ WITH HEADER ROW DELIMITER ';' LOCATION '/test/location/my_csv_table/' OPTIONS( -CREATE_LOCAL_PATH 'true', NULL_VALUE 'NAN' ); ``` -When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (gzip compression, special delimiter, and header row included). Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified the `OPTIONS` setting will be ignored. CREATE_LOCAL_PATH is a special option that indicates if DataFusion should create local file paths when writing new files if they do not already exist. This option is useful if you wish to create an external table from scratch, using only DataFusion SQL statements. Finally, NULL_VALUE is a CSV format specific option that determines how null values should be encoded within the CSV file. +When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (gzip compression, special delimiter, and header row included). Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified the `OPTIONS` setting will be ignored. NULL_VALUE is a CSV format specific option that determines how null values should be encoded within the CSV file. Finally, options can be passed when running a `COPY` command. @@ -70,17 +69,9 @@ In this example, we write the entirety of `source_table` out to a folder of parq The following special options are specific to the `COPY` command. | Option | Description | Default Value | -| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | +| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | --- | | SINGLE_FILE_OUTPUT | If true, COPY query will write output to a single file. Otherwise, multiple files will be written to a directory in parallel. | true | -| FORMAT | Specifies the file format COPY query will write out. If single_file_output is false or the format cannot be inferred from the file extension, then FORMAT must be specified. | N/A | - -### CREATE EXTERNAL TABLE Specific Options - -The following special options are specific to creating an external table. - -| Option | Description | Default Value | -| ----------- | --------------------------------------------------------------------------------------------------------------------- | ------------- | -| SINGLE_FILE | If true, indicates that this external table is backed by a single file. INSERT INTO queries will append to this file. | false | +| FORMAT | Specifies the file format COPY query will write out. If single_file_output is false or the format cannot be inferred from the file extension, then FORMAT must be specified. | N/A | | ### JSON Format Specific Options From ef34af8877d25cd84006806b355127179e2d4c89 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Fri, 22 Dec 2023 13:36:01 +0100 Subject: [PATCH 13/63] support LargeList in array_remove (#8595) --- .../physical-expr/src/array_expressions.rs | 114 ++++++-- datafusion/sqllogictest/test_files/array.slt | 269 ++++++++++++++++++ 2 files changed, 365 insertions(+), 18 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index bdab65cab9e3..4dfc157e53c7 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -100,6 +100,14 @@ fn compare_element_to_list( row_index: usize, eq: bool, ) -> Result { + if list_array_row.data_type() != element_array.data_type() { + return exec_err!( + "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.", + list_array_row.data_type(), + element_array.data_type() + ); + } + let indices = UInt32Array::from(vec![row_index as u32]); let element_array_row = arrow::compute::take(element_array, &indices, None)?; @@ -126,6 +134,26 @@ fn compare_element_to_list( }) .collect::() } + DataType::LargeList(_) => { + // compare each element of the from array + let element_array_row_inner = + as_large_list_array(&element_array_row)?.value(0); + let list_array_row_inner = as_large_list_array(list_array_row)?; + + list_array_row_inner + .iter() + // compare element by element the current row of list_array + .map(|row| { + row.map(|row| { + if eq { + row.eq(&element_array_row_inner) + } else { + row.ne(&element_array_row_inner) + } + }) + }) + .collect::() + } _ => { let element_arr = Scalar::new(element_array_row); // use not_distinct so we can compare NULL @@ -1511,14 +1539,14 @@ pub fn array_remove_n(args: &[ArrayRef]) -> Result { /// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced) /// ) /// ``` -fn general_replace( - list_array: &ListArray, +fn general_replace( + list_array: &GenericListArray, from_array: &ArrayRef, to_array: &ArrayRef, arr_n: Vec, ) -> Result { // Build up the offsets for the final output array - let mut offsets: Vec = vec![0]; + let mut offsets: Vec = vec![O::usize_as(0)]; let values = list_array.values(); let original_data = values.to_data(); let to_data = to_array.to_data(); @@ -1540,8 +1568,8 @@ fn general_replace( continue; } - let start = offset_window[0] as usize; - let end = offset_window[1] as usize; + let start = offset_window[0]; + let end = offset_window[1]; let list_array_row = list_array.value(row_index); @@ -1550,43 +1578,56 @@ fn general_replace( let eq_array = compare_element_to_list(&list_array_row, &from_array, row_index, true)?; - let original_idx = 0; - let replace_idx = 1; + let original_idx = O::usize_as(0); + let replace_idx = O::usize_as(1); let n = arr_n[row_index]; let mut counter = 0; // All elements are false, no need to replace, just copy original data if eq_array.false_count() == eq_array.len() { - mutable.extend(original_idx, start, end); - offsets.push(offsets[row_index] + (end - start) as i32); + mutable.extend( + original_idx.to_usize().unwrap(), + start.to_usize().unwrap(), + end.to_usize().unwrap(), + ); + offsets.push(offsets[row_index] + (end - start)); valid.append(true); continue; } for (i, to_replace) in eq_array.iter().enumerate() { + let i = O::usize_as(i); if let Some(true) = to_replace { - mutable.extend(replace_idx, row_index, row_index + 1); + mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1); counter += 1; if counter == n { // copy original data for any matches past n - mutable.extend(original_idx, start + i + 1, end); + mutable.extend( + original_idx.to_usize().unwrap(), + (start + i).to_usize().unwrap() + 1, + end.to_usize().unwrap(), + ); break; } } else { // copy original data for false / null matches - mutable.extend(original_idx, start + i, start + i + 1); + mutable.extend( + original_idx.to_usize().unwrap(), + (start + i).to_usize().unwrap(), + (start + i).to_usize().unwrap() + 1, + ); } } - offsets.push(offsets[row_index] + (end - start) as i32); + offsets.push(offsets[row_index] + (end - start)); valid.append(true); } let data = mutable.freeze(); - Ok(Arc::new(ListArray::try_new( + Ok(Arc::new(GenericListArray::::try_new( Arc::new(Field::new("item", list_array.value_type(), true)), - OffsetBuffer::new(offsets.into()), + OffsetBuffer::::new(offsets.into()), arrow_array::make_array(data), Some(NullBuffer::new(valid.finish())), )?)) @@ -1595,19 +1636,56 @@ fn general_replace( pub fn array_replace(args: &[ArrayRef]) -> Result { // replace at most one occurence for each element let arr_n = vec![1; args[0].len()]; - general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => exec_err!("array_replace does not support type '{array_type:?}'."), + } } pub fn array_replace_n(args: &[ArrayRef]) -> Result { // replace the specified number of occurences let arr_n = as_int64_array(&args[3])?.values().to_vec(); - general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => { + exec_err!("array_replace_n does not support type '{array_type:?}'.") + } + } } pub fn array_replace_all(args: &[ArrayRef]) -> Result { // replace all occurrences (up to "i64::MAX") let arr_n = vec![i64::MAX; args[0].len()]; - general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => { + exec_err!("array_replace_all does not support type '{array_type:?}'.") + } + } } macro_rules! to_string { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index ca33f08de06d..283f2d67b7a0 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -298,6 +298,17 @@ AS VALUES (make_array(10, 11, 12, 10, 11, 12, 10, 11, 12, 10), 10, 13, 10) ; +statement ok +CREATE TABLE large_arrays_with_repeating_elements +AS + SELECT + arrow_cast(column1, 'LargeList(Int64)') AS column1, + column2, + column3, + column4 + FROM arrays_with_repeating_elements +; + statement ok CREATE TABLE nested_arrays_with_repeating_elements AS VALUES @@ -307,6 +318,17 @@ AS VALUES (make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10) ; +statement ok +CREATE TABLE large_nested_arrays_with_repeating_elements +AS + SELECT + arrow_cast(column1, 'LargeList(List(Int64))') AS column1, + column2, + column3, + column4 + FROM nested_arrays_with_repeating_elements +; + query error select [1, true, null] @@ -2010,6 +2032,14 @@ select ---- [1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] +query ??? +select + array_replace(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + array_replace(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + array_replace(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] + # array_replace scalar function #2 (element is list) query ?? select @@ -2026,6 +2056,21 @@ select ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]] +query ?? +select + array_replace( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), + [4, 5, 6], + [1, 1, 1] + ), + array_replace( + arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), + [2, 3, 4], + [3, 1, 4] + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]] + # list_replace scalar function #3 (function alias `list_replace`) query ??? select list_replace( @@ -2035,6 +2080,14 @@ select list_replace( ---- [1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] +query ??? +select list_replace( + arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + list_replace(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + list_replace(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] + # array_replace scalar function with columns #1 query ? select array_replace(column1, column2, column3) from arrays_with_repeating_elements; @@ -2044,6 +2097,14 @@ select array_replace(column1, column2, column3) from arrays_with_repeating_eleme [10, 7, 7, 8, 7, 9, 7, 8, 7, 7] [13, 11, 12, 10, 11, 12, 10, 11, 12, 10] +query ? +select array_replace(column1, column2, column3) from large_arrays_with_repeating_elements; +---- +[1, 4, 1, 3, 2, 2, 1, 3, 2, 3] +[7, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[10, 7, 7, 8, 7, 9, 7, 8, 7, 7] +[13, 11, 12, 10, 11, 12, 10, 11, 12, 10] + # array_replace scalar function with columns #2 (element is list) query ? select array_replace(column1, column2, column3) from nested_arrays_with_repeating_elements; @@ -2053,6 +2114,14 @@ select array_replace(column1, column2, column3) from nested_arrays_with_repeatin [[28, 29, 30], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] +query ? +select array_replace(column1, column2, column3) from large_nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[19, 20, 21], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[28, 29, 30], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + # array_replace scalar function with columns and scalars #1 query ??? select @@ -2066,6 +2135,18 @@ from arrays_with_repeating_elements; [1, 2, 2, 4, 5, 4, 4, 10, 7, 10, 7, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 7, 7, 8, 7, 9, 7, 8, 7, 7] [1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 10, 11, 12, 10, 11, 12, 10] +query ??? +select + array_replace(arrow_cast(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), 'LargeList(Int64)'), column2, column3), + array_replace(column1, 1, column3), + array_replace(column1, column2, 4) +from large_arrays_with_repeating_elements; +---- +[1, 4, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 1, 3, 2, 2, 1, 3, 2, 3] [1, 4, 1, 3, 2, 2, 1, 3, 2, 3] +[1, 2, 2, 7, 5, 4, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 7, 10, 7, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 7, 7, 8, 7, 9, 7, 8, 7, 7] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 10, 11, 12, 10, 11, 12, 10] + # array_replace scalar function with columns and scalars #2 (element is list) query ??? select @@ -2084,6 +2165,23 @@ from nested_arrays_with_repeating_elements; [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] +query ??? +select + array_replace( + arrow_cast(make_array( + [1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]),'LargeList(List(Int64))'), + column2, + column3 + ), + array_replace(column1, make_array(1, 2, 3), column3), + array_replace(column1, column2, make_array(11, 12, 13)) +from large_nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + ## array_replace_n (aliases: `list_replace_n`) # array_replace_n scalar function #1 @@ -2095,6 +2193,14 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] +query ??? +select + array_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3, 2), + array_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0, 2), + array_replace_n(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0, 3); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] + # array_replace_n scalar function #2 (element is list) query ?? select @@ -2113,6 +2219,23 @@ select ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] +query ?? +select + array_replace_n( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), + [4, 5, 6], + [1, 1, 1], + 2 + ), + array_replace_n( + arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), + [2, 3, 4], + [3, 1, 4], + 2 + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] + # list_replace_n scalar function #3 (function alias `array_replace_n`) query ??? select @@ -2122,6 +2245,14 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] +query ??? +select + list_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3, 2), + list_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0, 2), + list_replace_n(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0, 3); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] + # array_replace_n scalar function with columns #1 query ? select @@ -2133,6 +2264,16 @@ from arrays_with_repeating_elements; [10, 10, 10, 8, 10, 9, 10, 8, 7, 7] [13, 11, 12, 13, 11, 12, 13, 11, 12, 13] +query ? +select + array_replace_n(column1, column2, column3, column4) +from large_arrays_with_repeating_elements; +---- +[1, 4, 1, 3, 4, 4, 1, 3, 2, 3] +[7, 7, 5, 5, 6, 5, 5, 5, 4, 4] +[10, 10, 10, 8, 10, 9, 10, 8, 7, 7] +[13, 11, 12, 13, 11, 12, 13, 11, 12, 13] + # array_replace_n scalar function with columns #2 (element is list) query ? select @@ -2144,6 +2285,17 @@ from nested_arrays_with_repeating_elements; [[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] +query ? +select + array_replace_n(column1, column2, column3, column4) +from large_nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] + + # array_replace_n scalar function with columns and scalars #1 query ???? select @@ -2158,6 +2310,19 @@ from arrays_with_repeating_elements; [1, 2, 2, 4, 5, 4, 4, 10, 10, 10, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 4, 4, 8, 4, 9, 4, 8, 7, 7] [10, 10, 7, 8, 7, 9, 7, 8, 7, 7] [1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 4, 11, 12, 4, 11, 12, 4] [13, 11, 12, 13, 11, 12, 10, 11, 12, 10] +query ???? +select + array_replace_n(arrow_cast(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), 'LargeList(Int64)'), column2, column3, column4), + array_replace_n(column1, 1, column3, column4), + array_replace_n(column1, column2, 4, column4), + array_replace_n(column1, column2, column3, 2) +from large_arrays_with_repeating_elements; +---- +[1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 2, 3] [1, 4, 1, 3, 4, 2, 1, 3, 2, 3] +[1, 2, 2, 7, 5, 7, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [7, 7, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 10, 10, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 4, 4, 8, 4, 9, 4, 8, 7, 7] [10, 10, 7, 8, 7, 9, 7, 8, 7, 7] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 4, 11, 12, 4, 11, 12, 4] [13, 11, 12, 13, 11, 12, 10, 11, 12, 10] + # array_replace_n scalar function with columns and scalars #2 (element is list) query ???? select @@ -2178,6 +2343,25 @@ from nested_arrays_with_repeating_elements; [[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[28, 29, 30], [28, 29, 30], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] +query ???? +select + array_replace_n( + arrow_cast(make_array( + [7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]), 'LargeList(List(Int64))'), + column2, + column3, + column4 + ), + array_replace_n(column1, make_array(1, 2, 3), column3, column4), + array_replace_n(column1, column2, make_array(11, 12, 13), column4), + array_replace_n(column1, column2, column3, 2) +from large_nested_arrays_with_repeating_elements; +---- +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [10, 11, 12]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [19, 20, 21], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[28, 29, 30], [28, 29, 30], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + ## array_replace_all (aliases: `list_replace_all`) # array_replace_all scalar function #1 @@ -2189,6 +2373,14 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] +query ??? +select + array_replace_all(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + array_replace_all(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + array_replace_all(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] + # array_replace_all scalar function #2 (element is list) query ?? select @@ -2205,6 +2397,21 @@ select ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] +query ?? +select + array_replace_all( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), + [4, 5, 6], + [1, 1, 1] + ), + array_replace_all( + arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), + [2, 3, 4], + [3, 1, 4] + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] + # list_replace_all scalar function #3 (function alias `array_replace_all`) query ??? select @@ -2214,6 +2421,14 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] +query ??? +select + list_replace_all(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + list_replace_all(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + list_replace_all(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] + # array_replace_all scalar function with columns #1 query ? select @@ -2225,6 +2440,16 @@ from arrays_with_repeating_elements; [10, 10, 10, 8, 10, 9, 10, 8, 10, 10] [13, 11, 12, 13, 11, 12, 13, 11, 12, 13] +query ? +select + array_replace_all(column1, column2, column3) +from large_arrays_with_repeating_elements; +---- +[1, 4, 1, 3, 4, 4, 1, 3, 4, 3] +[7, 7, 5, 5, 6, 5, 5, 5, 7, 7] +[10, 10, 10, 8, 10, 9, 10, 8, 10, 10] +[13, 11, 12, 13, 11, 12, 13, 11, 12, 13] + # array_replace_all scalar function with columns #2 (element is list) query ? select @@ -2236,6 +2461,16 @@ from nested_arrays_with_repeating_elements; [[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [28, 29, 30], [28, 29, 30]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] +query ? +select + array_replace_all(column1, column2, column3) +from large_nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [7, 8, 9]] +[[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [19, 20, 21], [19, 20, 21]] +[[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [28, 29, 30], [28, 29, 30]] +[[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] + # array_replace_all scalar function with columns and scalars #1 query ??? select @@ -2249,6 +2484,18 @@ from arrays_with_repeating_elements; [1, 2, 2, 4, 5, 4, 4, 10, 10, 10, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 4, 4, 8, 4, 9, 4, 8, 4, 4] [1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 4, 11, 12, 4, 11, 12, 4] +query ??? +select + array_replace_all(arrow_cast(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), 'LargeList(Int64)'), column2, column3), + array_replace_all(column1, 1, column3), + array_replace_all(column1, column2, 4) +from large_arrays_with_repeating_elements; +---- +[1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 4, 3] +[1, 2, 2, 7, 5, 7, 7, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 10, 10, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 4, 4, 8, 4, 9, 4, 8, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 4, 11, 12, 4, 11, 12, 4] + # array_replace_all scalar function with columns and scalars #2 (element is list) query ??? select @@ -2266,6 +2513,22 @@ from nested_arrays_with_repeating_elements; [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [11, 12, 13], [11, 12, 13]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] +query ??? +select + array_replace_all( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), 'LargeList(List(Int64))'), + column2, + column3 + ), + array_replace_all(column1, make_array(1, 2, 3), column3), + array_replace_all(column1, column2, make_array(11, 12, 13)) +from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [10, 11, 12], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [7, 8, 9]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [19, 20, 21], [19, 20, 21], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [11, 12, 13], [11, 12, 13]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [11, 12, 13], [11, 12, 13]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] + # array_replace with null handling statement ok @@ -3870,8 +4133,14 @@ drop table arrays_range; statement ok drop table arrays_with_repeating_elements; +statement ok +drop table large_arrays_with_repeating_elements; + statement ok drop table nested_arrays_with_repeating_elements; +statement ok +drop table large_nested_arrays_with_repeating_elements; + statement ok drop table flatten_table; From 0e62fa4df924f8657e43a97ca7aa8c6ca48bc08f Mon Sep 17 00:00:00 2001 From: Tomoaki Kawada Date: Fri, 22 Dec 2023 21:47:14 +0900 Subject: [PATCH 14/63] Rename `ParamValues::{LIST -> List,MAP -> Map}` (#8611) * Rename `ParamValues::{LIST -> List,MAP -> Map}` * Reformat the doc comments of `ParamValues::*` --- datafusion/common/src/param_value.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index 253c312b66d5..1b6195c0d0bc 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -23,17 +23,17 @@ use std::collections::HashMap; /// The parameter value corresponding to the placeholder #[derive(Debug, Clone)] pub enum ParamValues { - /// for positional query parameters, like select * from test where a > $1 and b = $2 - LIST(Vec), - /// for named query parameters, like select * from test where a > $foo and b = $goo - MAP(HashMap), + /// For positional query parameters, like `SELECT * FROM test WHERE a > $1 AND b = $2` + List(Vec), + /// For named query parameters, like `SELECT * FROM test WHERE a > $foo AND b = $goo` + Map(HashMap), } impl ParamValues { /// Verify parameter list length and type pub fn verify(&self, expect: &Vec) -> Result<()> { match self { - ParamValues::LIST(list) => { + ParamValues::List(list) => { // Verify if the number of params matches the number of values if expect.len() != list.len() { return _plan_err!( @@ -57,7 +57,7 @@ impl ParamValues { } Ok(()) } - ParamValues::MAP(_) => { + ParamValues::Map(_) => { // If it is a named query, variables can be reused, // but the lengths are not necessarily equal Ok(()) @@ -71,7 +71,7 @@ impl ParamValues { data_type: &Option, ) -> Result { match self { - ParamValues::LIST(list) => { + ParamValues::List(list) => { if id.is_empty() || id == "$0" { return _plan_err!("Empty placeholder id"); } @@ -97,7 +97,7 @@ impl ParamValues { } Ok(value.clone()) } - ParamValues::MAP(map) => { + ParamValues::Map(map) => { // convert name (in format $a, $b, ..) to mapped values (a, b, ..) let name = &id[1..]; // value at the name position in param_values should be the value for the placeholder @@ -122,7 +122,7 @@ impl ParamValues { impl From> for ParamValues { fn from(value: Vec) -> Self { - Self::LIST(value) + Self::List(value) } } @@ -133,7 +133,7 @@ where fn from(value: Vec<(K, ScalarValue)>) -> Self { let value: HashMap = value.into_iter().map(|(k, v)| (k.into(), v)).collect(); - Self::MAP(value) + Self::Map(value) } } @@ -144,6 +144,6 @@ where fn from(value: HashMap) -> Self { let value: HashMap = value.into_iter().map(|(k, v)| (k.into(), v)).collect(); - Self::MAP(value) + Self::Map(value) } } From 26a488d6ae0f45b33d2566b8b97d4f82a2e80fa3 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Sat, 23 Dec 2023 02:18:53 +0800 Subject: [PATCH 15/63] Support binary temporal coercion for Date64 and Timestamp types --- datafusion/expr/src/type_coercion/binary.rs | 6 ++++++ datafusion/sqllogictest/test_files/timestamps.slt | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index dd9449198796..1b62c1bc05c1 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -785,6 +785,12 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(Interval(MonthDayNano)), (Date64, Date32) | (Date32, Date64) => Some(Date64), + (Timestamp(_, None), Date64) | (Date64, Timestamp(_, None)) => { + Some(Timestamp(Nanosecond, None)) + } + (Timestamp(_, _tz), Date64) | (Date64, Timestamp(_, _tz)) => { + Some(Timestamp(Nanosecond, None)) + } (Timestamp(_, None), Date32) | (Date32, Timestamp(_, None)) => { Some(Timestamp(Nanosecond, None)) } diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index f956d59b1da0..2b3b4bf2e45b 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -1888,3 +1888,12 @@ true true true true true true #SELECT to_timestamp(-62125747200), to_timestamp(1926632005177), -62125747200::timestamp, 1926632005177::timestamp, cast(-62125747200 as timestamp), cast(1926632005177 as timestamp) #---- #0001-04-25T00:00:00 +63022-07-16T12:59:37 0001-04-25T00:00:00 +63022-07-16T12:59:37 0001-04-25T00:00:00 +63022-07-16T12:59:37 + +########## +## Test binary temporal coercion for Date and Timestamp +########## + +query B +select arrow_cast(now(), 'Date64') < arrow_cast('2022-02-02 02:02:02', 'Timestamp(Nanosecond, None)'); +---- +false From ba46434f839d612be01ee0d00e0d826475ce5f10 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Sat, 23 Dec 2023 03:28:47 +0800 Subject: [PATCH 16/63] Add new configuration item `listing_table_ignore_subdirectory` (#8565) * init * test * add config * rename * doc * fix doc * add sqllogictests & rename * fmt & fix test * clippy * test read partition table * simplify testing * simplify testing --- datafusion/common/src/config.rs | 5 +++ .../core/src/datasource/listing/helpers.rs | 4 +- datafusion/core/src/datasource/listing/url.rs | 26 ++++++++++--- .../core/src/execution/context/parquet.rs | 9 ++++- .../test_files/information_schema.slt | 2 + .../sqllogictest/test_files/parquet.slt | 38 ++++++++++++++++++- docs/source/user-guide/configs.md | 1 + 7 files changed, 75 insertions(+), 10 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 03fb5ea320a0..dedce74ff40d 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -273,6 +273,11 @@ config_namespace! { /// memory consumption pub max_buffered_batches_per_output_file: usize, default = 2 + /// When scanning file paths, whether to ignore subdirectory files, + /// ignored by default (true), when reading a partitioned table, + /// `listing_table_ignore_subdirectory` is always equal to false, even if set to true + pub listing_table_ignore_subdirectory: bool, default = true + } } diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index be74afa1f4d6..68de55e1a410 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -375,10 +375,10 @@ pub async fn pruned_partition_list<'a>( store.list(Some(&partition.path)).try_collect().await? } }; - let files = files.into_iter().filter(move |o| { let extension_match = o.location.as_ref().ends_with(file_extension); - let glob_match = table_path.contains(&o.location); + // here need to scan subdirectories(`listing_table_ignore_subdirectory` = false) + let glob_match = table_path.contains(&o.location, false); extension_match && glob_match }); diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index 3ca7864f7f9e..766dee7de901 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -20,6 +20,7 @@ use std::fs; use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::context::SessionState; use datafusion_common::{DataFusionError, Result}; +use datafusion_optimizer::OptimizerConfig; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; use glob::Pattern; @@ -184,14 +185,27 @@ impl ListingTableUrl { } /// Returns `true` if `path` matches this [`ListingTableUrl`] - pub fn contains(&self, path: &Path) -> bool { + pub fn contains(&self, path: &Path, ignore_subdirectory: bool) -> bool { match self.strip_prefix(path) { Some(mut segments) => match &self.glob { Some(glob) => { - let stripped = segments.join("/"); - glob.matches(&stripped) + if ignore_subdirectory { + segments + .next() + .map_or(false, |file_name| glob.matches(file_name)) + } else { + let stripped = segments.join("/"); + glob.matches(&stripped) + } + } + None => { + if ignore_subdirectory { + let has_subdirectory = segments.collect::>().len() > 1; + !has_subdirectory + } else { + true + } } - None => true, }, None => false, } @@ -223,6 +237,8 @@ impl ListingTableUrl { store: &'a dyn ObjectStore, file_extension: &'a str, ) -> Result>> { + let exec_options = &ctx.options().execution; + let ignore_subdirectory = exec_options.listing_table_ignore_subdirectory; // If the prefix is a file, use a head request, otherwise list let list = match self.is_collection() { true => match ctx.runtime_env().cache_manager.get_list_files_cache() { @@ -246,7 +262,7 @@ impl ListingTableUrl { .try_filter(move |meta| { let path = &meta.location; let extension_match = path.as_ref().ends_with(file_extension); - let glob_match = self.contains(path); + let glob_match = self.contains(path, ignore_subdirectory); futures::future::ready(extension_match && glob_match) }) .map_err(DataFusionError::ObjectStore) diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs index 5d649d3e6df8..7825d9b88297 100644 --- a/datafusion/core/src/execution/context/parquet.rs +++ b/datafusion/core/src/execution/context/parquet.rs @@ -80,6 +80,7 @@ mod tests { use crate::dataframe::DataFrameWriteOptions; use crate::parquet::basic::Compression; use crate::test_util::parquet_test_data; + use datafusion_execution::config::SessionConfig; use tempfile::tempdir; use super::*; @@ -103,8 +104,12 @@ mod tests { #[tokio::test] async fn read_with_glob_path_issue_2465() -> Result<()> { - let ctx = SessionContext::new(); - + let config = + SessionConfig::from_string_hash_map(std::collections::HashMap::from([( + "datafusion.execution.listing_table_ignore_subdirectory".to_owned(), + "false".to_owned(), + )]))?; + let ctx = SessionContext::new_with_config(config); let df = ctx .read_parquet( // it was reported that when a path contains // (two consecutive separator) no files were found diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 5c6bf6e2dac1..36876beb1447 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -150,6 +150,7 @@ datafusion.execution.aggregate.scalar_update_factor 10 datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true datafusion.execution.collect_statistics false +datafusion.execution.listing_table_ignore_subdirectory true datafusion.execution.max_buffered_batches_per_output_file 2 datafusion.execution.meta_fetch_concurrency 32 datafusion.execution.minimum_parallel_output_files 4 @@ -224,6 +225,7 @@ datafusion.execution.aggregate.scalar_update_factor 10 Specifies the threshold f datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files +datafusion.execution.listing_table_ignore_subdirectory true When scanning file paths, whether to ignore subdirectory files, ignored by default (true), when reading a partitioned table, `listing_table_ignore_subdirectory` is always equal to false, even if set to true datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt index 6c3bd687700a..0f26c14f0017 100644 --- a/datafusion/sqllogictest/test_files/parquet.slt +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -276,6 +276,39 @@ LIMIT 10; 0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) 0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +# Test config listing_table_ignore_subdirectory: + +query ITID +COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) +TO 'test_files/scratch/parquet/test_table/subdir/3.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +statement ok +CREATE EXTERNAL TABLE listing_table +STORED AS PARQUET +WITH HEADER ROW +LOCATION 'test_files/scratch/parquet/test_table/*.parquet'; + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = true; + +# scan file: 0.parquet 1.parquet 2.parquet +query I +select count(*) from listing_table; +---- +9 + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = false; + +# scan file: 0.parquet 1.parquet 2.parquet 3.parquet +query I +select count(*) from listing_table; +---- +12 + # Clean up statement ok DROP TABLE timestamp_with_tz; @@ -303,7 +336,6 @@ NULL statement ok DROP TABLE single_nan; - statement ok CREATE EXTERNAL TABLE list_columns STORED AS PARQUET @@ -319,3 +351,7 @@ SELECT int64_list, utf8_list FROM list_columns statement ok DROP TABLE list_columns; + +# Clean up +statement ok +DROP TABLE listing_table; diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 6fb5cc4ca870..1f7fa7760b94 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -82,6 +82,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | | datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | | datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | +| datafusion.execution.listing_table_ignore_subdirectory | true | When scanning file paths, whether to ignore subdirectory files, ignored by default (true), when reading a partitioned table, `listing_table_ignore_subdirectory` is always equal to false, even if set to true | | datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | | datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | | datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | From e4674929a1d17b2c2a80b8588fe61664606d9d63 Mon Sep 17 00:00:00 2001 From: Tomoaki Kawada Date: Sat, 23 Dec 2023 05:59:41 +0900 Subject: [PATCH 17/63] Optimize the parameter types of `ParamValues`'s methods (#8613) * Take `&str` instead of `&String` in `ParamValue::get_placeholders_with_values` * Take `Option<&DataType>` instead of `&Option` in `ParamValue::get_placeholders_with_values` * Take `&[_]` instead of `&Vec<_>` in `ParamValues::verify` --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/param_value.rs | 10 +++++----- datafusion/expr/src/logical_plan/plan.rs | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index 1b6195c0d0bc..004c1371d1ae 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -31,7 +31,7 @@ pub enum ParamValues { impl ParamValues { /// Verify parameter list length and type - pub fn verify(&self, expect: &Vec) -> Result<()> { + pub fn verify(&self, expect: &[DataType]) -> Result<()> { match self { ParamValues::List(list) => { // Verify if the number of params matches the number of values @@ -67,8 +67,8 @@ impl ParamValues { pub fn get_placeholders_with_values( &self, - id: &String, - data_type: &Option, + id: &str, + data_type: Option<&DataType>, ) -> Result { match self { ParamValues::List(list) => { @@ -88,7 +88,7 @@ impl ParamValues { )) })?; // check if the data type of the value matches the data type of the placeholder - if Some(value.data_type()) != *data_type { + if Some(&value.data_type()) != data_type { return _internal_err!( "Placeholder value type mismatch: expected {:?}, got {:?}", data_type, @@ -107,7 +107,7 @@ impl ParamValues { )) })?; // check if the data type of the value matches the data type of the placeholder - if Some(value.data_type()) != *data_type { + if Some(&value.data_type()) != data_type { return _internal_err!( "Placeholder value type mismatch: expected {:?}, got {:?}", data_type, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 1f3711407a14..50f4a6b76e18 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1250,8 +1250,8 @@ impl LogicalPlan { expr.transform(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, data_type }) => { - let value = - param_values.get_placeholders_with_values(id, data_type)?; + let value = param_values + .get_placeholders_with_values(id, data_type.as_ref())?; // Replace the placeholder with the value Ok(Transformed::Yes(Expr::Literal(value))) } From 03c2ef46f2d88fb015ee305ab67df6d930b780e2 Mon Sep 17 00:00:00 2001 From: Tomoaki Kawada Date: Sat, 23 Dec 2023 06:20:05 +0900 Subject: [PATCH 18/63] Don't panic on zero placeholder in `ParamValues::get_placeholders_with_values` (#8615) It correctly rejected `$0` but not the other ones that are parsed equally (e.g., `$000`). Co-authored-by: Andrew Lamb --- datafusion/common/src/param_value.rs | 17 ++++++++++------- datafusion/expr/src/logical_plan/plan.rs | 13 +++++++++++++ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index 004c1371d1ae..3fe2ba99ab83 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -72,17 +72,20 @@ impl ParamValues { ) -> Result { match self { ParamValues::List(list) => { - if id.is_empty() || id == "$0" { + if id.is_empty() { return _plan_err!("Empty placeholder id"); } // convert id (in format $1, $2, ..) to idx (0, 1, ..) - let idx = id[1..].parse::().map_err(|e| { - DataFusionError::Internal(format!( - "Failed to parse placeholder id: {e}" - )) - })? - 1; + let idx = id[1..] + .parse::() + .map_err(|e| { + DataFusionError::Internal(format!( + "Failed to parse placeholder id: {e}" + )) + })? + .checked_sub(1); // value at the idx-th position in param_values should be the value for the placeholder - let value = list.get(idx).ok_or_else(|| { + let value = idx.and_then(|idx| list.get(idx)).ok_or_else(|| { DataFusionError::Internal(format!( "No value found for placeholder with id {id}" )) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 50f4a6b76e18..9b0f441ef902 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -3099,6 +3099,19 @@ digraph { .build() .unwrap(); + plan.replace_params_with_values(¶m_values.clone().into()) + .expect_err("unexpectedly succeeded to replace an invalid placeholder"); + + // test $00 placeholder + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .filter(col("id").eq(placeholder("$00"))) + .unwrap() + .build() + .unwrap(); + plan.replace_params_with_values(¶m_values.into()) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); } From df2e1e2587340c513743b965f9aef301c4a2a859 Mon Sep 17 00:00:00 2001 From: Marvin Lanhenke <62298609+marvinlanhenke@users.noreply.github.com> Date: Sat, 23 Dec 2023 13:09:50 +0100 Subject: [PATCH 19/63] Fix #8507: Non-null sub-field on nullable struct-field has wrong nullity (#8623) * added test * added guard clause * rename schema fields * clippy --------- Co-authored-by: mlanhenke --- datafusion/expr/src/expr_schema.rs | 32 ++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index e5b0185d90e0..ba21d09f0619 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -277,6 +277,13 @@ impl ExprSchemable for Expr { "Wildcard expressions are not valid in a logical query plan" ), Expr::GetIndexedField(GetIndexedField { expr, field }) => { + // If schema is nested, check if parent is nullable + // if it is, return early + if let Expr::Column(col) = expr.as_ref() { + if input_schema.nullable(col)? { + return Ok(true); + } + } field_for_index(expr, field, input_schema).map(|x| x.is_nullable()) } Expr::GroupingSet(_) => { @@ -411,8 +418,8 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result {{ @@ -548,6 +555,27 @@ mod tests { assert_eq!(&meta, expr.to_field(&schema).unwrap().metadata()); } + #[test] + fn test_nested_schema_nullability() { + let fields = DFField::new( + Some(TableReference::Bare { + table: "table_name".into(), + }), + "parent", + DataType::Struct(Fields::from(vec![Field::new( + "child", + DataType::Int64, + false, + )])), + true, + ); + + let schema = DFSchema::new_with_metadata(vec![fields], HashMap::new()).unwrap(); + + let expr = col("parent").field("child"); + assert!(expr.nullable(&schema).unwrap()); + } + #[derive(Debug)] struct MockExprSchema { nullable: bool, From 8524d58e303b65597eeebc41c75025a6f0822793 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 23 Dec 2023 07:10:56 -0500 Subject: [PATCH 20/63] Implement `contained` API in PruningPredicate (#8440) * Implement `contains` API in PruningPredicate * Apply suggestions from code review Co-authored-by: Nga Tran * Add comment to len(), fix fmt * rename BoolVecBuilder::append* to BoolVecBuilder::combine* --------- Co-authored-by: Nga Tran --- .../physical_plan/parquet/page_filter.rs | 11 +- .../physical_plan/parquet/row_groups.rs | 9 + .../core/src/physical_optimizer/pruning.rs | 1073 +++++++++++++---- 3 files changed, 857 insertions(+), 236 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index 42bfef35996e..f6310c49bcd6 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -23,7 +23,7 @@ use arrow::array::{ }; use arrow::datatypes::DataType; use arrow::{array::ArrayRef, datatypes::SchemaRef, error::ArrowError}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use log::{debug, trace}; @@ -37,6 +37,7 @@ use parquet::{ }, format::PageLocation, }; +use std::collections::HashSet; use std::sync::Arc; use crate::datasource::physical_plan::parquet::parquet_to_arrow_decimal_type; @@ -554,4 +555,12 @@ impl<'a> PruningStatistics for PagesPruningStatistics<'a> { ))), } } + + fn contained( + &self, + _column: &datafusion_common::Column, + _values: &HashSet, + ) -> Option { + None + } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 7c3f7d9384ab..09e4907c9437 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -16,6 +16,7 @@ // under the License. use arrow::{array::ArrayRef, datatypes::Schema}; +use arrow_array::BooleanArray; use arrow_schema::FieldRef; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{Column, DataFusionError, Result, ScalarValue}; @@ -340,6 +341,14 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { let scalar = ScalarValue::UInt64(Some(c.statistics()?.null_count())); scalar.to_array().ok() } + + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } } #[cfg(test)] diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index b2ba7596db8d..79e084d7b7f1 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -35,12 +35,13 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; -use datafusion_common::{downcast_value, plan_datafusion_err, ScalarValue}; +use arrow_array::cast::AsArray; use datafusion_common::{ internal_err, plan_err, tree_node::{Transformed, TreeNode}, }; -use datafusion_physical_expr::utils::collect_columns; +use datafusion_common::{plan_datafusion_err, ScalarValue}; +use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use log::trace; @@ -93,6 +94,30 @@ pub trait PruningStatistics { /// /// Note: the returned array must contain [`Self::num_containers`] rows fn null_counts(&self, column: &Column) -> Option; + + /// Returns an array where each row represents information known about + /// the `values` contained in a column. + /// + /// This API is designed to be used along with [`LiteralGuarantee`] to prove + /// that predicates can not possibly evaluate to `true` and thus prune + /// containers. For example, Parquet Bloom Filters can prove that values are + /// not present. + /// + /// The returned array has one row for each container, with the following + /// meanings: + /// * `true` if the values in `column` ONLY contain values from `values` + /// * `false` if the values in `column` are NOT ANY of `values` + /// * `null` if the neither of the above holds or is unknown. + /// + /// If these statistics can not determine column membership for any + /// container, return `None` (the default). + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option; } /// Evaluates filter expressions on statistics such as min/max values and null @@ -142,12 +167,17 @@ pub trait PruningStatistics { pub struct PruningPredicate { /// The input schema against which the predicate will be evaluated schema: SchemaRef, - /// Actual pruning predicate (rewritten in terms of column min/max statistics) + /// A min/max pruning predicate (rewritten in terms of column min/max + /// values, which are supplied by statistics) predicate_expr: Arc, - /// The statistics required to evaluate this predicate - required_columns: RequiredStatColumns, - /// Original physical predicate from which this predicate expr is derived (required for serialization) + /// Description of which statistics are required to evaluate `predicate_expr` + required_columns: RequiredColumns, + /// Original physical predicate from which this predicate expr is derived + /// (required for serialization) orig_expr: Arc, + /// [`LiteralGuarantee`]s that are used to try and prove a predicate can not + /// possibly evaluate to `true`. + literal_guarantees: Vec, } impl PruningPredicate { @@ -172,14 +202,18 @@ impl PruningPredicate { /// `(column_min / 2) <= 4 && 4 <= (column_max / 2))` pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { // build predicate expression once - let mut required_columns = RequiredStatColumns::new(); + let mut required_columns = RequiredColumns::new(); let predicate_expr = build_predicate_expression(&expr, schema.as_ref(), &mut required_columns); + + let literal_guarantees = LiteralGuarantee::analyze(&expr); + Ok(Self { schema, predicate_expr, required_columns, orig_expr: expr, + literal_guarantees, }) } @@ -198,40 +232,47 @@ impl PruningPredicate { /// /// [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier pub fn prune(&self, statistics: &S) -> Result> { + let mut builder = BoolVecBuilder::new(statistics.num_containers()); + + // Try to prove the predicate can't be true for the containers based on + // literal guarantees + for literal_guarantee in &self.literal_guarantees { + let LiteralGuarantee { + column, + guarantee, + literals, + } = literal_guarantee; + if let Some(results) = statistics.contained(column, literals) { + match guarantee { + // `In` means the values in the column must be one of the + // values in the set for the predicate to evaluate to true. + // If `contained` returns false, that means the column is + // not any of the values so we can prune the container + Guarantee::In => builder.combine_array(&results), + // `NotIn` means the values in the column must must not be + // any of the values in the set for the predicate to + // evaluate to true. If contained returns true, it means the + // column is only in the set of values so we can prune the + // container + Guarantee::NotIn => { + builder.combine_array(&arrow::compute::not(&results)?) + } + } + } + } + + // Next, try to prove the predicate can't be true for the containers based + // on min/max values + // build a RecordBatch that contains the min/max values in the - // appropriate statistics columns + // appropriate statistics columns for the min/max predicate let statistics_batch = build_statistics_record_batch(statistics, &self.required_columns)?; - // Evaluate the pruning predicate on that record batch. - // - // Use true when the result of evaluating a predicate - // expression on a row group is null (aka `None`). Null can - // arise when the statistics are unknown or some calculation - // in the predicate means we don't know for sure if the row - // group can be filtered out or not. To maintain correctness - // the row group must be kept and thus `true` is returned. - match self.predicate_expr.evaluate(&statistics_batch)? { - ColumnarValue::Array(array) => { - let predicate_array = downcast_value!(array, BooleanArray); + // Evaluate the pruning predicate on that record batch and append any results to the builder + builder.combine_value(self.predicate_expr.evaluate(&statistics_batch)?); - Ok(predicate_array - .into_iter() - .map(|x| x.unwrap_or(true)) // None -> true per comments above - .collect::>()) - } - // result was a column - ColumnarValue::Scalar(ScalarValue::Boolean(v)) => { - let v = v.unwrap_or(true); // None -> true per comments above - Ok(vec![v; statistics.num_containers()]) - } - other => { - internal_err!( - "Unexpected result of pruning predicate evaluation. Expected Boolean array \ - or scalar but got {other:?}" - ) - } - } + Ok(builder.build()) } /// Return a reference to the input schema @@ -254,9 +295,91 @@ impl PruningPredicate { is_always_true(&self.predicate_expr) } - pub(crate) fn required_columns(&self) -> &RequiredStatColumns { + pub(crate) fn required_columns(&self) -> &RequiredColumns { &self.required_columns } + + /// Names of the columns that are known to be / not be in a set + /// of literals (constants). These are the columns the that may be passed to + /// [`PruningStatistics::contained`] during pruning. + /// + /// This is useful to avoid fetching statistics for columns that will not be + /// used in the predicate. For example, it can be used to avoid reading + /// uneeded bloom filters (a non trivial operation). + pub fn literal_columns(&self) -> Vec { + let mut seen = HashSet::new(); + self.literal_guarantees + .iter() + .map(|e| &e.column.name) + // avoid duplicates + .filter(|name| seen.insert(*name)) + .map(|s| s.to_string()) + .collect() + } +} + +/// Builds the return `Vec` for [`PruningPredicate::prune`]. +#[derive(Debug)] +struct BoolVecBuilder { + /// One element per container. Each element is + /// * `true`: if the container has row that may pass the predicate + /// * `false`: if the container has rows that DEFINITELY DO NOT pass the predicate + inner: Vec, +} + +impl BoolVecBuilder { + /// Create a new `BoolVecBuilder` with `num_containers` elements + fn new(num_containers: usize) -> Self { + Self { + // assume by default all containers may pass the predicate + inner: vec![true; num_containers], + } + } + + /// Combines result `array` for a conjunct (e.g. `AND` clause) of a + /// predicate into the currently in progress array. + /// + /// Each `array` element is: + /// * `true`: container has row that may pass the predicate + /// * `false`: all container rows DEFINITELY DO NOT pass the predicate + /// * `null`: container may or may not have rows that pass the predicate + fn combine_array(&mut self, array: &BooleanArray) { + assert_eq!(array.len(), self.inner.len()); + for (cur, new) in self.inner.iter_mut().zip(array.iter()) { + // `false` for this conjunct means we know for sure no rows could + // pass the predicate and thus we set the corresponding container + // location to false. + if let Some(false) = new { + *cur = false; + } + } + } + + /// Combines the results in the [`ColumnarValue`] to the currently in + /// progress array, following the same rules as [`Self::combine_array`]. + /// + /// # Panics + /// If `value` is not boolean + fn combine_value(&mut self, value: ColumnarValue) { + match value { + ColumnarValue::Array(array) => { + self.combine_array(array.as_boolean()); + } + ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))) => { + // False means all containers can not pass the predicate + self.inner = vec![false; self.inner.len()]; + } + _ => { + // Null or true means the rows in container may pass this + // conjunct so we can't prune any containers based on that + } + } + } + + /// Convert this builder into a Vec of bools + fn build(self) -> Vec { + self.inner + } } fn is_always_true(expr: &Arc) -> bool { @@ -276,21 +399,21 @@ fn is_always_true(expr: &Arc) -> bool { /// Handles creating references to the min/max statistics /// for columns as well as recording which statistics are needed #[derive(Debug, Default, Clone)] -pub(crate) struct RequiredStatColumns { +pub(crate) struct RequiredColumns { /// The statistics required to evaluate this predicate: /// * The unqualified column in the input schema /// * Statistics type (e.g. Min or Max or Null_Count) /// * The field the statistics value should be placed in for - /// pruning predicate evaluation + /// pruning predicate evaluation (e.g. `min_value` or `max_value`) columns: Vec<(phys_expr::Column, StatisticsType, Field)>, } -impl RequiredStatColumns { +impl RequiredColumns { fn new() -> Self { Self::default() } - /// Returns number of unique columns. + /// Returns number of unique columns pub(crate) fn n_columns(&self) -> usize { self.iter() .map(|(c, _s, _f)| c) @@ -344,11 +467,10 @@ impl RequiredStatColumns { // only add statistics column if not previously added if need_to_insert { - let stat_field = Field::new( - stat_column.name(), - field.data_type().clone(), - field.is_nullable(), - ); + // may be null if statistics are not present + let nullable = true; + let stat_field = + Field::new(stat_column.name(), field.data_type().clone(), nullable); self.columns.push((column.clone(), stat_type, stat_field)); } rewrite_column_expr(column_expr.clone(), column, &stat_column) @@ -391,7 +513,7 @@ impl RequiredStatColumns { } } -impl From> for RequiredStatColumns { +impl From> for RequiredColumns { fn from(columns: Vec<(phys_expr::Column, StatisticsType, Field)>) -> Self { Self { columns } } @@ -424,7 +546,7 @@ impl From> for RequiredStatColum /// ``` fn build_statistics_record_batch( statistics: &S, - required_columns: &RequiredStatColumns, + required_columns: &RequiredColumns, ) -> Result { let mut fields = Vec::::new(); let mut arrays = Vec::::new(); @@ -480,7 +602,7 @@ struct PruningExpressionBuilder<'a> { op: Operator, scalar_expr: Arc, field: &'a Field, - required_columns: &'a mut RequiredStatColumns, + required_columns: &'a mut RequiredColumns, } impl<'a> PruningExpressionBuilder<'a> { @@ -489,7 +611,7 @@ impl<'a> PruningExpressionBuilder<'a> { right: &'a Arc, op: Operator, schema: &'a Schema, - required_columns: &'a mut RequiredStatColumns, + required_columns: &'a mut RequiredColumns, ) -> Result { // find column name; input could be a more complicated expression let left_columns = collect_columns(left); @@ -704,7 +826,7 @@ fn reverse_operator(op: Operator) -> Result { fn build_single_column_expr( column: &phys_expr::Column, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, is_not: bool, // if true, treat as !col ) -> Option> { let field = schema.field_with_name(column.name()).ok()?; @@ -745,7 +867,7 @@ fn build_single_column_expr( fn build_is_null_column_expr( expr: &Arc, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, ) -> Option> { if let Some(col) = expr.as_any().downcast_ref::() { let field = schema.field_with_name(col.name()).ok()?; @@ -775,7 +897,7 @@ fn build_is_null_column_expr( fn build_predicate_expression( expr: &Arc, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, ) -> Arc { // Returned for unsupported expressions. Such expressions are // converted to TRUE. @@ -984,7 +1106,7 @@ mod tests { use std::collections::HashMap; use std::ops::{Not, Rem}; - #[derive(Debug)] + #[derive(Debug, Default)] /// Mock statistic provider for tests /// /// Each row represents the statistics for a "container" (which @@ -993,95 +1115,142 @@ mod tests { /// /// Note All `ArrayRefs` must be the same size. struct ContainerStats { - min: ArrayRef, - max: ArrayRef, + min: Option, + max: Option, /// Optional values null_counts: Option, + /// Optional known values (e.g. mimic a bloom filter) + /// (value, contained) + /// If present, all BooleanArrays must be the same size as min/max + contained: Vec<(HashSet, BooleanArray)>, } impl ContainerStats { + fn new() -> Self { + Default::default() + } fn new_decimal128( min: impl IntoIterator>, max: impl IntoIterator>, precision: u8, scale: i8, ) -> Self { - Self { - min: Arc::new( + Self::new() + .with_min(Arc::new( min.into_iter() .collect::() .with_precision_and_scale(precision, scale) .unwrap(), - ), - max: Arc::new( + )) + .with_max(Arc::new( max.into_iter() .collect::() .with_precision_and_scale(precision, scale) .unwrap(), - ), - null_counts: None, - } + )) } fn new_i64( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn new_i32( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn new_utf8<'a>( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn new_bool( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn min(&self) -> Option { - Some(self.min.clone()) + self.min.clone() } fn max(&self) -> Option { - Some(self.max.clone()) + self.max.clone() } fn null_counts(&self) -> Option { self.null_counts.clone() } + /// return an iterator over all arrays in this statistics + fn arrays(&self) -> Vec { + let contained_arrays = self + .contained + .iter() + .map(|(_values, contained)| Arc::new(contained.clone()) as ArrayRef); + + [ + self.min.as_ref().cloned(), + self.max.as_ref().cloned(), + self.null_counts.as_ref().cloned(), + ] + .into_iter() + .flatten() + .chain(contained_arrays) + .collect() + } + + /// Returns the number of containers represented by this statistics This + /// picks the length of the first array as all arrays must have the same + /// length (which is verified by `assert_invariants`). fn len(&self) -> usize { - assert_eq!(self.min.len(), self.max.len()); - self.min.len() + // pick the first non zero length + self.arrays().iter().map(|a| a.len()).next().unwrap_or(0) + } + + /// Ensure that the lengths of all arrays are consistent + fn assert_invariants(&self) { + let mut prev_len = None; + + for len in self.arrays().iter().map(|a| a.len()) { + // Get a length, if we don't already have one + match prev_len { + None => { + prev_len = Some(len); + } + Some(prev_len) => { + assert_eq!(prev_len, len); + } + } + } + } + + /// Add min values + fn with_min(mut self, min: ArrayRef) -> Self { + self.min = Some(min); + self + } + + /// Add max values + fn with_max(mut self, max: ArrayRef) -> Self { + self.max = Some(max); + self } /// Add null counts. There must be the same number of null counts as @@ -1090,14 +1259,36 @@ mod tests { mut self, counts: impl IntoIterator>, ) -> Self { - // take stats out and update them let null_counts: ArrayRef = Arc::new(counts.into_iter().collect::()); - assert_eq!(null_counts.len(), self.len()); + self.assert_invariants(); self.null_counts = Some(null_counts); self } + + /// Add contained information. + pub fn with_contained( + mut self, + values: impl IntoIterator, + contained: impl IntoIterator>, + ) -> Self { + let contained: BooleanArray = contained.into_iter().collect(); + let values: HashSet<_> = values.into_iter().collect(); + + self.contained.push((values, contained)); + self.assert_invariants(); + self + } + + /// get any contained information for the specified values + fn contained(&self, find_values: &HashSet) -> Option { + // find the one with the matching values + self.contained + .iter() + .find(|(values, _contained)| values == find_values) + .map(|(_values, contained)| contained.clone()) + } } #[derive(Debug, Default)] @@ -1135,13 +1326,34 @@ mod tests { let container_stats = self .stats .remove(&col) - .expect("Can not find stats for column") + .unwrap_or_default() .with_null_counts(counts); // put stats back in self.stats.insert(col, container_stats); self } + + /// Add contained information for the specified columm. + fn with_contained( + mut self, + name: impl Into, + values: impl IntoIterator, + contained: impl IntoIterator>, + ) -> Self { + let col = Column::from_name(name.into()); + + // take stats out and update them + let container_stats = self + .stats + .remove(&col) + .unwrap_or_default() + .with_contained(values, contained); + + // put stats back in + self.stats.insert(col, container_stats); + self + } } impl PruningStatistics for TestStatistics { @@ -1173,6 +1385,16 @@ mod tests { .map(|container_stats| container_stats.null_counts()) .unwrap_or(None) } + + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + self.stats + .get(column) + .and_then(|container_stats| container_stats.contained(values)) + } } /// Returns the specified min/max container values @@ -1198,12 +1420,20 @@ mod tests { fn null_counts(&self, _column: &Column) -> Option { None } + + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } } #[test] fn test_build_statistics_record_batch() { // Request a record batch with of s1_min, s2_max, s3_max, s3_min - let required_columns = RequiredStatColumns::from(vec![ + let required_columns = RequiredColumns::from(vec![ // min of original column s1, named s1_min ( phys_expr::Column::new("s1", 1), @@ -1275,7 +1505,7 @@ mod tests { // which is what Parquet does // Request a record batch with of s1_min as a timestamp - let required_columns = RequiredStatColumns::from(vec![( + let required_columns = RequiredColumns::from(vec![( phys_expr::Column::new("s3", 3), StatisticsType::Min, Field::new( @@ -1307,7 +1537,7 @@ mod tests { #[test] fn test_build_statistics_no_required_stats() { - let required_columns = RequiredStatColumns::new(); + let required_columns = RequiredColumns::new(); let statistics = OneContainerStats { min_values: Some(Arc::new(Int64Array::from(vec![Some(10)]))), @@ -1325,7 +1555,7 @@ mod tests { // Test requesting a Utf8 column when the stats return some other type // Request a record batch with of s1_min as a timestamp - let required_columns = RequiredStatColumns::from(vec![( + let required_columns = RequiredColumns::from(vec![( phys_expr::Column::new("s3", 3), StatisticsType::Min, Field::new("s1_min", DataType::Utf8, true), @@ -1354,7 +1584,7 @@ mod tests { #[test] fn test_build_statistics_inconsistent_length() { // return an inconsistent length to the actual statistics arrays - let required_columns = RequiredStatColumns::from(vec![( + let required_columns = RequiredColumns::from(vec![( phys_expr::Column::new("s1", 3), StatisticsType::Min, Field::new("s1_min", DataType::Int64, true), @@ -1385,20 +1615,14 @@ mod tests { // test column on the left let expr = col("c1").eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1411,20 +1635,14 @@ mod tests { // test column on the left let expr = col("c1").not_eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).not_eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1437,20 +1655,14 @@ mod tests { // test column on the left let expr = col("c1").gt(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).lt(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1463,19 +1675,13 @@ mod tests { // test column on the left let expr = col("c1").gt_eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).lt_eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1488,20 +1694,14 @@ mod tests { // test column on the left let expr = col("c1").lt(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).gt(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1514,19 +1714,13 @@ mod tests { // test column on the left let expr = col("c1").lt_eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).gt_eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1542,11 +1736,8 @@ mod tests { // test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3"))); let expected_expr = "c1_min@0 < 1"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1561,11 +1752,8 @@ mod tests { // test OR operator joining supported c1 < 1 expression and unsupported c2 % 2 = 0 expression let expr = col("c1").lt(lit(1)).or(col("c2").rem(lit(2)).eq(lit(0))); let expected_expr = "true"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1577,11 +1765,8 @@ mod tests { let expected_expr = "true"; let expr = col("c1").not(); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1593,11 +1778,8 @@ mod tests { let expected_expr = "NOT c1_min@0 AND c1_max@1"; let expr = col("c1").not(); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1609,11 +1791,8 @@ mod tests { let expected_expr = "c1_min@0 OR c1_max@1"; let expr = col("c1"); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1627,11 +1806,8 @@ mod tests { // DF doesn't support arithmetic on boolean columns so // this predicate will error when evaluated let expr = col("c1").lt(lit(true)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1643,7 +1819,7 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), ]); - let mut required_columns = RequiredStatColumns::new(); + let mut required_columns = RequiredColumns::new(); // c1 < 1 and (c2 = 2 or c2 = 3) let expr = col("c1") .lt(lit(1)) @@ -1659,7 +1835,7 @@ mod tests { ( phys_expr::Column::new("c1", 0), StatisticsType::Min, - c1_min_field + c1_min_field.with_nullable(true) // could be nullable if stats are not present ) ); // c2 = 2 should add c2_min and c2_max @@ -1669,7 +1845,7 @@ mod tests { ( phys_expr::Column::new("c2", 1), StatisticsType::Min, - c2_min_field + c2_min_field.with_nullable(true) // could be nullable if stats are not present ) ); let c2_max_field = Field::new("c2_max", DataType::Int32, false); @@ -1678,7 +1854,7 @@ mod tests { ( phys_expr::Column::new("c2", 1), StatisticsType::Max, - c2_max_field + c2_max_field.with_nullable(true) // could be nullable if stats are not present ) ); // c2 = 3 shouldn't add any new statistics fields @@ -1700,11 +1876,8 @@ mod tests { false, )); let expected_expr = "c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_min@0 <= 2 AND 2 <= c1_max@1 OR c1_min@0 <= 3 AND 3 <= c1_max@1"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1719,11 +1892,8 @@ mod tests { // test c1 in() let expr = Expr::InList(InList::new(Box::new(col("c1")), vec![], false)); let expected_expr = "true"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1744,11 +1914,8 @@ mod tests { let expected_expr = "(c1_min@0 != 1 OR 1 != c1_max@1) \ AND (c1_min@0 != 2 OR 2 != c1_max@1) \ AND (c1_min@0 != 3 OR 3 != c1_max@1)"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1762,20 +1929,14 @@ mod tests { // test column on the left let expr = cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(1)))); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(ScalarValue::Int64(Some(1))).eq(cast(col("c1"), DataType::Int64)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); let expected_expr = "TRY_CAST(c1_max@0 AS Int64) > 1"; @@ -1783,21 +1944,15 @@ mod tests { // test column on the left let expr = try_cast(col("c1"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(1)))); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(ScalarValue::Int64(Some(1))).lt(try_cast(col("c1"), DataType::Int64)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1817,11 +1972,8 @@ mod tests { false, )); let expected_expr = "CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64)"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); let expr = Expr::InList(InList::new( @@ -1837,11 +1989,8 @@ mod tests { "(CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64)) \ AND (CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64)) \ AND (CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64))"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -2484,10 +2633,464 @@ mod tests { // TODO: add other negative test for other case and op } + #[test] + fn prune_with_contained_one_column() { + let schema = Arc::new(Schema::new(vec![Field::new("s1", DataType::Utf8, true)])); + + // Model having information like a bloom filter for s1 + let statistics = TestStatistics::new() + .with_contained( + "s1", + [ScalarValue::from("foo")], + [ + // container 0 known to only contain "foo"", + Some(true), + // container 1 known to not contain "foo" + Some(false), + // container 2 unknown about "foo" + None, + // container 3 known to only contain "foo" + Some(true), + // container 4 known to not contain "foo" + Some(false), + // container 5 unknown about "foo" + None, + // container 6 known to only contain "foo" + Some(true), + // container 7 known to not contain "foo" + Some(false), + // container 8 unknown about "foo" + None, + ], + ) + .with_contained( + "s1", + [ScalarValue::from("bar")], + [ + // containers 0,1,2 known to only contain "bar" + Some(true), + Some(true), + Some(true), + // container 3,4,5 known to not contain "bar" + Some(false), + Some(false), + Some(false), + // container 6,7,8 unknown about "bar" + None, + None, + None, + ], + ) + .with_contained( + // the way the tests are setup, this data is + // consulted if the "foo" and "bar" are being checked at the same time + "s1", + [ScalarValue::from("foo"), ScalarValue::from("bar")], + [ + // container 0,1,2 unknown about ("foo, "bar") + None, + None, + None, + // container 3,4,5 known to contain only either "foo" and "bar" + Some(true), + Some(true), + Some(true), + // container 6,7,8 known to contain neither "foo" and "bar" + Some(false), + Some(false), + Some(false), + ], + ); + + // s1 = 'foo' + prune_with_expr( + col("s1").eq(lit("foo")), + &schema, + &statistics, + // rule out containers ('false) where we know foo is not present + vec![true, false, true, true, false, true, true, false, true], + ); + + // s1 = 'bar' + prune_with_expr( + col("s1").eq(lit("bar")), + &schema, + &statistics, + // rule out containers where we know bar is not present + vec![true, true, true, false, false, false, true, true, true], + ); + + // s1 = 'baz' (unknown value) + prune_with_expr( + col("s1").eq(lit("baz")), + &schema, + &statistics, + // can't rule out anything + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' AND s1 = 'bar' + prune_with_expr( + col("s1").eq(lit("foo")).and(col("s1").eq(lit("bar"))), + &schema, + &statistics, + // logically this predicate can't possibly be true (the column can't + // take on both values) but we could rule it out if the stats tell + // us that both values are not present + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' OR s1 = 'bar' + prune_with_expr( + col("s1").eq(lit("foo")).or(col("s1").eq(lit("bar"))), + &schema, + &statistics, + // can rule out containers that we know contain neither foo nor bar + vec![true, true, true, true, true, true, false, false, false], + ); + + // s1 = 'foo' OR s1 = 'baz' + prune_with_expr( + col("s1").eq(lit("foo")).or(col("s1").eq(lit("baz"))), + &schema, + &statistics, + // can't rule out anything container + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' OR s1 = 'bar' OR s1 = 'baz' + prune_with_expr( + col("s1") + .eq(lit("foo")) + .or(col("s1").eq(lit("bar"))) + .or(col("s1").eq(lit("baz"))), + &schema, + &statistics, + // can rule out any containers based on knowledge of s1 and `foo`, + // `bar` and (`foo`, `bar`) + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 != foo + prune_with_expr( + col("s1").not_eq(lit("foo")), + &schema, + &statistics, + // rule out containers we know for sure only contain foo + vec![false, true, true, false, true, true, false, true, true], + ); + + // s1 != bar + prune_with_expr( + col("s1").not_eq(lit("bar")), + &schema, + &statistics, + // rule out when we know for sure s1 has the value bar + vec![false, false, false, true, true, true, true, true, true], + ); + + // s1 != foo AND s1 != bar + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s1").not_eq(lit("bar"))), + &schema, + &statistics, + // can rule out any container where we know s1 does not have either 'foo' or 'bar' + vec![true, true, true, false, false, false, true, true, true], + ); + + // s1 != foo AND s1 != bar AND s1 != baz + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s1").not_eq(lit("bar"))) + .and(col("s1").not_eq(lit("baz"))), + &schema, + &statistics, + // can't rule out any container based on knowledge of s1,s2 + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 != foo OR s1 != bar + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .or(col("s1").not_eq(lit("bar"))), + &schema, + &statistics, + // cant' rule out anything based on contains information + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 != foo OR s1 != bar OR s1 != baz + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .or(col("s1").not_eq(lit("bar"))) + .or(col("s1").not_eq(lit("baz"))), + &schema, + &statistics, + // cant' rule out anything based on contains information + vec![true, true, true, true, true, true, true, true, true], + ); + } + + #[test] + fn prune_with_contained_two_columns() { + let schema = Arc::new(Schema::new(vec![ + Field::new("s1", DataType::Utf8, true), + Field::new("s2", DataType::Utf8, true), + ])); + + // Model having information like bloom filters for s1 and s2 + let statistics = TestStatistics::new() + .with_contained( + "s1", + [ScalarValue::from("foo")], + [ + // container 0, s1 known to only contain "foo"", + Some(true), + // container 1, s1 known to not contain "foo" + Some(false), + // container 2, s1 unknown about "foo" + None, + // container 3, s1 known to only contain "foo" + Some(true), + // container 4, s1 known to not contain "foo" + Some(false), + // container 5, s1 unknown about "foo" + None, + // container 6, s1 known to only contain "foo" + Some(true), + // container 7, s1 known to not contain "foo" + Some(false), + // container 8, s1 unknown about "foo" + None, + ], + ) + .with_contained( + "s2", // for column s2 + [ScalarValue::from("bar")], + [ + // containers 0,1,2 s2 known to only contain "bar" + Some(true), + Some(true), + Some(true), + // container 3,4,5 s2 known to not contain "bar" + Some(false), + Some(false), + Some(false), + // container 6,7,8 s2 unknown about "bar" + None, + None, + None, + ], + ); + + // s1 = 'foo' + prune_with_expr( + col("s1").eq(lit("foo")), + &schema, + &statistics, + // rule out containers where we know s1 is not present + vec![true, false, true, true, false, true, true, false, true], + ); + + // s1 = 'foo' OR s2 = 'bar' + let expr = col("s1").eq(lit("foo")).or(col("s2").eq(lit("bar"))); + prune_with_expr( + expr, + &schema, + &statistics, + // can't rule out any container (would need to prove that s1 != foo AND s2 != bar) + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' AND s2 != 'bar' + prune_with_expr( + col("s1").eq(lit("foo")).and(col("s2").not_eq(lit("bar"))), + &schema, + &statistics, + // can only rule out container where we know either: + // 1. s1 doesn't have the value 'foo` or + // 2. s2 has only the value of 'bar' + vec![false, false, false, true, false, true, true, false, true], + ); + + // s1 != 'foo' AND s2 != 'bar' + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s2").not_eq(lit("bar"))), + &schema, + &statistics, + // Can rule out any container where we know either + // 1. s1 has only the value 'foo' + // 2. s2 has only the value 'bar' + vec![false, false, false, false, true, true, false, true, true], + ); + + // s1 != 'foo' AND (s2 = 'bar' OR s2 = 'baz') + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s2").eq(lit("bar")).or(col("s2").eq(lit("baz")))), + &schema, + &statistics, + // Can rule out any container where we know s1 has only the value + // 'foo'. Can't use knowledge of s2 and bar to rule out anything + vec![false, true, true, false, true, true, false, true, true], + ); + + // s1 like '%foo%bar%' + prune_with_expr( + col("s1").like(lit("foo%bar%")), + &schema, + &statistics, + // cant rule out anything with information we know + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 like '%foo%bar%' AND s2 = 'bar' + prune_with_expr( + col("s1") + .like(lit("foo%bar%")) + .and(col("s2").eq(lit("bar"))), + &schema, + &statistics, + // can rule out any container where we know s2 does not have the value 'bar' + vec![true, true, true, false, false, false, true, true, true], + ); + + // s1 like '%foo%bar%' OR s2 = 'bar' + prune_with_expr( + col("s1").like(lit("foo%bar%")).or(col("s2").eq(lit("bar"))), + &schema, + &statistics, + // can't rule out anything (we would have to prove that both the + // like and the equality must be false) + vec![true, true, true, true, true, true, true, true, true], + ); + } + + #[test] + fn prune_with_range_and_contained() { + // Setup mimics range information for i, a bloom filter for s + let schema = Arc::new(Schema::new(vec![ + Field::new("i", DataType::Int32, true), + Field::new("s", DataType::Utf8, true), + ])); + + let statistics = TestStatistics::new() + .with( + "i", + ContainerStats::new_i32( + // Container 0, 3, 6: [-5 to 5] + // Container 1, 4, 7: [10 to 20] + // Container 2, 5, 9: unknown + vec![ + Some(-5), + Some(10), + None, + Some(-5), + Some(10), + None, + Some(-5), + Some(10), + None, + ], // min + vec![ + Some(5), + Some(20), + None, + Some(5), + Some(20), + None, + Some(5), + Some(20), + None, + ], // max + ), + ) + // Add contained information about the s and "foo" + .with_contained( + "s", + [ScalarValue::from("foo")], + [ + // container 0,1,2 known to only contain "foo" + Some(true), + Some(true), + Some(true), + // container 3,4,5 known to not contain "foo" + Some(false), + Some(false), + Some(false), + // container 6,7,8 unknown about "foo" + None, + None, + None, + ], + ); + + // i = 0 and s = 'foo' + prune_with_expr( + col("i").eq(lit(0)).and(col("s").eq(lit("foo"))), + &schema, + &statistics, + // Can rule out container where we know that either: + // 1. 0 is outside the min/max range of i + // 1. s does not contain foo + // (range is false, and contained is false) + vec![true, false, true, false, false, false, true, false, true], + ); + + // i = 0 and s != 'foo' + prune_with_expr( + col("i").eq(lit(0)).and(col("s").not_eq(lit("foo"))), + &schema, + &statistics, + // Can rule out containers where either: + // 1. 0 is outside the min/max range of i + // 2. s only contains foo + vec![false, false, false, true, false, true, true, false, true], + ); + + // i = 0 OR s = 'foo' + prune_with_expr( + col("i").eq(lit(0)).or(col("s").eq(lit("foo"))), + &schema, + &statistics, + // in theory could rule out containers if we had min/max values for + // s as well. But in this case we don't so we can't rule out anything + vec![true, true, true, true, true, true, true, true, true], + ); + } + + /// prunes the specified expr with the specified schema and statistics, and + /// ensures it returns expected. + /// + /// `expected` is a vector of bools, where true means the row group should + /// be kept, and false means it should be pruned. + /// + // TODO refactor other tests to use this to reduce boiler plate + fn prune_with_expr( + expr: Expr, + schema: &SchemaRef, + statistics: &TestStatistics, + expected: Vec, + ) { + println!("Pruning with expr: {}", expr); + let expr = logical2physical(&expr, schema); + let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + let result = p.prune(statistics).unwrap(); + assert_eq!(result, expected); + } + fn test_build_predicate_expression( expr: &Expr, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, ) -> Arc { let expr = logical2physical(expr, schema); build_predicate_expression(&expr, schema, required_columns) From bf43bb2eed304369c078637bc84d1b842c24b399 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 23 Dec 2023 09:25:07 -0700 Subject: [PATCH 21/63] Add partial serde support for ParquetWriterOptions (#8627) * Add serde support for ParquetWriterOptions * save progress * test passes * Improve test * Refactor and add link to follow on issue * remove duplicate code * clippy * Regen * remove comments from proto file * change proto types from i32 to u32 pre feedback on PR * change to u64 --- datafusion/proto/proto/datafusion.proto | 15 + datafusion/proto/src/generated/pbjson.rs | 321 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 28 +- datafusion/proto/src/logical_plan/mod.rs | 146 ++++++-- .../proto/src/physical_plan/from_proto.rs | 7 + .../tests/cases/roundtrip_logical_plan.rs | 41 ++- 6 files changed, 524 insertions(+), 34 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 05f0b6434368..d02fc8e91b41 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1206,6 +1206,7 @@ message PartitionColumn { message FileTypeWriterOptions { oneof FileType { JsonWriterOptions json_options = 1; + ParquetWriterOptions parquet_options = 2; } } @@ -1213,6 +1214,20 @@ message JsonWriterOptions { CompressionTypeVariant compression = 1; } +message ParquetWriterOptions { + WriterProperties writer_properties = 1; +} + +message WriterProperties { + uint64 data_page_size_limit = 1; + uint64 dictionary_page_size_limit = 2; + uint64 data_page_row_count_limit = 3; + uint64 write_batch_size = 4; + uint64 max_row_group_size = 5; + string writer_version = 6; + string created_by = 7; +} + message FileSinkConfig { reserved 6; // writer_mode diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0fdeab0a40f6..f860b1f1e6a0 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -7890,6 +7890,9 @@ impl serde::Serialize for FileTypeWriterOptions { file_type_writer_options::FileType::JsonOptions(v) => { struct_ser.serialize_field("jsonOptions", v)?; } + file_type_writer_options::FileType::ParquetOptions(v) => { + struct_ser.serialize_field("parquetOptions", v)?; + } } } struct_ser.end() @@ -7904,11 +7907,14 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { const FIELDS: &[&str] = &[ "json_options", "jsonOptions", + "parquet_options", + "parquetOptions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { JsonOptions, + ParquetOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7931,6 +7937,7 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { { match value { "jsonOptions" | "json_options" => Ok(GeneratedField::JsonOptions), + "parquetOptions" | "parquet_options" => Ok(GeneratedField::ParquetOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7958,6 +7965,13 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { return Err(serde::de::Error::duplicate_field("jsonOptions")); } file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::JsonOptions) +; + } + GeneratedField::ParquetOptions => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("parquetOptions")); + } + file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::ParquetOptions) ; } } @@ -15171,6 +15185,98 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { deserializer.deserialize_struct("datafusion.ParquetScanExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ParquetWriterOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.writer_properties.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetWriterOptions", len)?; + if let Some(v) = self.writer_properties.as_ref() { + struct_ser.serialize_field("writerProperties", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetWriterOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "writer_properties", + "writerProperties", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + WriterProperties, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "writerProperties" | "writer_properties" => Ok(GeneratedField::WriterProperties), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetWriterOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ParquetWriterOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut writer_properties__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::WriterProperties => { + if writer_properties__.is_some() { + return Err(serde::de::Error::duplicate_field("writerProperties")); + } + writer_properties__ = map_.next_value()?; + } + } + } + Ok(ParquetWriterOptions { + writer_properties: writer_properties__, + }) + } + } + deserializer.deserialize_struct("datafusion.ParquetWriterOptions", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PartialTableReference { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -27144,3 +27250,218 @@ impl<'de> serde::Deserialize<'de> for WindowNode { deserializer.deserialize_struct("datafusion.WindowNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for WriterProperties { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.data_page_size_limit != 0 { + len += 1; + } + if self.dictionary_page_size_limit != 0 { + len += 1; + } + if self.data_page_row_count_limit != 0 { + len += 1; + } + if self.write_batch_size != 0 { + len += 1; + } + if self.max_row_group_size != 0 { + len += 1; + } + if !self.writer_version.is_empty() { + len += 1; + } + if !self.created_by.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.WriterProperties", len)?; + if self.data_page_size_limit != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dataPageSizeLimit", ToString::to_string(&self.data_page_size_limit).as_str())?; + } + if self.dictionary_page_size_limit != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dictionaryPageSizeLimit", ToString::to_string(&self.dictionary_page_size_limit).as_str())?; + } + if self.data_page_row_count_limit != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dataPageRowCountLimit", ToString::to_string(&self.data_page_row_count_limit).as_str())?; + } + if self.write_batch_size != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("writeBatchSize", ToString::to_string(&self.write_batch_size).as_str())?; + } + if self.max_row_group_size != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("maxRowGroupSize", ToString::to_string(&self.max_row_group_size).as_str())?; + } + if !self.writer_version.is_empty() { + struct_ser.serialize_field("writerVersion", &self.writer_version)?; + } + if !self.created_by.is_empty() { + struct_ser.serialize_field("createdBy", &self.created_by)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for WriterProperties { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "data_page_size_limit", + "dataPageSizeLimit", + "dictionary_page_size_limit", + "dictionaryPageSizeLimit", + "data_page_row_count_limit", + "dataPageRowCountLimit", + "write_batch_size", + "writeBatchSize", + "max_row_group_size", + "maxRowGroupSize", + "writer_version", + "writerVersion", + "created_by", + "createdBy", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + DataPageSizeLimit, + DictionaryPageSizeLimit, + DataPageRowCountLimit, + WriteBatchSize, + MaxRowGroupSize, + WriterVersion, + CreatedBy, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "dataPageSizeLimit" | "data_page_size_limit" => Ok(GeneratedField::DataPageSizeLimit), + "dictionaryPageSizeLimit" | "dictionary_page_size_limit" => Ok(GeneratedField::DictionaryPageSizeLimit), + "dataPageRowCountLimit" | "data_page_row_count_limit" => Ok(GeneratedField::DataPageRowCountLimit), + "writeBatchSize" | "write_batch_size" => Ok(GeneratedField::WriteBatchSize), + "maxRowGroupSize" | "max_row_group_size" => Ok(GeneratedField::MaxRowGroupSize), + "writerVersion" | "writer_version" => Ok(GeneratedField::WriterVersion), + "createdBy" | "created_by" => Ok(GeneratedField::CreatedBy), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = WriterProperties; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.WriterProperties") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut data_page_size_limit__ = None; + let mut dictionary_page_size_limit__ = None; + let mut data_page_row_count_limit__ = None; + let mut write_batch_size__ = None; + let mut max_row_group_size__ = None; + let mut writer_version__ = None; + let mut created_by__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::DataPageSizeLimit => { + if data_page_size_limit__.is_some() { + return Err(serde::de::Error::duplicate_field("dataPageSizeLimit")); + } + data_page_size_limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DictionaryPageSizeLimit => { + if dictionary_page_size_limit__.is_some() { + return Err(serde::de::Error::duplicate_field("dictionaryPageSizeLimit")); + } + dictionary_page_size_limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DataPageRowCountLimit => { + if data_page_row_count_limit__.is_some() { + return Err(serde::de::Error::duplicate_field("dataPageRowCountLimit")); + } + data_page_row_count_limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::WriteBatchSize => { + if write_batch_size__.is_some() { + return Err(serde::de::Error::duplicate_field("writeBatchSize")); + } + write_batch_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::MaxRowGroupSize => { + if max_row_group_size__.is_some() { + return Err(serde::de::Error::duplicate_field("maxRowGroupSize")); + } + max_row_group_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::WriterVersion => { + if writer_version__.is_some() { + return Err(serde::de::Error::duplicate_field("writerVersion")); + } + writer_version__ = Some(map_.next_value()?); + } + GeneratedField::CreatedBy => { + if created_by__.is_some() { + return Err(serde::de::Error::duplicate_field("createdBy")); + } + created_by__ = Some(map_.next_value()?); + } + } + } + Ok(WriterProperties { + data_page_size_limit: data_page_size_limit__.unwrap_or_default(), + dictionary_page_size_limit: dictionary_page_size_limit__.unwrap_or_default(), + data_page_row_count_limit: data_page_row_count_limit__.unwrap_or_default(), + write_batch_size: write_batch_size__.unwrap_or_default(), + max_row_group_size: max_row_group_size__.unwrap_or_default(), + writer_version: writer_version__.unwrap_or_default(), + created_by: created_by__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.WriterProperties", FIELDS, GeneratedVisitor) + } +} diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index e44355859d65..459d5a965cd3 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1642,7 +1642,7 @@ pub struct PartitionColumn { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FileTypeWriterOptions { - #[prost(oneof = "file_type_writer_options::FileType", tags = "1")] + #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2")] pub file_type: ::core::option::Option, } /// Nested message and enum types in `FileTypeWriterOptions`. @@ -1652,6 +1652,8 @@ pub mod file_type_writer_options { pub enum FileType { #[prost(message, tag = "1")] JsonOptions(super::JsonWriterOptions), + #[prost(message, tag = "2")] + ParquetOptions(super::ParquetWriterOptions), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1662,6 +1664,30 @@ pub struct JsonWriterOptions { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetWriterOptions { + #[prost(message, optional, tag = "1")] + pub writer_properties: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct WriterProperties { + #[prost(uint64, tag = "1")] + pub data_page_size_limit: u64, + #[prost(uint64, tag = "2")] + pub dictionary_page_size_limit: u64, + #[prost(uint64, tag = "3")] + pub data_page_row_count_limit: u64, + #[prost(uint64, tag = "4")] + pub write_batch_size: u64, + #[prost(uint64, tag = "5")] + pub max_row_group_size: u64, + #[prost(string, tag = "6")] + pub writer_version: ::prost::alloc::string::String, + #[prost(string, tag = "7")] + pub created_by: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct FileSinkConfig { #[prost(string, tag = "1")] pub object_store_url: ::prost::alloc::string::String, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index e03b3ffa7b84..d137a41fa19b 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -23,7 +23,8 @@ use std::sync::Arc; use crate::common::{byte_to_string, proto_error, str_to_byte}; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; use crate::protobuf::{ - copy_to_node, CustomTableScanNode, LogicalExprNodeCollection, SqlOption, + copy_to_node, file_type_writer_options, CustomTableScanNode, + LogicalExprNodeCollection, SqlOption, }; use crate::{ convert_required, @@ -49,7 +50,7 @@ use datafusion::{ use datafusion_common::{ context, file_options::StatementOptions, internal_err, not_impl_err, parsers::CompressionTypeVariant, plan_datafusion_err, DataFusionError, FileType, - OwnedTableReference, Result, + FileTypeWriterOptions, OwnedTableReference, Result, }; use datafusion_expr::{ dml, @@ -62,6 +63,8 @@ use datafusion_expr::{ DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, }; +use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_expr::dml::CopyOptions; use prost::bytes::BufMut; use prost::Message; @@ -833,19 +836,48 @@ impl AsLogicalPlan for LogicalPlanNode { let copy_options = match ©.copy_options { Some(copy_to_node::CopyOptions::SqlOptions(opt)) => { - let options = opt.option.iter().map(|o| (o.key.clone(), o.value.clone())).collect(); - CopyOptions::SQLOptions(StatementOptions::from( - &options, - )) + let options = opt + .option + .iter() + .map(|o| (o.key.clone(), o.value.clone())) + .collect(); + CopyOptions::SQLOptions(StatementOptions::from(&options)) } - Some(copy_to_node::CopyOptions::WriterOptions(_)) => { - return Err(proto_error( - "LogicalPlan serde is not yet implemented for CopyTo with WriterOptions", - )) + Some(copy_to_node::CopyOptions::WriterOptions(opt)) => { + match &opt.file_type { + Some(ft) => match ft { + file_type_writer_options::FileType::ParquetOptions( + writer_options, + ) => { + let writer_properties = + match &writer_options.writer_properties { + Some(serialized_writer_options) => { + writer_properties_from_proto( + serialized_writer_options, + )? + } + _ => WriterProperties::default(), + }; + CopyOptions::WriterOptions(Box::new( + FileTypeWriterOptions::Parquet( + ParquetWriterOptions::new(writer_properties), + ), + )) + } + _ => { + return Err(proto_error( + "WriterOptions unsupported file_type", + )) + } + }, + None => { + return Err(proto_error( + "WriterOptions missing file_type", + )) + } + } } - other => return Err(proto_error(format!( - "LogicalPlan serde is not yet implemented for CopyTo with CopyOptions {other:?}", - ))) + None => return Err(proto_error("CopyTo missing CopyOptions")), }; Ok(datafusion_expr::LogicalPlan::Copy( datafusion_expr::dml::CopyTo { @@ -1580,22 +1612,48 @@ impl AsLogicalPlan for LogicalPlanNode { extension_codec, )?; - let copy_options_proto: Option = match copy_options { - CopyOptions::SQLOptions(opt) => { - let options: Vec = opt.clone().into_inner().iter().map(|(k, v)| SqlOption { - key: k.to_string(), - value: v.to_string(), - }).collect(); - Some(copy_to_node::CopyOptions::SqlOptions(protobuf::SqlOptions { - option: options - })) - } - CopyOptions::WriterOptions(_) => { - return Err(proto_error( - "LogicalPlan serde is not yet implemented for CopyTo with WriterOptions", - )) - } - }; + let copy_options_proto: Option = + match copy_options { + CopyOptions::SQLOptions(opt) => { + let options: Vec = opt + .clone() + .into_inner() + .iter() + .map(|(k, v)| SqlOption { + key: k.to_string(), + value: v.to_string(), + }) + .collect(); + Some(copy_to_node::CopyOptions::SqlOptions( + protobuf::SqlOptions { option: options }, + )) + } + CopyOptions::WriterOptions(opt) => { + match opt.as_ref() { + FileTypeWriterOptions::Parquet(parquet_opts) => { + let parquet_writer_options = + protobuf::ParquetWriterOptions { + writer_properties: Some( + writer_properties_to_proto( + &parquet_opts.writer_options, + ), + ), + }; + let parquet_options = file_type_writer_options::FileType::ParquetOptions(parquet_writer_options); + Some(copy_to_node::CopyOptions::WriterOptions( + protobuf::FileTypeWriterOptions { + file_type: Some(parquet_options), + }, + )) + } + _ => { + return Err(proto_error( + "Unsupported FileTypeWriterOptions in CopyTo", + )) + } + } + } + }; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CopyTo(Box::new( @@ -1615,3 +1673,33 @@ impl AsLogicalPlan for LogicalPlanNode { } } } + +pub(crate) fn writer_properties_to_proto( + props: &WriterProperties, +) -> protobuf::WriterProperties { + protobuf::WriterProperties { + data_page_size_limit: props.data_page_size_limit() as u64, + dictionary_page_size_limit: props.dictionary_page_size_limit() as u64, + data_page_row_count_limit: props.data_page_row_count_limit() as u64, + write_batch_size: props.write_batch_size() as u64, + max_row_group_size: props.max_row_group_size() as u64, + writer_version: format!("{:?}", props.writer_version()), + created_by: props.created_by().to_string(), + } +} + +pub(crate) fn writer_properties_from_proto( + props: &protobuf::WriterProperties, +) -> Result { + let writer_version = WriterVersion::from_str(&props.writer_version) + .map_err(|e| proto_error(e.to_string()))?; + Ok(WriterProperties::builder() + .set_created_by(props.created_by.clone()) + .set_writer_version(writer_version) + .set_dictionary_page_size_limit(props.dictionary_page_size_limit as usize) + .set_data_page_row_count_limit(props.data_page_row_count_limit as usize) + .set_data_page_size_limit(props.data_page_size_limit as usize) + .set_write_batch_size(props.write_batch_size as usize) + .set_max_row_group_size(props.max_row_group_size as usize) + .build()) +} diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 65f9f139a87b..824eb60a5715 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -40,6 +40,7 @@ use datafusion::physical_plan::{ functions, ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, }; use datafusion_common::file_options::json_writer::JsonWriterOptions; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ @@ -52,6 +53,7 @@ use crate::logical_plan; use crate::protobuf; use crate::protobuf::physical_expr_node::ExprType; +use crate::logical_plan::writer_properties_from_proto; use chrono::{TimeZone, Utc}; use object_store::path::Path; use object_store::ObjectMeta; @@ -769,6 +771,11 @@ impl TryFrom<&protobuf::FileTypeWriterOptions> for FileTypeWriterOptions { protobuf::file_type_writer_options::FileType::JsonOptions(opts) => Ok( Self::JSON(JsonWriterOptions::new(opts.compression().into())), ), + protobuf::file_type_writer_options::FileType::ParquetOptions(opt) => { + let props = opt.writer_properties.clone().unwrap_or_default(); + let writer_properties = writer_properties_from_proto(&props)?; + Ok(Self::Parquet(ParquetWriterOptions::new(writer_properties))) + } } } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 9798b06f4724..3eeae01a643e 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -31,7 +31,7 @@ use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use datafusion::parquet::file::properties::WriterProperties; +use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext}; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -330,7 +330,6 @@ async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { } #[tokio::test] -#[ignore] // see https://github.com/apache/arrow-datafusion/issues/8619 async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { let ctx = SessionContext::new(); @@ -339,11 +338,17 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { let writer_properties = WriterProperties::builder() .set_bloom_filter_enabled(true) .set_created_by("DataFusion Test".to_string()) + .set_writer_version(WriterVersion::PARQUET_2_0) + .set_write_batch_size(111) + .set_data_page_size_limit(222) + .set_data_page_row_count_limit(333) + .set_dictionary_page_size_limit(444) + .set_max_row_group_size(555) .build(); let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), - output_url: "test.csv".to_string(), - file_format: FileType::CSV, + output_url: "test.parquet".to_string(), + file_format: FileType::PARQUET, single_file_output: true, copy_options: CopyOptions::WriterOptions(Box::new( FileTypeWriterOptions::Parquet(ParquetWriterOptions::new(writer_properties)), @@ -354,6 +359,34 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + match logical_round_trip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.parquet", copy_to.output_url); + assert_eq!(FileType::PARQUET, copy_to.file_format); + assert!(copy_to.single_file_output); + match ©_to.copy_options { + CopyOptions::WriterOptions(y) => match y.as_ref() { + FileTypeWriterOptions::Parquet(p) => { + let props = &p.writer_options; + assert_eq!("DataFusion Test", props.created_by()); + assert_eq!( + "PARQUET_2_0", + format!("{:?}", props.writer_version()) + ); + assert_eq!(111, props.write_batch_size()); + assert_eq!(222, props.data_page_size_limit()); + assert_eq!(333, props.data_page_row_count_limit()); + assert_eq!(444, props.dictionary_page_size_limit()); + assert_eq!(555, props.max_row_group_size()); + } + _ => panic!(), + }, + _ => panic!(), + } + } + _ => panic!(), + } + Ok(()) } From 7443f30fc020cca05af74e22d2b5f42ebfe9604e Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Sat, 23 Dec 2023 19:39:03 +0100 Subject: [PATCH 22/63] add arguments length check (#8622) --- .../physical-expr/src/array_expressions.rs | 110 +++++++++++++++++- 1 file changed, 107 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 4dfc157e53c7..3ee99d7e8e55 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -472,6 +472,10 @@ where /// For example: /// > array_element(\[1, 2, 3], 2) -> 2 pub fn array_element(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_element needs two arguments"); + } + match &args[0].data_type() { DataType::List(_) => { let array = as_list_array(&args[0])?; @@ -585,6 +589,10 @@ pub fn array_except(args: &[ArrayRef]) -> Result { /// /// See test cases in `array.slt` for more details. pub fn array_slice(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_slice needs three arguments"); + } + let array_data_type = args[0].data_type(); match array_data_type { DataType::List(_) => { @@ -736,6 +744,10 @@ where /// array_pop_back SQL function pub fn array_pop_back(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_pop_back needs one argument"); + } + let list_array = as_list_array(&args[0])?; let from_array = Int64Array::from(vec![1; list_array.len()]); let to_array = Int64Array::from( @@ -885,6 +897,10 @@ pub fn array_pop_front(args: &[ArrayRef]) -> Result { /// Array_append SQL function pub fn array_append(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_append expects two arguments"); + } + let list_array = as_list_array(&args[0])?; let element_array = &args[1]; @@ -911,6 +927,10 @@ pub fn array_append(args: &[ArrayRef]) -> Result { /// Array_sort SQL function pub fn array_sort(args: &[ArrayRef]) -> Result { + if args.is_empty() || args.len() > 3 { + return exec_err!("array_sort expects one to three arguments"); + } + let sort_option = match args.len() { 1 => None, 2 => { @@ -990,6 +1010,10 @@ fn order_nulls_first(modifier: &str) -> Result { /// Array_prepend SQL function pub fn array_prepend(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_prepend expects two arguments"); + } + let list_array = as_list_array(&args[1])?; let element_array = &args[0]; @@ -1110,6 +1134,10 @@ fn concat_internal(args: &[ArrayRef]) -> Result { /// Array_concat/Array_cat SQL function pub fn array_concat(args: &[ArrayRef]) -> Result { + if args.is_empty() { + return exec_err!("array_concat expects at least one arguments"); + } + let mut new_args = vec![]; for arg in args { let ndim = list_ndims(arg.data_type()); @@ -1126,6 +1154,10 @@ pub fn array_concat(args: &[ArrayRef]) -> Result { /// Array_empty SQL function pub fn array_empty(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_empty expects one argument"); + } + if as_null_array(&args[0]).is_ok() { // Make sure to return Boolean type. return Ok(Arc::new(BooleanArray::new_null(args[0].len()))); @@ -1150,6 +1182,10 @@ fn array_empty_dispatch(array: &ArrayRef) -> Result Result { + if args.len() != 2 { + return exec_err!("array_repeat expects two arguments"); + } + let element = &args[0]; let count_array = as_int64_array(&args[1])?; @@ -1285,6 +1321,10 @@ fn general_list_repeat( /// Array_position SQL function pub fn array_position(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("array_position expects two or three arguments"); + } + let list_array = as_list_array(&args[0])?; let element_array = &args[1]; @@ -1349,6 +1389,10 @@ fn general_position( /// Array_positions SQL function pub fn array_positions(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_positions expects two arguments"); + } + let element = &args[1]; match &args[0].data_type() { @@ -1508,16 +1552,28 @@ fn array_remove_internal( } pub fn array_remove_all(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_remove_all expects two arguments"); + } + let arr_n = vec![i64::MAX; args[0].len()]; array_remove_internal(&args[0], &args[1], arr_n) } pub fn array_remove(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_remove expects two arguments"); + } + let arr_n = vec![1; args[0].len()]; array_remove_internal(&args[0], &args[1], arr_n) } pub fn array_remove_n(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_remove_n expects three arguments"); + } + let arr_n = as_int64_array(&args[2])?.values().to_vec(); array_remove_internal(&args[0], &args[1], arr_n) } @@ -1634,6 +1690,10 @@ fn general_replace( } pub fn array_replace(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_replace expects three arguments"); + } + // replace at most one occurence for each element let arr_n = vec![1; args[0].len()]; let array = &args[0]; @@ -1651,6 +1711,10 @@ pub fn array_replace(args: &[ArrayRef]) -> Result { } pub fn array_replace_n(args: &[ArrayRef]) -> Result { + if args.len() != 4 { + return exec_err!("array_replace_n expects four arguments"); + } + // replace the specified number of occurences let arr_n = as_int64_array(&args[3])?.values().to_vec(); let array = &args[0]; @@ -1670,6 +1734,10 @@ pub fn array_replace_n(args: &[ArrayRef]) -> Result { } pub fn array_replace_all(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_replace_all expects three arguments"); + } + // replace all occurrences (up to "i64::MAX") let arr_n = vec![i64::MAX; args[0].len()]; let array = &args[0]; @@ -1760,7 +1828,7 @@ fn union_generic_lists( /// Array_union SQL function pub fn array_union(args: &[ArrayRef]) -> Result { if args.len() != 2 { - return exec_err!("array_union needs two arguments"); + return exec_err!("array_union needs 2 arguments"); } let array1 = &args[0]; let array2 = &args[1]; @@ -1802,6 +1870,10 @@ pub fn array_union(args: &[ArrayRef]) -> Result { /// Array_to_string SQL function pub fn array_to_string(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("array_to_string expects two or three arguments"); + } + let arr = &args[0]; let delimiters = as_string_array(&args[1])?; @@ -1911,6 +1983,10 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { /// Cardinality SQL function pub fn cardinality(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("cardinality expects one argument"); + } + let list_array = as_list_array(&args[0])?.clone(); let result = list_array @@ -1967,6 +2043,10 @@ fn flatten_internal( /// Flatten SQL function pub fn flatten(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("flatten expects one argument"); + } + let flattened_array = flatten_internal(&args[0], None)?; Ok(Arc::new(flattened_array) as ArrayRef) } @@ -1991,6 +2071,10 @@ fn array_length_dispatch(array: &[ArrayRef]) -> Result Result { + if args.len() != 1 && args.len() != 2 { + return exec_err!("array_length expects one or two arguments"); + } + match &args[0].data_type() { DataType::List(_) => array_length_dispatch::(args), DataType::LargeList(_) => array_length_dispatch::(args), @@ -2037,6 +2121,10 @@ pub fn array_dims(args: &[ArrayRef]) -> Result { /// Array_ndims SQL function pub fn array_ndims(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_ndims needs one argument"); + } + if let Some(list_array) = args[0].as_list_opt::() { let ndims = datafusion_common::utils::list_ndims(list_array.data_type()); @@ -2127,6 +2215,10 @@ fn general_array_has_dispatch( /// Array_has SQL function pub fn array_has(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_has needs two arguments"); + } + let array_type = args[0].data_type(); match array_type { @@ -2142,6 +2234,10 @@ pub fn array_has(args: &[ArrayRef]) -> Result { /// Array_has_any SQL function pub fn array_has_any(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_has_any needs two arguments"); + } + let array_type = args[0].data_type(); match array_type { @@ -2157,6 +2253,10 @@ pub fn array_has_any(args: &[ArrayRef]) -> Result { /// Array_has_all SQL function pub fn array_has_all(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_has_all needs two arguments"); + } + let array_type = args[0].data_type(); match array_type { @@ -2261,7 +2361,9 @@ pub fn string_to_array(args: &[ArrayRef]) -> Result Result { - assert_eq!(args.len(), 2); + if args.len() != 2 { + return exec_err!("array_intersect needs two arguments"); + } let first_array = &args[0]; let second_array = &args[1]; @@ -2364,7 +2466,9 @@ pub fn general_array_distinct( /// array_distinct SQL function /// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4] pub fn array_distinct(args: &[ArrayRef]) -> Result { - assert_eq!(args.len(), 1); + if args.len() != 1 { + return exec_err!("array_distinct needs one argument"); + } // handle null if args[0].data_type() == &DataType::Null { From 69e5382aaac8dff6b163de68abc8a46f8780791a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 23 Dec 2023 13:41:59 -0500 Subject: [PATCH 23/63] Improve DataFrame functional tests (#8630) --- datafusion/core/src/dataframe/mod.rs | 220 ++++++++++----------------- 1 file changed, 82 insertions(+), 138 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 4b8a9c5b7d79..2ae4a7c21a9c 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1356,15 +1356,30 @@ mod tests { use arrow::array::{self, Int32Array}; use arrow::datatypes::DataType; - use datafusion_common::{Constraint, Constraints, ScalarValue}; + use datafusion_common::{Constraint, Constraints}; use datafusion_expr::{ avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum, - BinaryExpr, BuiltInWindowFunction, Operator, ScalarFunctionImplementation, - Volatility, WindowFrame, WindowFunction, + BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, + WindowFunction, }; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::get_plan_string; + // Get string representation of the plan + async fn assert_physical_plan(df: &DataFrame, expected: Vec<&str>) { + let physical_plan = df + .clone() + .create_physical_plan() + .await + .expect("Error creating physical plan"); + + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + } + pub fn table_with_constraints() -> Arc { let dual_schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), @@ -1587,47 +1602,36 @@ mod tests { let config = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::new_with_config(config); - let table1 = table_with_constraints(); - let df = ctx.read_table(table1)?; - let col_id = Expr::Column(datafusion_common::Column { - relation: None, - name: "id".to_string(), - }); - let col_name = Expr::Column(datafusion_common::Column { - relation: None, - name: "name".to_string(), - }); + let df = ctx.read_table(table_with_constraints())?; - // group by contains id column - let group_expr = vec![col_id.clone()]; + // GROUP BY id + let group_expr = vec![col("id")]; let aggr_expr = vec![]; let df = df.aggregate(group_expr, aggr_expr)?; - // expr list contains id, name - let expr_list = vec![col_id, col_name]; - let df = df.select(expr_list)?; - let physical_plan = df.clone().create_physical_plan().await?; - let expected = vec![ - "AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", - " MemoryExec: partitions=1, partition_sizes=[1]", - ]; - // Get string representation of the plan - let actual = get_plan_string(&physical_plan); - assert_eq!( - expected, actual, - "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - // Since id and name are functionally dependant, we can use name among expression - // even if it is not part of the group by expression. - let df_results = collect(physical_plan, ctx.task_ctx()).await?; + // Since id and name are functionally dependant, we can use name among + // expression even if it is not part of the group by expression and can + // select "name" column even though it wasn't explicitly grouped + let df = df.select(vec![col("id"), col("name")])?; + assert_physical_plan( + &df, + vec![ + "AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ], + ) + .await; + + let df_results = df.collect().await?; #[rustfmt::skip] - assert_batches_sorted_eq!( - ["+----+------+", + assert_batches_sorted_eq!([ + "+----+------+", "| id | name |", "+----+------+", "| 1 | a |", - "+----+------+",], + "+----+------+" + ], &df_results ); @@ -1640,57 +1644,31 @@ mod tests { let config = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::new_with_config(config); - let table1 = table_with_constraints(); - let df = ctx.read_table(table1)?; - let col_id = Expr::Column(datafusion_common::Column { - relation: None, - name: "id".to_string(), - }); - let col_name = Expr::Column(datafusion_common::Column { - relation: None, - name: "name".to_string(), - }); + let df = ctx.read_table(table_with_constraints())?; - // group by contains id column - let group_expr = vec![col_id.clone()]; + // GROUP BY id + let group_expr = vec![col("id")]; let aggr_expr = vec![]; let df = df.aggregate(group_expr, aggr_expr)?; - let condition1 = Expr::BinaryExpr(BinaryExpr::new( - Box::new(col_id.clone()), - Operator::Eq, - Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), - )); - let condition2 = Expr::BinaryExpr(BinaryExpr::new( - Box::new(col_name), - Operator::Eq, - Box::new(Expr::Literal(ScalarValue::Utf8(Some("a".to_string())))), - )); - // Predicate refers to id, and name fields - let predicate = Expr::BinaryExpr(BinaryExpr::new( - Box::new(condition1), - Operator::And, - Box::new(condition2), - )); + // Predicate refers to id, and name fields: + // id = 1 AND name = 'a' + let predicate = col("id").eq(lit(1i32)).and(col("name").eq(lit("a"))); let df = df.filter(predicate)?; - let physical_plan = df.clone().create_physical_plan().await?; - - let expected = vec![ + assert_physical_plan( + &df, + vec![ "CoalesceBatchesExec: target_batch_size=8192", " FilterExec: id@0 = 1 AND name@1 = a", " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", " MemoryExec: partitions=1, partition_sizes=[1]", - ]; - // Get string representation of the plan - let actual = get_plan_string(&physical_plan); - assert_eq!( - expected, actual, - "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); + ], + ) + .await; // Since id and name are functionally dependant, we can use name among expression // even if it is not part of the group by expression. - let df_results = collect(physical_plan, ctx.task_ctx()).await?; + let df_results = df.collect().await?; #[rustfmt::skip] assert_batches_sorted_eq!( @@ -1711,53 +1689,35 @@ mod tests { let config = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::new_with_config(config); - let table1 = table_with_constraints(); - let df = ctx.read_table(table1)?; - let col_id = Expr::Column(datafusion_common::Column { - relation: None, - name: "id".to_string(), - }); - let col_name = Expr::Column(datafusion_common::Column { - relation: None, - name: "name".to_string(), - }); + let df = ctx.read_table(table_with_constraints())?; - // group by contains id column - let group_expr = vec![col_id.clone()]; + // GROUP BY id + let group_expr = vec![col("id")]; let aggr_expr = vec![]; // group by id, let df = df.aggregate(group_expr, aggr_expr)?; - let condition1 = Expr::BinaryExpr(BinaryExpr::new( - Box::new(col_id.clone()), - Operator::Eq, - Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), - )); // Predicate refers to id field - let predicate = condition1; - // id=0 + // id = 1 + let predicate = col("id").eq(lit(1i32)); let df = df.filter(predicate)?; // Select expression refers to id, and name columns. // id, name - let df = df.select(vec![col_id.clone(), col_name.clone()])?; - let physical_plan = df.clone().create_physical_plan().await?; - - let expected = vec![ + let df = df.select(vec![col("id"), col("name")])?; + assert_physical_plan( + &df, + vec![ "CoalesceBatchesExec: target_batch_size=8192", " FilterExec: id@0 = 1", " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", " MemoryExec: partitions=1, partition_sizes=[1]", - ]; - // Get string representation of the plan - let actual = get_plan_string(&physical_plan); - assert_eq!( - expected, actual, - "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); + ], + ) + .await; // Since id and name are functionally dependant, we can use name among expression // even if it is not part of the group by expression. - let df_results = collect(physical_plan, ctx.task_ctx()).await?; + let df_results = df.collect().await?; #[rustfmt::skip] assert_batches_sorted_eq!( @@ -1778,51 +1738,35 @@ mod tests { let config = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::new_with_config(config); - let table1 = table_with_constraints(); - let df = ctx.read_table(table1)?; - let col_id = Expr::Column(datafusion_common::Column { - relation: None, - name: "id".to_string(), - }); + let df = ctx.read_table(table_with_constraints())?; - // group by contains id column - let group_expr = vec![col_id.clone()]; + // GROUP BY id + let group_expr = vec![col("id")]; let aggr_expr = vec![]; - // group by id, let df = df.aggregate(group_expr, aggr_expr)?; - let condition1 = Expr::BinaryExpr(BinaryExpr::new( - Box::new(col_id.clone()), - Operator::Eq, - Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), - )); // Predicate refers to id field - let predicate = condition1; - // id=1 + // id = 1 + let predicate = col("id").eq(lit(1i32)); let df = df.filter(predicate)?; // Select expression refers to id column. // id - let df = df.select(vec![col_id.clone()])?; - let physical_plan = df.clone().create_physical_plan().await?; + let df = df.select(vec![col("id")])?; // In this case aggregate shouldn't be expanded, since these // columns are not used. - let expected = vec![ - "CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: id@0 = 1", - " AggregateExec: mode=Single, gby=[id@0 as id], aggr=[]", - " MemoryExec: partitions=1, partition_sizes=[1]", - ]; - // Get string representation of the plan - let actual = get_plan_string(&physical_plan); - assert_eq!( - expected, actual, - "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); + assert_physical_plan( + &df, + vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1", + " AggregateExec: mode=Single, gby=[id@0 as id], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ], + ) + .await; - // Since id and name are functionally dependant, we can use name among expression - // even if it is not part of the group by expression. - let df_results = collect(physical_plan, ctx.task_ctx()).await?; + let df_results = df.collect().await?; #[rustfmt::skip] assert_batches_sorted_eq!( From 72af0ffdf00247e5383adcdbe3dada7ca85d9172 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 24 Dec 2023 00:17:05 -0800 Subject: [PATCH 24/63] Improve regexp_match performance by avoiding cloning Regex (#8631) * Improve regexp_match performance by avoiding cloning Regex * Update datafusion/physical-expr/src/regex_expressions.rs Co-authored-by: Andrew Lamb * Removing clone of Regex in regexp_replace --------- Co-authored-by: Andrew Lamb --- .../physical-expr/src/regex_expressions.rs | 96 +++++++++++++++++-- 1 file changed, 87 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-expr/src/regex_expressions.rs b/datafusion/physical-expr/src/regex_expressions.rs index 7bafed072b61..b778fd86c24b 100644 --- a/datafusion/physical-expr/src/regex_expressions.rs +++ b/datafusion/physical-expr/src/regex_expressions.rs @@ -25,7 +25,8 @@ use arrow::array::{ new_null_array, Array, ArrayDataBuilder, ArrayRef, BufferBuilder, GenericStringArray, OffsetSizeTrait, }; -use arrow::compute; +use arrow_array::builder::{GenericStringBuilder, ListBuilder}; +use arrow_schema::ArrowError; use datafusion_common::{arrow_datafusion_err, plan_err}; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, @@ -58,7 +59,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { 2 => { let values = as_generic_string_array::(&args[0])?; let regex = as_generic_string_array::(&args[1])?; - compute::regexp_match(values, regex, None).map_err(|e| arrow_datafusion_err!(e)) + _regexp_match(values, regex, None).map_err(|e| arrow_datafusion_err!(e)) } 3 => { let values = as_generic_string_array::(&args[0])?; @@ -69,7 +70,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { Some(f) if f.iter().any(|s| s == Some("g")) => { plan_err!("regexp_match() does not support the \"global\" option") }, - _ => compute::regexp_match(values, regex, flags).map_err(|e| arrow_datafusion_err!(e)), + _ => _regexp_match(values, regex, flags).map_err(|e| arrow_datafusion_err!(e)), } } other => internal_err!( @@ -78,6 +79,83 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { } } +/// TODO: Remove this once it is included in arrow-rs new release. +/// +fn _regexp_match( + array: &GenericStringArray, + regex_array: &GenericStringArray, + flags_array: Option<&GenericStringArray>, +) -> std::result::Result { + let mut patterns: std::collections::HashMap = + std::collections::HashMap::new(); + let builder: GenericStringBuilder = + GenericStringBuilder::with_capacity(0, 0); + let mut list_builder = ListBuilder::new(builder); + + let complete_pattern = match flags_array { + Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( + |(pattern, flags)| { + pattern.map(|pattern| match flags { + Some(value) => format!("(?{value}){pattern}"), + None => pattern.to_string(), + }) + }, + )) as Box>>, + None => Box::new( + regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; + + array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + // Required for Postgres compatibility: + // SELECT regexp_match('foobarbequebaz', ''); = {""} + (Some(_), Some(pattern)) if pattern == *"" => { + list_builder.values().append_value(""); + list_builder.append(true); + } + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re, + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + patterns.insert(pattern.clone(), re); + patterns.get(&pattern).unwrap() + } + }; + match re.captures(value) { + Some(caps) => { + let mut iter = caps.iter(); + if caps.len() > 1 { + iter.next(); + } + for m in iter.flatten() { + list_builder.values().append_value(m.as_str()); + } + + list_builder.append(true); + } + None => list_builder.append(false), + } + } + _ => list_builder.append(false), + } + Ok(()) + }) + .collect::, ArrowError>>()?; + Ok(Arc::new(list_builder.finish())) +} + /// replace POSIX capture groups (like \1) with Rust Regex group (like ${1}) /// used by regexp_replace fn regex_replace_posix_groups(replacement: &str) -> String { @@ -116,12 +194,12 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result // if patterns hashmap already has regexp then use else else create and return let re = match patterns.get(pattern) { - Some(re) => Ok(re.clone()), + Some(re) => Ok(re), None => { match Regex::new(pattern) { Ok(re) => { - patterns.insert(pattern.to_string(), re.clone()); - Ok(re) + patterns.insert(pattern.to_string(), re); + Ok(patterns.get(pattern).unwrap()) }, Err(err) => Err(DataFusionError::External(Box::new(err))), } @@ -162,12 +240,12 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result // if patterns hashmap already has regexp then use else else create and return let re = match patterns.get(&pattern) { - Some(re) => Ok(re.clone()), + Some(re) => Ok(re), None => { match Regex::new(pattern.as_str()) { Ok(re) => { - patterns.insert(pattern, re.clone()); - Ok(re) + patterns.insert(pattern.clone(), re); + Ok(patterns.get(&pattern).unwrap()) }, Err(err) => Err(DataFusionError::External(Box::new(err))), } From 6b433a839948c406a41128186e81572ec1fff689 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 24 Dec 2023 07:37:38 -0500 Subject: [PATCH 25/63] Minor: improve `listing_table_ignore_subdirectory` config documentation (#8634) * Minor: improve `listing_table_ignore_subdirectory` config documentation * update slt --- datafusion/common/src/config.rs | 8 ++++---- datafusion/sqllogictest/test_files/information_schema.slt | 2 +- docs/source/user-guide/configs.md | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index dedce74ff40d..5b1325ec06ee 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -273,11 +273,11 @@ config_namespace! { /// memory consumption pub max_buffered_batches_per_output_file: usize, default = 2 - /// When scanning file paths, whether to ignore subdirectory files, - /// ignored by default (true), when reading a partitioned table, - /// `listing_table_ignore_subdirectory` is always equal to false, even if set to true + /// Should sub directories be ignored when scanning directories for data + /// files. Defaults to true (ignores subdirectories), consistent with + /// Hive. Note that this setting does not affect reading partitioned + /// tables (e.g. `/table/year=2021/month=01/data.parquet`). pub listing_table_ignore_subdirectory: bool, default = true - } } diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 36876beb1447..1b5ad86546a3 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -225,7 +225,7 @@ datafusion.execution.aggregate.scalar_update_factor 10 Specifies the threshold f datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files -datafusion.execution.listing_table_ignore_subdirectory true When scanning file paths, whether to ignore subdirectory files, ignored by default (true), when reading a partitioned table, `listing_table_ignore_subdirectory` is always equal to false, even if set to true +datafusion.execution.listing_table_ignore_subdirectory true Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 1f7fa7760b94..0a5c221c5034 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -82,7 +82,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | | datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | | datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | -| datafusion.execution.listing_table_ignore_subdirectory | true | When scanning file paths, whether to ignore subdirectory files, ignored by default (true), when reading a partitioned table, `listing_table_ignore_subdirectory` is always equal to false, even if set to true | +| datafusion.execution.listing_table_ignore_subdirectory | true | Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). | | datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | | datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | | datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | From d5704f75fc28f88632518ef9a808c9cda38dc162 Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Sun, 24 Dec 2023 07:46:26 -0500 Subject: [PATCH 26/63] Support Writing Arrow files (#8608) * write arrow files * update datafusion-cli lock * fix toml formatting * Update insert_to_external.slt Co-authored-by: Andrew Lamb * add ticket tracking arrow options * default to lz4 compression * update datafusion-cli lock * cargo update --------- Co-authored-by: Andrew Lamb --- Cargo.toml | 28 +-- datafusion-cli/Cargo.lock | 56 ++--- datafusion/core/Cargo.toml | 1 + .../core/src/datasource/file_format/arrow.rs | 207 +++++++++++++++++- .../src/datasource/file_format/parquet.rs | 34 +-- .../src/datasource/file_format/write/mod.rs | 33 ++- datafusion/sqllogictest/test_files/copy.slt | 56 +++++ .../test_files/insert_to_external.slt | 39 ++++ 8 files changed, 368 insertions(+), 86 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 023dc6c6fc4f..a698fbf471f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,24 +17,7 @@ [workspace] exclude = ["datafusion-cli"] -members = [ - "datafusion/common", - "datafusion/core", - "datafusion/expr", - "datafusion/execution", - "datafusion/optimizer", - "datafusion/physical-expr", - "datafusion/physical-plan", - "datafusion/proto", - "datafusion/proto/gen", - "datafusion/sql", - "datafusion/sqllogictest", - "datafusion/substrait", - "datafusion/wasmtest", - "datafusion-examples", - "docs", - "test-utils", - "benchmarks", +members = ["datafusion/common", "datafusion/core", "datafusion/expr", "datafusion/execution", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", "datafusion/wasmtest", "datafusion-examples", "docs", "test-utils", "benchmarks", ] resolver = "2" @@ -53,24 +36,26 @@ arrow = { version = "49.0.0", features = ["prettyprint"] } arrow-array = { version = "49.0.0", default-features = false, features = ["chrono-tz"] } arrow-buffer = { version = "49.0.0", default-features = false } arrow-flight = { version = "49.0.0", features = ["flight-sql-experimental"] } +arrow-ipc = { version = "49.0.0", default-features = false, features=["lz4"] } arrow-ord = { version = "49.0.0", default-features = false } arrow-schema = { version = "49.0.0", default-features = false } async-trait = "0.1.73" bigdecimal = "0.4.1" bytes = "1.4" +chrono = { version = "0.4.31", default-features = false } ctor = "0.2.0" +dashmap = "5.4.0" datafusion = { path = "datafusion/core", version = "34.0.0" } datafusion-common = { path = "datafusion/common", version = "34.0.0" } +datafusion-execution = { path = "datafusion/execution", version = "34.0.0" } datafusion-expr = { path = "datafusion/expr", version = "34.0.0" } -datafusion-sql = { path = "datafusion/sql", version = "34.0.0" } datafusion-optimizer = { path = "datafusion/optimizer", version = "34.0.0" } datafusion-physical-expr = { path = "datafusion/physical-expr", version = "34.0.0" } datafusion-physical-plan = { path = "datafusion/physical-plan", version = "34.0.0" } -datafusion-execution = { path = "datafusion/execution", version = "34.0.0" } datafusion-proto = { path = "datafusion/proto", version = "34.0.0" } +datafusion-sql = { path = "datafusion/sql", version = "34.0.0" } datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "34.0.0" } datafusion-substrait = { path = "datafusion/substrait", version = "34.0.0" } -dashmap = "5.4.0" doc-comment = "0.3" env_logger = "0.10" futures = "0.3" @@ -88,7 +73,6 @@ serde_json = "1" sqlparser = { version = "0.40.0", features = ["visitor"] } tempfile = "3" thiserror = "1.0.44" -chrono = { version = "0.4.31", default-features = false } url = "2.2" [profile.release] diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index ac05ddf10a73..9f75013c86dc 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -255,6 +255,7 @@ dependencies = [ "arrow-data", "arrow-schema", "flatbuffers", + "lz4_flex", ] [[package]] @@ -378,13 +379,13 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.74" +version = "0.1.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" +checksum = "fdf6721fb0140e4f897002dd086c06f6c27775df19cfe1fccb21181a48fd2c98" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -1074,7 +1075,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30d2b3721e861707777e3195b0158f950ae6dc4a27e4d02ff9f67e3eb3de199e" dependencies = [ "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -1104,6 +1105,7 @@ dependencies = [ "apache-avro", "arrow", "arrow-array", + "arrow-ipc", "arrow-schema", "async-compression", "async-trait", @@ -1576,7 +1578,7 @@ checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -2496,7 +2498,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -2513,9 +2515,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.27" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" [[package]] name = "powerfmt" @@ -2586,9 +2588,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.70" +version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" +checksum = "75cb1540fadbd5b8fbccc4dddad2734eba435053f725621c070711a14bb5f4b8" dependencies = [ "unicode-ident", ] @@ -3020,7 +3022,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -3186,7 +3188,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -3208,9 +3210,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.41" +version = "2.0.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44c8b28c477cc3bf0e7966561e3460130e1255f7a1cf71931075f1c5e7a7e269" +checksum = "5b7d0a2c048d661a1a59fcd7355baa232f7ed34e0ee4df2eef3c1c1c0d3852d8" dependencies = [ "proc-macro2", "quote", @@ -3289,7 +3291,7 @@ checksum = "01742297787513b79cf8e29d1056ede1313e2420b7b3b15d0a768b4921f549df" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -3357,9 +3359,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.35.0" +version = "1.35.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841d45b238a16291a4e1584e61820b8ae57d696cc5015c459c229ccc6990cc1c" +checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" dependencies = [ "backtrace", "bytes", @@ -3381,7 +3383,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -3478,7 +3480,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -3523,7 +3525,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -3677,7 +3679,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", "wasm-bindgen-shared", ] @@ -3711,7 +3713,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3960,22 +3962,22 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.31" +version = "0.7.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c4061bedbb353041c12f413700357bec76df2c7e2ca8e4df8bac24c6bf68e3d" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.31" +version = "0.7.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3c129550b3e6de3fd0ba67ba5c81818f9805e58b8d7fee80a3a59d2c9fc601a" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 0ee83e756745..9de6a7f7d6a0 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -55,6 +55,7 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] apache-avro = { version = "0.16", optional = true } arrow = { workspace = true } arrow-array = { workspace = true } +arrow-ipc = { workspace = true } arrow-schema = { workspace = true } async-compression = { version = "0.4.0", features = ["bzip2", "gzip", "xz", "zstd", "futures-io", "tokio"], optional = true } async-trait = { workspace = true } diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 07c96bdae1b4..7d393d9129dd 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -21,10 +21,13 @@ use std::any::Any; use std::borrow::Cow; +use std::fmt::{self, Debug}; use std::sync::Arc; use crate::datasource::file_format::FileFormat; -use crate::datasource::physical_plan::{ArrowExec, FileScanConfig}; +use crate::datasource::physical_plan::{ + ArrowExec, FileGroupDisplay, FileScanConfig, FileSinkConfig, +}; use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::ExecutionPlan; @@ -32,16 +35,28 @@ use crate::physical_plan::ExecutionPlan; use arrow::ipc::convert::fb_to_schema; use arrow::ipc::reader::FileReader; use arrow::ipc::root_as_message; +use arrow_ipc::writer::IpcWriteOptions; +use arrow_ipc::CompressionType; use arrow_schema::{ArrowError, Schema, SchemaRef}; use bytes::Bytes; -use datafusion_common::{FileType, Statistics}; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_common::{not_impl_err, DataFusionError, FileType, Statistics}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use crate::physical_plan::{DisplayAs, DisplayFormatType}; use async_trait::async_trait; +use datafusion_physical_plan::insert::{DataSink, FileSinkExec}; +use datafusion_physical_plan::metrics::MetricsSet; use futures::stream::BoxStream; use futures::StreamExt; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +use tokio::io::AsyncWriteExt; +use tokio::task::JoinSet; + +use super::file_compression_type::FileCompressionType; +use super::write::demux::start_demuxer_task; +use super::write::{create_writer, SharedBuffer}; /// Arrow `FileFormat` implementation. #[derive(Default, Debug)] @@ -97,11 +112,197 @@ impl FileFormat for ArrowFormat { Ok(Arc::new(exec)) } + async fn create_writer_physical_plan( + &self, + input: Arc, + _state: &SessionState, + conf: FileSinkConfig, + order_requirements: Option>, + ) -> Result> { + if conf.overwrite { + return not_impl_err!("Overwrites are not implemented yet for Arrow format"); + } + + let sink_schema = conf.output_schema().clone(); + let sink = Arc::new(ArrowFileSink::new(conf)); + + Ok(Arc::new(FileSinkExec::new( + input, + sink, + sink_schema, + order_requirements, + )) as _) + } + fn file_type(&self) -> FileType { FileType::ARROW } } +/// Implements [`DataSink`] for writing to arrow_ipc files +struct ArrowFileSink { + config: FileSinkConfig, +} + +impl ArrowFileSink { + fn new(config: FileSinkConfig) -> Self { + Self { config } + } + + /// Converts table schema to writer schema, which may differ in the case + /// of hive style partitioning where some columns are removed from the + /// underlying files. + fn get_writer_schema(&self) -> Arc { + if !self.config.table_partition_cols.is_empty() { + let schema = self.config.output_schema(); + let partition_names: Vec<_> = self + .config + .table_partition_cols + .iter() + .map(|(s, _)| s) + .collect(); + Arc::new(Schema::new( + schema + .fields() + .iter() + .filter(|f| !partition_names.contains(&f.name())) + .map(|f| (**f).clone()) + .collect::>(), + )) + } else { + self.config.output_schema().clone() + } + } +} + +impl Debug for ArrowFileSink { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ArrowFileSink").finish() + } +} + +impl DisplayAs for ArrowFileSink { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ArrowFileSink(file_groups=",)?; + FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; + write!(f, ")") + } + } + } +} + +#[async_trait] +impl DataSink for ArrowFileSink { + fn as_any(&self) -> &dyn Any { + self + } + + fn metrics(&self) -> Option { + None + } + + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + // No props are supported yet, but can be by updating FileTypeWriterOptions + // to populate this struct and use those options to initialize the arrow_ipc::writer::FileWriter + // https://github.com/apache/arrow-datafusion/issues/8635 + let _arrow_props = self.config.file_type_writer_options.try_into_arrow()?; + + let object_store = context + .runtime_env() + .object_store(&self.config.object_store_url)?; + + let part_col = if !self.config.table_partition_cols.is_empty() { + Some(self.config.table_partition_cols.clone()) + } else { + None + }; + + let (demux_task, mut file_stream_rx) = start_demuxer_task( + data, + context, + part_col, + self.config.table_paths[0].clone(), + "arrow".into(), + self.config.single_file_output, + ); + + let mut file_write_tasks: JoinSet> = + JoinSet::new(); + + let ipc_options = + IpcWriteOptions::try_new(64, false, arrow_ipc::MetadataVersion::V5)? + .try_with_compression(Some(CompressionType::LZ4_FRAME))?; + while let Some((path, mut rx)) = file_stream_rx.recv().await { + let shared_buffer = SharedBuffer::new(1048576); + let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options( + shared_buffer.clone(), + &self.get_writer_schema(), + ipc_options.clone(), + )?; + let mut object_store_writer = create_writer( + FileCompressionType::UNCOMPRESSED, + &path, + object_store.clone(), + ) + .await?; + file_write_tasks.spawn(async move { + let mut row_count = 0; + while let Some(batch) = rx.recv().await { + row_count += batch.num_rows(); + arrow_writer.write(&batch)?; + let mut buff_to_flush = shared_buffer.buffer.try_lock().unwrap(); + if buff_to_flush.len() > 1024000 { + object_store_writer + .write_all(buff_to_flush.as_slice()) + .await?; + buff_to_flush.clear(); + } + } + arrow_writer.finish()?; + let final_buff = shared_buffer.buffer.try_lock().unwrap(); + + object_store_writer.write_all(final_buff.as_slice()).await?; + object_store_writer.shutdown().await?; + Ok(row_count) + }); + } + + let mut row_count = 0; + while let Some(result) = file_write_tasks.join_next().await { + match result { + Ok(r) => { + row_count += r?; + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + match demux_task.await { + Ok(r) => r?, + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + Ok(row_count as u64) + } +} + const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 9db320fb9da4..0c813b6ccbf0 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -29,7 +29,6 @@ use parquet::file::writer::SerializedFileWriter; use std::any::Any; use std::fmt; use std::fmt::Debug; -use std::io::Write; use std::sync::Arc; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver, Sender}; @@ -56,7 +55,7 @@ use parquet::file::properties::WriterProperties; use parquet::file::statistics::Statistics as ParquetStatistics; use super::write::demux::start_demuxer_task; -use super::write::{create_writer, AbortableWrite}; +use super::write::{create_writer, AbortableWrite, SharedBuffer}; use super::{FileFormat, FileScanConfig}; use crate::arrow::array::{ BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, @@ -1101,37 +1100,6 @@ async fn output_single_parquet_file_parallelized( Ok(row_count) } -/// A buffer with interior mutability shared by the SerializedFileWriter and -/// ObjectStore writer -#[derive(Clone)] -struct SharedBuffer { - /// The inner buffer for reading and writing - /// - /// The lock is used to obtain internal mutability, so no worry about the - /// lock contention. - buffer: Arc>>, -} - -impl SharedBuffer { - pub fn new(capacity: usize) -> Self { - Self { - buffer: Arc::new(futures::lock::Mutex::new(Vec::with_capacity(capacity))), - } - } -} - -impl Write for SharedBuffer { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - let mut buffer = self.buffer.try_lock().unwrap(); - Write::write(&mut *buffer, buf) - } - - fn flush(&mut self) -> std::io::Result<()> { - let mut buffer = self.buffer.try_lock().unwrap(); - Write::flush(&mut *buffer) - } -} - #[cfg(test)] pub(crate) mod test_util { use super::*; diff --git a/datafusion/core/src/datasource/file_format/write/mod.rs b/datafusion/core/src/datasource/file_format/write/mod.rs index cfcdbd8c464e..68fe81ce91fa 100644 --- a/datafusion/core/src/datasource/file_format/write/mod.rs +++ b/datafusion/core/src/datasource/file_format/write/mod.rs @@ -18,7 +18,7 @@ //! Module containing helper methods/traits related to enabling //! write support for the various file formats -use std::io::Error; +use std::io::{Error, Write}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -43,6 +43,37 @@ use tokio::io::AsyncWrite; pub(crate) mod demux; pub(crate) mod orchestration; +/// A buffer with interior mutability shared by the SerializedFileWriter and +/// ObjectStore writer +#[derive(Clone)] +pub(crate) struct SharedBuffer { + /// The inner buffer for reading and writing + /// + /// The lock is used to obtain internal mutability, so no worry about the + /// lock contention. + pub(crate) buffer: Arc>>, +} + +impl SharedBuffer { + pub fn new(capacity: usize) -> Self { + Self { + buffer: Arc::new(futures::lock::Mutex::new(Vec::with_capacity(capacity))), + } + } +} + +impl Write for SharedBuffer { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut buffer = self.buffer.try_lock().unwrap(); + Write::write(&mut *buffer, buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + let mut buffer = self.buffer.try_lock().unwrap(); + Write::flush(&mut *buffer) + } +} + /// Stores data needed during abortion of MultiPart writers #[derive(Clone)] pub(crate) struct MultiPart { diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 02ab33083315..89b23917884c 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -230,6 +230,62 @@ select * from validate_csv_with_options; 1;Foo 2;Bar +# Copy from table to single arrow file +query IT +COPY source_table to 'test_files/scratch/copy/table.arrow'; +---- +2 + +# Validate single csv output +statement ok +CREATE EXTERNAL TABLE validate_arrow_file +STORED AS arrow +LOCATION 'test_files/scratch/copy/table.arrow'; + +query IT +select * from validate_arrow_file; +---- +1 Foo +2 Bar + +# Copy from dict encoded values to single arrow file +query T? +COPY (values +('c', arrow_cast('foo', 'Dictionary(Int32, Utf8)')), ('d', arrow_cast('bar', 'Dictionary(Int32, Utf8)'))) +to 'test_files/scratch/copy/table_dict.arrow'; +---- +2 + +# Validate single csv output +statement ok +CREATE EXTERNAL TABLE validate_arrow_file_dict +STORED AS arrow +LOCATION 'test_files/scratch/copy/table_dict.arrow'; + +query T? +select * from validate_arrow_file_dict; +---- +c foo +d bar + + +# Copy from table to folder of json +query IT +COPY source_table to 'test_files/scratch/copy/table_arrow' (format arrow, single_file_output false); +---- +2 + +# Validate json output +statement ok +CREATE EXTERNAL TABLE validate_arrow STORED AS arrow LOCATION 'test_files/scratch/copy/table_arrow'; + +query IT +select * from validate_arrow; +---- +1 Foo +2 Bar + + # Error cases: # Copy from table with options diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index cdaf0bb64339..e73778ad44e5 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -76,6 +76,45 @@ select * from dictionary_encoded_parquet_partitioned order by (a); a foo b bar +statement ok +CREATE EXTERNAL TABLE dictionary_encoded_arrow_partitioned( + a varchar, + b varchar, +) +STORED AS arrow +LOCATION 'test_files/scratch/insert_to_external/arrow_dict_partitioned/' +PARTITIONED BY (b) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query TT +insert into dictionary_encoded_arrow_partitioned +select * from dictionary_encoded_values +---- +2 + +statement ok +CREATE EXTERNAL TABLE dictionary_encoded_arrow_test_readback( + a varchar, +) +STORED AS arrow +LOCATION 'test_files/scratch/insert_to_external/arrow_dict_partitioned/b=bar/' +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query T +select * from dictionary_encoded_arrow_test_readback; +---- +b + +# https://github.com/apache/arrow-datafusion/issues/7816 +query error DataFusion error: Arrow error: Schema error: project index 1 out of bounds, max field 1 +select * from dictionary_encoded_arrow_partitioned order by (a); + # test_insert_into statement ok From 3698693fab040dfb077edaf763b6935e9f42ea06 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Mon, 25 Dec 2023 10:43:52 +0300 Subject: [PATCH 27/63] Filter pushdown into cross join (#8626) * Initial commit * Simplifications * Review * Review Part 2 * More idiomatic Rust --------- Co-authored-by: Mehmet Ozan Kabak --- .../optimizer/src/eliminate_cross_join.rs | 128 ++++++++++-------- datafusion/optimizer/src/push_down_filter.rs | 89 ++++++++---- datafusion/sqllogictest/test_files/joins.slt | 17 +++ 3 files changed, 152 insertions(+), 82 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index cf9a59d6b892..7c866950a622 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -20,6 +20,7 @@ use std::collections::HashSet; use std::sync::Arc; use crate::{utils, OptimizerConfig, OptimizerRule}; + use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ @@ -47,81 +48,93 @@ impl EliminateCrossJoin { /// For above queries, the join predicate is available in filters and they are moved to /// join nodes appropriately /// This fix helps to improve the performance of TPCH Q19. issue#78 -/// impl OptimizerRule for EliminateCrossJoin { fn try_optimize( &self, plan: &LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - match plan { + let mut possible_join_keys: Vec<(Expr, Expr)> = vec![]; + let mut all_inputs: Vec = vec![]; + let parent_predicate = match plan { LogicalPlan::Filter(filter) => { - let input = filter.input.as_ref().clone(); - - let mut possible_join_keys: Vec<(Expr, Expr)> = vec![]; - let mut all_inputs: Vec = vec![]; - let did_flat_successfully = match &input { + let input = filter.input.as_ref(); + match input { LogicalPlan::Join(Join { join_type: JoinType::Inner, .. }) - | LogicalPlan::CrossJoin(_) => try_flatten_join_inputs( - &input, - &mut possible_join_keys, - &mut all_inputs, - )?, + | LogicalPlan::CrossJoin(_) => { + if !try_flatten_join_inputs( + input, + &mut possible_join_keys, + &mut all_inputs, + )? { + return Ok(None); + } + extract_possible_join_keys( + &filter.predicate, + &mut possible_join_keys, + )?; + Some(&filter.predicate) + } _ => { return utils::optimize_children(self, plan, config); } - }; - - if !did_flat_successfully { + } + } + LogicalPlan::Join(Join { + join_type: JoinType::Inner, + .. + }) => { + if !try_flatten_join_inputs( + plan, + &mut possible_join_keys, + &mut all_inputs, + )? { return Ok(None); } + None + } + _ => return utils::optimize_children(self, plan, config), + }; - let predicate = &filter.predicate; - // join keys are handled locally - let mut all_join_keys: HashSet<(Expr, Expr)> = HashSet::new(); - - extract_possible_join_keys(predicate, &mut possible_join_keys)?; + // Join keys are handled locally: + let mut all_join_keys = HashSet::<(Expr, Expr)>::new(); + let mut left = all_inputs.remove(0); + while !all_inputs.is_empty() { + left = find_inner_join( + &left, + &mut all_inputs, + &mut possible_join_keys, + &mut all_join_keys, + )?; + } - let mut left = all_inputs.remove(0); - while !all_inputs.is_empty() { - left = find_inner_join( - &left, - &mut all_inputs, - &mut possible_join_keys, - &mut all_join_keys, - )?; - } + left = utils::optimize_children(self, &left, config)?.unwrap_or(left); - left = utils::optimize_children(self, &left, config)?.unwrap_or(left); + if plan.schema() != left.schema() { + left = LogicalPlan::Projection(Projection::new_from_schema( + Arc::new(left), + plan.schema().clone(), + )); + } - if plan.schema() != left.schema() { - left = LogicalPlan::Projection(Projection::new_from_schema( - Arc::new(left.clone()), - plan.schema().clone(), - )); - } + let Some(predicate) = parent_predicate else { + return Ok(Some(left)); + }; - // if there are no join keys then do nothing. - if all_join_keys.is_empty() { - Ok(Some(LogicalPlan::Filter(Filter::try_new( - predicate.clone(), - Arc::new(left), - )?))) - } else { - // remove join expressions from filter - match remove_join_expressions(predicate, &all_join_keys)? { - Some(filter_expr) => Ok(Some(LogicalPlan::Filter( - Filter::try_new(filter_expr, Arc::new(left))?, - ))), - _ => Ok(Some(left)), - } - } + // If there are no join keys then do nothing: + if all_join_keys.is_empty() { + Filter::try_new(predicate.clone(), Arc::new(left)) + .map(|f| Some(LogicalPlan::Filter(f))) + } else { + // Remove join expressions from filter: + match remove_join_expressions(predicate, &all_join_keys)? { + Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left)) + .map(|f| Some(LogicalPlan::Filter(f))), + _ => Ok(Some(left)), } - - _ => utils::optimize_children(self, plan, config), } } @@ -325,17 +338,16 @@ fn remove_join_expressions( #[cfg(test)] mod tests { + use super::*; + use crate::optimizer::OptimizerContext; + use crate::test::*; + use datafusion_expr::{ binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Operator::{And, Or}, }; - use crate::optimizer::OptimizerContext; - use crate::test::*; - - use super::*; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: Vec<&str>) { let rule = EliminateCrossJoin::new(); let optimized_plan = rule diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 4bea17500acc..4eed39a08941 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -15,25 +15,29 @@ //! [`PushDownFilter`] Moves filters so they are applied as early as possible in //! the plan. +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::{ - internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result, + internal_err, plan_datafusion_err, Column, DFSchema, DFSchemaRef, DataFusionError, + JoinConstraint, Result, }; use datafusion_expr::expr::Alias; +use datafusion_expr::expr_rewriter::replace_col; +use datafusion_expr::logical_plan::{ + CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union, +}; use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned}; -use datafusion_expr::Volatility; use datafusion_expr::{ - and, - expr_rewriter::replace_col, - logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union}, - or, BinaryExpr, Expr, Filter, Operator, ScalarFunctionDefinition, - TableProviderFilterPushDown, + and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator, + ScalarFunctionDefinition, TableProviderFilterPushDown, Volatility, }; + use itertools::Itertools; -use std::collections::{HashMap, HashSet}; -use std::sync::Arc; /// Optimizer rule for pushing (moving) filter expressions down in a plan so /// they are applied as early as possible. @@ -848,17 +852,23 @@ impl OptimizerRule for PushDownFilter { None => return Ok(None), } } - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { + LogicalPlan::CrossJoin(cross_join) => { let predicates = split_conjunction_owned(filter.predicate.clone()); - push_down_all_join( + let join = convert_cross_join_to_inner_join(cross_join.clone())?; + let join_plan = LogicalPlan::Join(join); + let inputs = join_plan.inputs(); + let left = inputs[0]; + let right = inputs[1]; + let plan = push_down_all_join( predicates, vec![], - &filter.input, + &join_plan, left, right, vec![], - false, - )? + true, + )?; + convert_to_cross_join_if_beneficial(plan)? } LogicalPlan::TableScan(scan) => { let filter_predicates = split_conjunction(&filter.predicate); @@ -955,6 +965,36 @@ impl PushDownFilter { } } +/// Convert cross join to join by pushing down filter predicate to the join condition +fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { + let CrossJoin { left, right, .. } = cross_join; + let join_schema = build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; + // predicate is given + Ok(Join { + left, + right, + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + on: vec![], + filter: None, + schema: DFSchemaRef::new(join_schema), + null_equals_null: true, + }) +} + +/// Converts the inner join with empty equality predicate and empty filter condition to the cross join +fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result { + if let LogicalPlan::Join(join) = &plan { + // Can be converted back to cross join + if join.on.is_empty() && join.filter.is_none() { + return LogicalPlanBuilder::from(join.left.as_ref().clone()) + .cross_join(join.right.as_ref().clone())? + .build(); + } + } + Ok(plan) +} + /// replaces columns by its name on the projection. pub fn replace_cols_by_name( e: Expr, @@ -1026,13 +1066,16 @@ fn contain(e: &Expr, check_map: &HashMap) -> bool { #[cfg(test)] mod tests { + use std::fmt::{Debug, Formatter}; + use std::sync::Arc; + use super::*; use crate::optimizer::Optimizer; use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::test::*; use crate::OptimizerContext; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use async_trait::async_trait; use datafusion_common::{DFSchema, DFSchemaRef}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ @@ -1040,8 +1083,8 @@ mod tests { BinaryExpr, Expr, Extension, LogicalPlanBuilder, Operator, TableSource, TableType, UserDefinedLogicalNodeCore, }; - use std::fmt::{Debug, Formatter}; - use std::sync::Arc; + + use async_trait::async_trait; fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { crate::test::assert_optimized_plan_eq( @@ -2665,14 +2708,12 @@ Projection: a, b .cross_join(right)? .filter(filter)? .build()?; - let expected = "\ - Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ - \n CrossJoin:\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ - \n Projection: test1.a AS d, test1.a AS e\ - \n TableScan: test1"; + Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ + \n Projection: test.a, test.b, test.c\ + \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ + \n Projection: test1.a AS d, test1.a AS e\ + \n TableScan: test1"; assert_optimized_plan_eq_with_rewrite_predicate(&plan, expected)?; // Originally global state which can help to avoid duplicate Filters been generated and pushed down. diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 1ad17fbb8c91..eee213811f44 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3466,6 +3466,23 @@ SortPreservingMergeExec: [a@0 ASC] ----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST], has_header=true +query TT +EXPLAIN SELECT * +FROM annotated_data as l, annotated_data as r +WHERE l.a > r.a +---- +logical_plan +Inner Join: Filter: l.a > r.a +--SubqueryAlias: l +----TableScan: annotated_data projection=[a0, a, b, c, d] +--SubqueryAlias: r +----TableScan: annotated_data projection=[a0, a, b, c, d] +physical_plan +NestedLoopJoinExec: join_type=Inner, filter=a@0 > a@1 +--RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + #### # Config teardown #### From 18c75669e18929ca095c47af4ebf285b14d2c814 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Metehan=20Y=C4=B1ld=C4=B1r=C4=B1m?= <100111937+metesynnada@users.noreply.github.com> Date: Mon, 25 Dec 2023 23:12:51 +0300 Subject: [PATCH 28/63] [MINOR] Remove duplicate test utility and move one utility function for better organization (#8652) * Code rearrange * Update stream_join_utils.rs --- .../src/joins/stream_join_utils.rs | 156 +++++++++++------- .../src/joins/symmetric_hash_join.rs | 11 +- datafusion/physical-plan/src/joins/utils.rs | 90 +--------- 3 files changed, 104 insertions(+), 153 deletions(-) diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 50b1618a35dd..9a4c98927683 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -25,23 +25,25 @@ use std::usize; use crate::joins::utils::{JoinFilter, JoinHashMapType, StatefulStreamResult}; use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; -use crate::{handle_async_state, handle_state, metrics}; +use crate::{handle_async_state, handle_state, metrics, ExecutionPlan}; use arrow::compute::concat_batches; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; use arrow_schema::{Schema, SchemaRef}; -use async_trait::async_trait; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ - arrow_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue, + arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result, + ScalarValue, }; use datafusion_execution::SendableRecordBatchStream; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use async_trait::async_trait; use futures::{ready, FutureExt, StreamExt}; use hashbrown::raw::RawTable; use hashbrown::HashSet; @@ -175,7 +177,7 @@ impl PruningJoinHashMap { prune_length: usize, deleting_offset: u64, shrink_factor: usize, - ) -> Result<()> { + ) { // Remove elements from the list based on the pruning length. self.next.drain(0..prune_length); @@ -198,11 +200,10 @@ impl PruningJoinHashMap { // Shrink the map if necessary. self.shrink_if_necessary(shrink_factor); - Ok(()) } } -pub fn check_filter_expr_contains_sort_information( +fn check_filter_expr_contains_sort_information( expr: &Arc, reference: &Arc, ) -> bool { @@ -227,7 +228,7 @@ pub fn map_origin_col_to_filter_col( side: &JoinSide, ) -> Result> { let filter_schema = filter.schema(); - let mut col_to_col_map: HashMap = HashMap::new(); + let mut col_to_col_map = HashMap::::new(); for (filter_schema_index, index) in filter.column_indices().iter().enumerate() { if index.side.eq(side) { // Get the main field from column index: @@ -581,7 +582,7 @@ where // get the semi index (0..prune_length) .filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) - .collect::>() + .collect() } pub fn combine_two_batches( @@ -763,7 +764,6 @@ pub trait EagerJoinStream { if batch.num_rows() == 0 { return Ok(StatefulStreamResult::Continue); } - self.set_state(EagerJoinStreamState::PullLeft); self.process_batch_from_right(batch) } @@ -1032,6 +1032,91 @@ impl StreamJoinMetrics { } } +/// Updates sorted filter expressions with corresponding node indices from the +/// expression interval graph. +/// +/// This function iterates through the provided sorted filter expressions, +/// gathers the corresponding node indices from the expression interval graph, +/// and then updates the sorted expressions with these indices. It ensures +/// that these sorted expressions are aligned with the structure of the graph. +fn update_sorted_exprs_with_node_indices( + graph: &mut ExprIntervalGraph, + sorted_exprs: &mut [SortedFilterExpr], +) { + // Extract filter expressions from the sorted expressions: + let filter_exprs = sorted_exprs + .iter() + .map(|expr| expr.filter_expr().clone()) + .collect::>(); + + // Gather corresponding node indices for the extracted filter expressions from the graph: + let child_node_indices = graph.gather_node_indices(&filter_exprs); + + // Iterate through the sorted expressions and the gathered node indices: + for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices) { + // Update each sorted expression with the corresponding node index: + sorted_expr.set_node_index(index); + } +} + +/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. +/// +/// # Arguments +/// +/// * `filter` - The join filter to base the sorting on. +/// * `left` - The left execution plan. +/// * `right` - The right execution plan. +/// * `left_sort_exprs` - The expressions to sort on the left side. +/// * `right_sort_exprs` - The expressions to sort on the right side. +/// +/// # Returns +/// +/// * A tuple consisting of the sorted filter expression for the left and right sides, and an expression interval graph. +pub fn prepare_sorted_exprs( + filter: &JoinFilter, + left: &Arc, + right: &Arc, + left_sort_exprs: &[PhysicalSortExpr], + right_sort_exprs: &[PhysicalSortExpr], +) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { + // Build the filter order for the left side + let err = || plan_datafusion_err!("Filter does not include the child order"); + + let left_temp_sorted_filter_expr = build_filter_input_order( + JoinSide::Left, + filter, + &left.schema(), + &left_sort_exprs[0], + )? + .ok_or_else(err)?; + + // Build the filter order for the right side + let right_temp_sorted_filter_expr = build_filter_input_order( + JoinSide::Right, + filter, + &right.schema(), + &right_sort_exprs[0], + )? + .ok_or_else(err)?; + + // Collect the sorted expressions + let mut sorted_exprs = + vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr]; + + // Build the expression interval graph + let mut graph = + ExprIntervalGraph::try_new(filter.expression().clone(), filter.schema())?; + + // Update sorted expressions with node indices + update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs); + + // Swap and remove to get the final sorted filter expressions + let right_sorted_filter_expr = sorted_exprs.swap_remove(1); + let left_sorted_filter_expr = sorted_exprs.swap_remove(0); + + Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph)) +} + #[cfg(test)] pub mod tests { use std::sync::Arc; @@ -1043,62 +1128,15 @@ pub mod tests { }; use crate::{ expressions::{Column, PhysicalSortExpr}, + joins::test_utils::complicated_filter, joins::utils::{ColumnIndex, JoinFilter}, }; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{JoinSide, ScalarValue}; + use datafusion_common::JoinSide; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{binary, cast, col, lit}; - - /// Filter expr for a + b > c + 10 AND a + b < c + 100 - pub(crate) fn complicated_filter( - filter_schema: &Schema, - ) -> Result> { - let left_expr = binary( - cast( - binary( - col("0", filter_schema)?, - Operator::Plus, - col("1", filter_schema)?, - filter_schema, - )?, - filter_schema, - DataType::Int64, - )?, - Operator::Gt, - binary( - cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, - Operator::Plus, - lit(ScalarValue::Int64(Some(10))), - filter_schema, - )?, - filter_schema, - )?; - - let right_expr = binary( - cast( - binary( - col("0", filter_schema)?, - Operator::Plus, - col("1", filter_schema)?, - filter_schema, - )?, - filter_schema, - DataType::Int64, - )?, - Operator::Lt, - binary( - cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, - Operator::Plus, - lit(ScalarValue::Int64(Some(100))), - filter_schema, - )?, - filter_schema, - )?; - binary(left_expr, Operator::And, right_expr, filter_schema) - } + use datafusion_physical_expr::expressions::{binary, cast, col}; #[test] fn test_column_exchange() -> Result<()> { diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index b9101b57c3e5..f071a7f6015a 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -36,13 +36,14 @@ use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash}; use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, convert_sort_expr_with_filter_schema, get_pruning_anti_indices, - get_pruning_semi_indices, record_visited_indices, EagerJoinStream, - EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics, + get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices, + EagerJoinStream, EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, + StreamJoinMetrics, }; use crate::joins::utils::{ build_batch_from_indices, build_join_schema, check_join_is_valid, - partitioned_join_output_partitioning, prepare_sorted_exprs, ColumnIndex, JoinFilter, - JoinOn, StatefulStreamResult, + partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn, + StatefulStreamResult, }; use crate::{ expressions::{Column, PhysicalSortExpr}, @@ -936,7 +937,7 @@ impl OneSideHashJoiner { prune_length, self.deleted_offset as u64, HASHMAP_SHRINK_SCALE_FACTOR, - )?; + ); // Remove pruned rows from the visited rows set: for row in self.deleted_offset..(self.deleted_offset + prune_length) { self.visited_rows.remove(&row); diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index c902ba85f271..ac805b50e6a5 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -25,7 +25,6 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::usize; -use crate::joins::stream_join_utils::{build_filter_input_order, SortedFilterExpr}; use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics}; @@ -39,13 +38,11 @@ use arrow::record_batch::{RecordBatch, RecordBatchOptions}; use datafusion_common::cast::as_boolean_array; use datafusion_common::stats::Precision; use datafusion_common::{ - plan_datafusion_err, plan_err, DataFusionError, JoinSide, JoinType, Result, - SharedResult, + plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult, }; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::equivalence::add_offset_to_expr; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::utils::merge_vectors; use datafusion_physical_expr::{ LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalSortExpr, @@ -1208,91 +1205,6 @@ impl BuildProbeJoinMetrics { } } -/// Updates sorted filter expressions with corresponding node indices from the -/// expression interval graph. -/// -/// This function iterates through the provided sorted filter expressions, -/// gathers the corresponding node indices from the expression interval graph, -/// and then updates the sorted expressions with these indices. It ensures -/// that these sorted expressions are aligned with the structure of the graph. -fn update_sorted_exprs_with_node_indices( - graph: &mut ExprIntervalGraph, - sorted_exprs: &mut [SortedFilterExpr], -) { - // Extract filter expressions from the sorted expressions: - let filter_exprs = sorted_exprs - .iter() - .map(|expr| expr.filter_expr().clone()) - .collect::>(); - - // Gather corresponding node indices for the extracted filter expressions from the graph: - let child_node_indices = graph.gather_node_indices(&filter_exprs); - - // Iterate through the sorted expressions and the gathered node indices: - for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices) { - // Update each sorted expression with the corresponding node index: - sorted_expr.set_node_index(index); - } -} - -/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. -/// -/// # Arguments -/// -/// * `filter` - The join filter to base the sorting on. -/// * `left` - The left execution plan. -/// * `right` - The right execution plan. -/// * `left_sort_exprs` - The expressions to sort on the left side. -/// * `right_sort_exprs` - The expressions to sort on the right side. -/// -/// # Returns -/// -/// * A tuple consisting of the sorted filter expression for the left and right sides, and an expression interval graph. -pub fn prepare_sorted_exprs( - filter: &JoinFilter, - left: &Arc, - right: &Arc, - left_sort_exprs: &[PhysicalSortExpr], - right_sort_exprs: &[PhysicalSortExpr], -) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { - // Build the filter order for the left side - let err = || plan_datafusion_err!("Filter does not include the child order"); - - let left_temp_sorted_filter_expr = build_filter_input_order( - JoinSide::Left, - filter, - &left.schema(), - &left_sort_exprs[0], - )? - .ok_or_else(err)?; - - // Build the filter order for the right side - let right_temp_sorted_filter_expr = build_filter_input_order( - JoinSide::Right, - filter, - &right.schema(), - &right_sort_exprs[0], - )? - .ok_or_else(err)?; - - // Collect the sorted expressions - let mut sorted_exprs = - vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr]; - - // Build the expression interval graph - let mut graph = - ExprIntervalGraph::try_new(filter.expression().clone(), filter.schema())?; - - // Update sorted expressions with node indices - update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs); - - // Swap and remove to get the final sorted filter expressions - let right_sorted_filter_expr = sorted_exprs.swap_remove(1); - let left_sorted_filter_expr = sorted_exprs.swap_remove(0); - - Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph)) -} - /// The `handle_state` macro is designed to process the result of a state-changing /// operation, encountered e.g. in implementations of `EagerJoinStream`. It /// operates on a `StatefulStreamResult` by matching its variants and executing From ec8fd44594cada9cb0189f56ddf586ec48175ce0 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Tue, 26 Dec 2023 01:01:10 +0300 Subject: [PATCH 29/63] [MINOR]: Add new test for filter pushdown into cross join (#8648) * Initial commit * Minor changes * Simplifications * Update UDF example * Address review --------- Co-authored-by: Mehmet Ozan Kabak --- .../optimizer/src/eliminate_cross_join.rs | 1 + datafusion/optimizer/src/push_down_filter.rs | 12 +++- datafusion/sqllogictest/src/test_context.rs | 61 ++++++++++++++----- datafusion/sqllogictest/test_files/joins.slt | 22 +++++++ 4 files changed, 78 insertions(+), 18 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 7c866950a622..d9e96a9f2543 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -45,6 +45,7 @@ impl EliminateCrossJoin { /// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' /// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) /// or (a.x = b.y and b.xx = 200 and a.z=c.z);' +/// 'select ... from a, b where a.x > b.y' /// For above queries, the join predicate is available in filters and they are moved to /// join nodes appropriately /// This fix helps to improve the performance of TPCH Q19. issue#78 diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 4eed39a08941..9d277d18d2f7 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -965,11 +965,11 @@ impl PushDownFilter { } } -/// Convert cross join to join by pushing down filter predicate to the join condition +/// Converts the given cross join to an inner join with an empty equality +/// predicate and an empty filter condition. fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { let CrossJoin { left, right, .. } = cross_join; let join_schema = build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; - // predicate is given Ok(Join { left, right, @@ -982,7 +982,8 @@ fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { }) } -/// Converts the inner join with empty equality predicate and empty filter condition to the cross join +/// Converts the given inner join with an empty equality predicate and an +/// empty filter condition to a cross join. fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result { if let LogicalPlan::Join(join) = &plan { // Can be converted back to cross join @@ -991,6 +992,11 @@ fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result .cross_join(join.right.as_ref().clone())? .build(); } + } else if let LogicalPlan::Filter(filter) = &plan { + let new_input = + convert_to_cross_join_if_beneficial(filter.input.as_ref().clone())?; + return Filter::try_new(filter.predicate.clone(), Arc::new(new_input)) + .map(LogicalPlan::Filter); } Ok(plan) } diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 941dcb69d2f4..a5ce7ccb9fe0 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -15,31 +15,33 @@ // specific language governing permissions and limitations // under the License. -use async_trait::async_trait; +use std::collections::HashMap; +use std::fs::File; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, + StringArray, TimestampNanosecondArray, +}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use arrow::record_batch::RecordBatch; use datafusion::execution::context::SessionState; -use datafusion::logical_expr::Expr; +use datafusion::logical_expr::{create_udf, Expr, ScalarUDF, Volatility}; +use datafusion::physical_expr::functions::make_scalar_function; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionConfig; use datafusion::{ - arrow::{ - array::{ - BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, - StringArray, TimestampNanosecondArray, - }, - datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}, - record_batch::RecordBatch, - }, catalog::{schema::MemorySchemaProvider, CatalogProvider, MemoryCatalogProvider}, datasource::{MemTable, TableProvider, TableType}, prelude::{CsvReadOptions, SessionContext}, }; +use datafusion_common::cast::as_float64_array; use datafusion_common::DataFusionError; + +use async_trait::async_trait; use log::info; -use std::collections::HashMap; -use std::fs::File; -use std::io::Write; -use std::path::Path; -use std::sync::Arc; use tempfile::TempDir; /// Context for running tests @@ -102,6 +104,8 @@ impl TestContext { } "joins.slt" => { info!("Registering partition table tables"); + let example_udf = create_example_udf(); + test_ctx.ctx.register_udf(example_udf); register_partition_table(&mut test_ctx).await; } "metadata.slt" => { @@ -348,3 +352,30 @@ pub async fn register_metadata_tables(ctx: &SessionContext) { ctx.register_batch("table_with_metadata", batch).unwrap(); } + +/// Create a UDF function named "example". See the `sample_udf.rs` example +/// file for an explanation of the API. +fn create_example_udf() -> ScalarUDF { + let adder = make_scalar_function(|args: &[ArrayRef]| { + let lhs = as_float64_array(&args[0]).expect("cast failed"); + let rhs = as_float64_array(&args[1]).expect("cast failed"); + let array = lhs + .iter() + .zip(rhs.iter()) + .map(|(lhs, rhs)| match (lhs, rhs) { + (Some(lhs), Some(rhs)) => Some(lhs + rhs), + _ => None, + }) + .collect::(); + Ok(Arc::new(array) as ArrayRef) + }); + create_udf( + "example", + // Expects two f64 values: + vec![DataType::Float64, DataType::Float64], + // Returns an f64 value: + Arc::new(DataType::Float64), + Volatility::Immutable, + adder, + ) +} diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index eee213811f44..9a349f600091 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3483,6 +3483,28 @@ NestedLoopJoinExec: join_type=Inner, filter=a@0 > a@1 ----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true --CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +# Currently datafusion cannot pushdown filter conditions with scalar UDF into +# cross join. +query TT +EXPLAIN SELECT * +FROM annotated_data as t1, annotated_data as t2 +WHERE EXAMPLE(t1.a, t2.a) > 3 +---- +logical_plan +Filter: example(CAST(t1.a AS Float64), CAST(t2.a AS Float64)) > Float64(3) +--CrossJoin: +----SubqueryAlias: t1 +------TableScan: annotated_data projection=[a0, a, b, c, d] +----SubqueryAlias: t2 +------TableScan: annotated_data projection=[a0, a, b, c, d] +physical_plan +CoalesceBatchesExec: target_batch_size=2 +--FilterExec: example(CAST(a@1 AS Float64), CAST(a@6 AS Float64)) > 3 +----CrossJoinExec +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + #### # Config teardown #### From e10d3e2a0267c70bf36373c6811906e5b9b47703 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 26 Dec 2023 06:53:07 -0500 Subject: [PATCH 30/63] Rewrite bloom filters to use `contains` API (#8442) --- .../datasource/physical_plan/parquet/mod.rs | 1 + .../physical_plan/parquet/row_groups.rs | 245 +++++++----------- 2 files changed, 91 insertions(+), 155 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index ade149da6991..76a6cc297b0e 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -522,6 +522,7 @@ impl FileOpener for ParquetOpener { if enable_bloom_filter && !row_groups.is_empty() { if let Some(predicate) = predicate { row_groups = row_groups::prune_row_groups_by_bloom_filters( + &file_schema, &mut builder, &row_groups, file_metadata.row_groups(), diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 09e4907c9437..8a1abb7d965f 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -18,8 +18,7 @@ use arrow::{array::ArrayRef, datatypes::Schema}; use arrow_array::BooleanArray; use arrow_schema::FieldRef; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::{Column, DataFusionError, Result, ScalarValue}; +use datafusion_common::{Column, ScalarValue}; use parquet::file::metadata::ColumnChunkMetaData; use parquet::schema::types::SchemaDescriptor; use parquet::{ @@ -27,19 +26,13 @@ use parquet::{ bloom_filter::Sbbf, file::metadata::RowGroupMetaData, }; -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; +use std::collections::{HashMap, HashSet}; use crate::datasource::listing::FileRange; use crate::datasource::physical_plan::parquet::statistics::{ max_statistics, min_statistics, parquet_column, }; -use crate::logical_expr::Operator; -use crate::physical_expr::expressions as phys_expr; use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; -use crate::physical_plan::PhysicalExpr; use super::ParquetFileMetrics; @@ -118,188 +111,129 @@ pub(crate) fn prune_row_groups_by_statistics( pub(crate) async fn prune_row_groups_by_bloom_filters< T: AsyncFileReader + Send + 'static, >( + arrow_schema: &Schema, builder: &mut ParquetRecordBatchStreamBuilder, row_groups: &[usize], groups: &[RowGroupMetaData], predicate: &PruningPredicate, metrics: &ParquetFileMetrics, ) -> Vec { - let bf_predicates = match BloomFilterPruningPredicate::try_new(predicate.orig_expr()) - { - Ok(predicates) => predicates, - Err(_) => { - return row_groups.to_vec(); - } - }; let mut filtered = Vec::with_capacity(groups.len()); for idx in row_groups { - let rg_metadata = &groups[*idx]; - // get all columns bloom filter - let mut column_sbbf = - HashMap::with_capacity(bf_predicates.required_columns.len()); - for column_name in bf_predicates.required_columns.iter() { - let column_idx = match rg_metadata - .columns() - .iter() - .enumerate() - .find(|(_, column)| column.column_path().string().eq(column_name)) - { - Some((column_idx, _)) => column_idx, - None => continue, + // get all columns in the predicate that we could use a bloom filter with + let literal_columns = predicate.literal_columns(); + let mut column_sbbf = HashMap::with_capacity(literal_columns.len()); + + for column_name in literal_columns { + let Some((column_idx, _field)) = + parquet_column(builder.parquet_schema(), arrow_schema, &column_name) + else { + continue; }; + let bf = match builder .get_row_group_column_bloom_filter(*idx, column_idx) .await { - Ok(bf) => match bf { - Some(bf) => bf, - None => { - continue; - } - }, + Ok(Some(bf)) => bf, + Ok(None) => continue, // no bloom filter for this column Err(e) => { - log::error!("Error evaluating row group predicate values when using BloomFilterPruningPredicate {e}"); + log::debug!("Ignoring error reading bloom filter: {e}"); metrics.predicate_evaluation_errors.add(1); continue; } }; - column_sbbf.insert(column_name.to_owned(), bf); + column_sbbf.insert(column_name.to_string(), bf); } - if bf_predicates.prune(&column_sbbf) { + + let stats = BloomFilterStatistics { column_sbbf }; + + // Can this group be pruned? + let prune_group = match predicate.prune(&stats) { + Ok(values) => !values[0], + Err(e) => { + log::debug!("Error evaluating row group predicate on bloom filter: {e}"); + metrics.predicate_evaluation_errors.add(1); + false + } + }; + + if prune_group { metrics.row_groups_pruned.add(1); - continue; + } else { + filtered.push(*idx); } - filtered.push(*idx); } filtered } -struct BloomFilterPruningPredicate { - /// Actual pruning predicate - predicate_expr: Option, - /// The statistics required to evaluate this predicate - required_columns: Vec, +/// Implements `PruningStatistics` for Parquet Split Block Bloom Filters (SBBF) +struct BloomFilterStatistics { + /// Maps column name to the parquet bloom filter + column_sbbf: HashMap, } -impl BloomFilterPruningPredicate { - fn try_new(expr: &Arc) -> Result { - let binary_expr = expr.as_any().downcast_ref::(); - match binary_expr { - Some(binary_expr) => { - let columns = Self::get_predicate_columns(expr); - Ok(Self { - predicate_expr: Some(binary_expr.clone()), - required_columns: columns.into_iter().collect(), - }) - } - None => Err(DataFusionError::Execution( - "BloomFilterPruningPredicate only support binary expr".to_string(), - )), - } +impl PruningStatistics for BloomFilterStatistics { + fn min_values(&self, _column: &Column) -> Option { + None } - fn prune(&self, column_sbbf: &HashMap) -> bool { - Self::prune_expr_with_bloom_filter(self.predicate_expr.as_ref(), column_sbbf) + fn max_values(&self, _column: &Column) -> Option { + None } - /// Return true if the `expr` can be proved not `true` - /// based on the bloom filter. - /// - /// We only checked `BinaryExpr` but it also support `InList`, - /// Because of the `optimizer` will convert `InList` to `BinaryExpr`. - fn prune_expr_with_bloom_filter( - expr: Option<&phys_expr::BinaryExpr>, - column_sbbf: &HashMap, - ) -> bool { - let Some(expr) = expr else { - // unsupported predicate - return false; - }; - match expr.op() { - Operator::And | Operator::Or => { - let left = Self::prune_expr_with_bloom_filter( - expr.left().as_any().downcast_ref::(), - column_sbbf, - ); - let right = Self::prune_expr_with_bloom_filter( - expr.right() - .as_any() - .downcast_ref::(), - column_sbbf, - ); - match expr.op() { - Operator::And => left || right, - Operator::Or => left && right, - _ => false, - } - } - Operator::Eq => { - if let Some((col, val)) = Self::check_expr_is_col_equal_const(expr) { - if let Some(sbbf) = column_sbbf.get(col.name()) { - match val { - ScalarValue::Utf8(Some(v)) => !sbbf.check(&v.as_str()), - ScalarValue::Boolean(Some(v)) => !sbbf.check(&v), - ScalarValue::Float64(Some(v)) => !sbbf.check(&v), - ScalarValue::Float32(Some(v)) => !sbbf.check(&v), - ScalarValue::Int64(Some(v)) => !sbbf.check(&v), - ScalarValue::Int32(Some(v)) => !sbbf.check(&v), - ScalarValue::Int16(Some(v)) => !sbbf.check(&v), - ScalarValue::Int8(Some(v)) => !sbbf.check(&v), - _ => false, - } - } else { - false - } - } else { - false - } - } - _ => false, - } + fn num_containers(&self) -> usize { + 1 } - fn get_predicate_columns(expr: &Arc) -> HashSet { - let mut columns = HashSet::new(); - expr.apply(&mut |expr| { - if let Some(binary_expr) = - expr.as_any().downcast_ref::() - { - if let Some((column, _)) = - Self::check_expr_is_col_equal_const(binary_expr) - { - columns.insert(column.name().to_string()); - } - } - Ok(VisitRecursion::Continue) - }) - // no way to fail as only Ok(VisitRecursion::Continue) is returned - .unwrap(); - - columns + fn null_counts(&self, _column: &Column) -> Option { + None } - fn check_expr_is_col_equal_const( - exr: &phys_expr::BinaryExpr, - ) -> Option<(phys_expr::Column, ScalarValue)> { - if Operator::Eq.ne(exr.op()) { - return None; - } + /// Use bloom filters to determine if we are sure this column can not + /// possibly contain `values` + /// + /// The `contained` API returns false if the bloom filters knows that *ALL* + /// of the values in a column are not present. + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + let sbbf = self.column_sbbf.get(column.name.as_str())?; - let left_any = exr.left().as_any(); - let right_any = exr.right().as_any(); - if let (Some(col), Some(liter)) = ( - left_any.downcast_ref::(), - right_any.downcast_ref::(), - ) { - return Some((col.clone(), liter.value().clone())); - } - if let (Some(liter), Some(col)) = ( - left_any.downcast_ref::(), - right_any.downcast_ref::(), - ) { - return Some((col.clone(), liter.value().clone())); - } - None + // Bloom filters are probabilistic data structures that can return false + // positives (i.e. it might return true even if the value is not + // present) however, the bloom filter will return `false` if the value is + // definitely not present. + + let known_not_present = values + .iter() + .map(|value| match value { + ScalarValue::Utf8(Some(v)) => sbbf.check(&v.as_str()), + ScalarValue::Boolean(Some(v)) => sbbf.check(v), + ScalarValue::Float64(Some(v)) => sbbf.check(v), + ScalarValue::Float32(Some(v)) => sbbf.check(v), + ScalarValue::Int64(Some(v)) => sbbf.check(v), + ScalarValue::Int32(Some(v)) => sbbf.check(v), + ScalarValue::Int16(Some(v)) => sbbf.check(v), + ScalarValue::Int8(Some(v)) => sbbf.check(v), + _ => true, + }) + // The row group doesn't contain any of the values if + // all the checks are false + .all(|v| !v); + + let contains = if known_not_present { + Some(false) + } else { + // Given the bloom filter is probabilistic, we can't be sure that + // the row group actually contains the values. Return `None` to + // indicate this uncertainty + None + }; + + Some(BooleanArray::from(vec![contains])) } } @@ -1367,6 +1301,7 @@ mod tests { let metadata = builder.metadata().clone(); let pruned_row_group = prune_row_groups_by_bloom_filters( + pruning_predicate.schema(), &mut builder, row_groups, metadata.row_groups(), From 4e4d0508587096551c9a34439703f765fd96edaa Mon Sep 17 00:00:00 2001 From: tushushu <33303747+tushushu@users.noreply.github.com> Date: Tue, 26 Dec 2023 19:54:27 +0800 Subject: [PATCH 31/63] Split equivalence code into smaller modules. (#8649) * refactor * refactor * fix imports * fix ordering * private func as pub * private as pub * fix import * fix mod func * fix add_equal_conditions_test * fix project_equivalence_properties_test * fix test_ordering_satisfy * fix test_ordering_satisfy_with_equivalence2 * fix other ordering tests * fix join_equivalence_properties * fix test_expr_consists_of_constants * fix test_bridge_groups * fix test_remove_redundant_entries_eq_group * fix proj tests * test_remove_redundant_entries_oeq_class * test_schema_normalize_expr_with_equivalence * test_normalize_ordering_equivalence_classes * test_get_indices_of_matching_sort_exprs_with_order_eq * test_contains_any * test_update_ordering * test_find_longest_permutation_random * test_find_longest_permutation * test_get_meet_ordering * test_get_finer * test_normalize_sort_reqs * test_schema_normalize_sort_requirement_with_equivalence * expose func and struct * remove unused export --- datafusion/physical-expr/src/equivalence.rs | 5327 ----------------- .../physical-expr/src/equivalence/class.rs | 598 ++ .../physical-expr/src/equivalence/mod.rs | 533 ++ .../physical-expr/src/equivalence/ordering.rs | 1159 ++++ .../src/equivalence/projection.rs | 1153 ++++ .../src/equivalence/properties.rs | 2062 +++++++ 6 files changed, 5505 insertions(+), 5327 deletions(-) delete mode 100644 datafusion/physical-expr/src/equivalence.rs create mode 100644 datafusion/physical-expr/src/equivalence/class.rs create mode 100644 datafusion/physical-expr/src/equivalence/mod.rs create mode 100644 datafusion/physical-expr/src/equivalence/ordering.rs create mode 100644 datafusion/physical-expr/src/equivalence/projection.rs create mode 100644 datafusion/physical-expr/src/equivalence/properties.rs diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs deleted file mode 100644 index defd7b5786a3..000000000000 --- a/datafusion/physical-expr/src/equivalence.rs +++ /dev/null @@ -1,5327 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::collections::{HashMap, HashSet}; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; - -use crate::expressions::{Column, Literal}; -use crate::physical_expr::deduplicate_physical_exprs; -use crate::sort_properties::{ExprOrdering, SortProperties}; -use crate::{ - physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexOrderingRef, - LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalSortExpr, - PhysicalSortRequirement, -}; - -use arrow::datatypes::SchemaRef; -use arrow_schema::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{JoinSide, JoinType, Result}; - -use indexmap::IndexSet; -use itertools::Itertools; - -/// An `EquivalenceClass` is a set of [`Arc`]s that are known -/// to have the same value for all tuples in a relation. These are generated by -/// equality predicates (e.g. `a = b`), typically equi-join conditions and -/// equality conditions in filters. -/// -/// Two `EquivalenceClass`es are equal if they contains the same expressions in -/// without any ordering. -#[derive(Debug, Clone)] -pub struct EquivalenceClass { - /// The expressions in this equivalence class. The order doesn't - /// matter for equivalence purposes - /// - /// TODO: use a HashSet for this instead of a Vec - exprs: Vec>, -} - -impl PartialEq for EquivalenceClass { - /// Returns true if other is equal in the sense - /// of bags (multi-sets), disregarding their orderings. - fn eq(&self, other: &Self) -> bool { - physical_exprs_bag_equal(&self.exprs, &other.exprs) - } -} - -impl EquivalenceClass { - /// Create a new empty equivalence class - pub fn new_empty() -> Self { - Self { exprs: vec![] } - } - - // Create a new equivalence class from a pre-existing `Vec` - pub fn new(mut exprs: Vec>) -> Self { - deduplicate_physical_exprs(&mut exprs); - Self { exprs } - } - - /// Return the inner vector of expressions - pub fn into_vec(self) -> Vec> { - self.exprs - } - - /// Return the "canonical" expression for this class (the first element) - /// if any - fn canonical_expr(&self) -> Option> { - self.exprs.first().cloned() - } - - /// Insert the expression into this class, meaning it is known to be equal to - /// all other expressions in this class - pub fn push(&mut self, expr: Arc) { - if !self.contains(&expr) { - self.exprs.push(expr); - } - } - - /// Inserts all the expressions from other into this class - pub fn extend(&mut self, other: Self) { - for expr in other.exprs { - // use push so entries are deduplicated - self.push(expr); - } - } - - /// Returns true if this equivalence class contains t expression - pub fn contains(&self, expr: &Arc) -> bool { - physical_exprs_contains(&self.exprs, expr) - } - - /// Returns true if this equivalence class has any entries in common with `other` - pub fn contains_any(&self, other: &Self) -> bool { - self.exprs.iter().any(|e| other.contains(e)) - } - - /// return the number of items in this class - pub fn len(&self) -> usize { - self.exprs.len() - } - - /// return true if this class is empty - pub fn is_empty(&self) -> bool { - self.exprs.is_empty() - } - - /// Iterate over all elements in this class, in some arbitrary order - pub fn iter(&self) -> impl Iterator> { - self.exprs.iter() - } - - /// Return a new equivalence class that have the specified offset added to - /// each expression (used when schemas are appended such as in joins) - pub fn with_offset(&self, offset: usize) -> Self { - let new_exprs = self - .exprs - .iter() - .cloned() - .map(|e| add_offset_to_expr(e, offset)) - .collect(); - Self::new(new_exprs) - } -} - -/// Stores the mapping between source expressions and target expressions for a -/// projection. -#[derive(Debug, Clone)] -pub struct ProjectionMapping { - /// Mapping between source expressions and target expressions. - /// Vector indices correspond to the indices after projection. - map: Vec<(Arc, Arc)>, -} - -impl ProjectionMapping { - /// Constructs the mapping between a projection's input and output - /// expressions. - /// - /// For example, given the input projection expressions (`a + b`, `c + d`) - /// and an output schema with two columns `"c + d"` and `"a + b"`, the - /// projection mapping would be: - /// - /// ```text - /// [0]: (c + d, col("c + d")) - /// [1]: (a + b, col("a + b")) - /// ``` - /// - /// where `col("c + d")` means the column named `"c + d"`. - pub fn try_new( - expr: &[(Arc, String)], - input_schema: &SchemaRef, - ) -> Result { - // Construct a map from the input expressions to the output expression of the projection: - expr.iter() - .enumerate() - .map(|(expr_idx, (expression, name))| { - let target_expr = Arc::new(Column::new(name, expr_idx)) as _; - expression - .clone() - .transform_down(&|e| match e.as_any().downcast_ref::() { - Some(col) => { - // Sometimes, an expression and its name in the input_schema - // doesn't match. This can cause problems, so we make sure - // that the expression name matches with the name in `input_schema`. - // Conceptually, `source_expr` and `expression` should be the same. - let idx = col.index(); - let matching_input_field = input_schema.field(idx); - let matching_input_column = - Column::new(matching_input_field.name(), idx); - Ok(Transformed::Yes(Arc::new(matching_input_column))) - } - None => Ok(Transformed::No(e)), - }) - .map(|source_expr| (source_expr, target_expr)) - }) - .collect::>>() - .map(|map| Self { map }) - } - - /// Iterate over pairs of (source, target) expressions - pub fn iter( - &self, - ) -> impl Iterator, Arc)> + '_ { - self.map.iter() - } - - /// This function returns the target expression for a given source expression. - /// - /// # Arguments - /// - /// * `expr` - Source physical expression. - /// - /// # Returns - /// - /// An `Option` containing the target for the given source expression, - /// where a `None` value means that `expr` is not inside the mapping. - pub fn target_expr( - &self, - expr: &Arc, - ) -> Option> { - self.map - .iter() - .find(|(source, _)| source.eq(expr)) - .map(|(_, target)| target.clone()) - } -} - -/// An `EquivalenceGroup` is a collection of `EquivalenceClass`es where each -/// class represents a distinct equivalence class in a relation. -#[derive(Debug, Clone)] -pub struct EquivalenceGroup { - classes: Vec, -} - -impl EquivalenceGroup { - /// Creates an empty equivalence group. - fn empty() -> Self { - Self { classes: vec![] } - } - - /// Creates an equivalence group from the given equivalence classes. - fn new(classes: Vec) -> Self { - let mut result = Self { classes }; - result.remove_redundant_entries(); - result - } - - /// Returns how many equivalence classes there are in this group. - fn len(&self) -> usize { - self.classes.len() - } - - /// Checks whether this equivalence group is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns an iterator over the equivalence classes in this group. - pub fn iter(&self) -> impl Iterator { - self.classes.iter() - } - - /// Adds the equality `left` = `right` to this equivalence group. - /// New equality conditions often arise after steps like `Filter(a = b)`, - /// `Alias(a, a as b)` etc. - fn add_equal_conditions( - &mut self, - left: &Arc, - right: &Arc, - ) { - let mut first_class = None; - let mut second_class = None; - for (idx, cls) in self.classes.iter().enumerate() { - if cls.contains(left) { - first_class = Some(idx); - } - if cls.contains(right) { - second_class = Some(idx); - } - } - match (first_class, second_class) { - (Some(mut first_idx), Some(mut second_idx)) => { - // If the given left and right sides belong to different classes, - // we should unify/bridge these classes. - if first_idx != second_idx { - // By convention, make sure `second_idx` is larger than `first_idx`. - if first_idx > second_idx { - (first_idx, second_idx) = (second_idx, first_idx); - } - // Remove the class at `second_idx` and merge its values with - // the class at `first_idx`. The convention above makes sure - // that `first_idx` is still valid after removing `second_idx`. - let other_class = self.classes.swap_remove(second_idx); - self.classes[first_idx].extend(other_class); - } - } - (Some(group_idx), None) => { - // Right side is new, extend left side's class: - self.classes[group_idx].push(right.clone()); - } - (None, Some(group_idx)) => { - // Left side is new, extend right side's class: - self.classes[group_idx].push(left.clone()); - } - (None, None) => { - // None of the expressions is among existing classes. - // Create a new equivalence class and extend the group. - self.classes - .push(EquivalenceClass::new(vec![left.clone(), right.clone()])); - } - } - } - - /// Removes redundant entries from this group. - fn remove_redundant_entries(&mut self) { - // Remove duplicate entries from each equivalence class: - self.classes.retain_mut(|cls| { - // Keep groups that have at least two entries as singleton class is - // meaningless (i.e. it contains no non-trivial information): - cls.len() > 1 - }); - // Unify/bridge groups that have common expressions: - self.bridge_classes() - } - - /// This utility function unifies/bridges classes that have common expressions. - /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`. - /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all - /// equal and belong to one class. This utility converts merges such classes. - fn bridge_classes(&mut self) { - let mut idx = 0; - while idx < self.classes.len() { - let mut next_idx = idx + 1; - let start_size = self.classes[idx].len(); - while next_idx < self.classes.len() { - if self.classes[idx].contains_any(&self.classes[next_idx]) { - let extension = self.classes.swap_remove(next_idx); - self.classes[idx].extend(extension); - } else { - next_idx += 1; - } - } - if self.classes[idx].len() > start_size { - continue; - } - idx += 1; - } - } - - /// Extends this equivalence group with the `other` equivalence group. - fn extend(&mut self, other: Self) { - self.classes.extend(other.classes); - self.remove_redundant_entries(); - } - - /// Normalizes the given physical expression according to this group. - /// The expression is replaced with the first expression in the equivalence - /// class it matches with (if any). - pub fn normalize_expr(&self, expr: Arc) -> Arc { - expr.clone() - .transform(&|expr| { - for cls in self.iter() { - if cls.contains(&expr) { - return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); - } - } - Ok(Transformed::No(expr)) - }) - .unwrap_or(expr) - } - - /// Normalizes the given sort expression according to this group. - /// The underlying physical expression is replaced with the first expression - /// in the equivalence class it matches with (if any). If the underlying - /// expression does not belong to any equivalence class in this group, returns - /// the sort expression as is. - pub fn normalize_sort_expr( - &self, - mut sort_expr: PhysicalSortExpr, - ) -> PhysicalSortExpr { - sort_expr.expr = self.normalize_expr(sort_expr.expr); - sort_expr - } - - /// Normalizes the given sort requirement according to this group. - /// The underlying physical expression is replaced with the first expression - /// in the equivalence class it matches with (if any). If the underlying - /// expression does not belong to any equivalence class in this group, returns - /// the given sort requirement as is. - pub fn normalize_sort_requirement( - &self, - mut sort_requirement: PhysicalSortRequirement, - ) -> PhysicalSortRequirement { - sort_requirement.expr = self.normalize_expr(sort_requirement.expr); - sort_requirement - } - - /// This function applies the `normalize_expr` function for all expressions - /// in `exprs` and returns the corresponding normalized physical expressions. - pub fn normalize_exprs( - &self, - exprs: impl IntoIterator>, - ) -> Vec> { - exprs - .into_iter() - .map(|expr| self.normalize_expr(expr)) - .collect() - } - - /// This function applies the `normalize_sort_expr` function for all sort - /// expressions in `sort_exprs` and returns the corresponding normalized - /// sort expressions. - pub fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { - // Convert sort expressions to sort requirements: - let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); - // Normalize the requirements: - let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); - // Convert sort requirements back to sort expressions: - PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) - } - - /// This function applies the `normalize_sort_requirement` function for all - /// requirements in `sort_reqs` and returns the corresponding normalized - /// sort requirements. - pub fn normalize_sort_requirements( - &self, - sort_reqs: LexRequirementRef, - ) -> LexRequirement { - collapse_lex_req( - sort_reqs - .iter() - .map(|sort_req| self.normalize_sort_requirement(sort_req.clone())) - .collect(), - ) - } - - /// Projects `expr` according to the given projection mapping. - /// If the resulting expression is invalid after projection, returns `None`. - fn project_expr( - &self, - mapping: &ProjectionMapping, - expr: &Arc, - ) -> Option> { - // First, we try to project expressions with an exact match. If we are - // unable to do this, we consult equivalence classes. - if let Some(target) = mapping.target_expr(expr) { - // If we match the source, we can project directly: - return Some(target); - } else { - // If the given expression is not inside the mapping, try to project - // expressions considering the equivalence classes. - for (source, target) in mapping.iter() { - // If we match an equivalent expression to `source`, then we can - // project. For example, if we have the mapping `(a as a1, a + c)` - // and the equivalence class `(a, b)`, expression `b` projects to `a1`. - if self - .get_equivalence_class(source) - .map_or(false, |group| group.contains(expr)) - { - return Some(target.clone()); - } - } - } - // Project a non-leaf expression by projecting its children. - let children = expr.children(); - if children.is_empty() { - // Leaf expression should be inside mapping. - return None; - } - children - .into_iter() - .map(|child| self.project_expr(mapping, &child)) - .collect::>>() - .map(|children| expr.clone().with_new_children(children).unwrap()) - } - - /// Projects this equivalence group according to the given projection mapping. - pub fn project(&self, mapping: &ProjectionMapping) -> Self { - let projected_classes = self.iter().filter_map(|cls| { - let new_class = cls - .iter() - .filter_map(|expr| self.project_expr(mapping, expr)) - .collect::>(); - (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) - }); - // TODO: Convert the algorithm below to a version that uses `HashMap`. - // once `Arc` can be stored in `HashMap`. - // See issue: https://github.com/apache/arrow-datafusion/issues/8027 - let mut new_classes = vec![]; - for (source, target) in mapping.iter() { - if new_classes.is_empty() { - new_classes.push((source, vec![target.clone()])); - } - if let Some((_, values)) = - new_classes.iter_mut().find(|(key, _)| key.eq(source)) - { - if !physical_exprs_contains(values, target) { - values.push(target.clone()); - } - } - } - // Only add equivalence classes with at least two members as singleton - // equivalence classes are meaningless. - let new_classes = new_classes - .into_iter() - .filter_map(|(_, values)| (values.len() > 1).then_some(values)) - .map(EquivalenceClass::new); - - let classes = projected_classes.chain(new_classes).collect(); - Self::new(classes) - } - - /// Returns the equivalence class containing `expr`. If no equivalence class - /// contains `expr`, returns `None`. - fn get_equivalence_class( - &self, - expr: &Arc, - ) -> Option<&EquivalenceClass> { - self.iter().find(|cls| cls.contains(expr)) - } - - /// Combine equivalence groups of the given join children. - pub fn join( - &self, - right_equivalences: &Self, - join_type: &JoinType, - left_size: usize, - on: &[(Column, Column)], - ) -> Self { - match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { - let mut result = Self::new( - self.iter() - .cloned() - .chain( - right_equivalences - .iter() - .map(|cls| cls.with_offset(left_size)), - ) - .collect(), - ); - // In we have an inner join, expressions in the "on" condition - // are equal in the resulting table. - if join_type == &JoinType::Inner { - for (lhs, rhs) in on.iter() { - let index = rhs.index() + left_size; - let new_lhs = Arc::new(lhs.clone()) as _; - let new_rhs = Arc::new(Column::new(rhs.name(), index)) as _; - result.add_equal_conditions(&new_lhs, &new_rhs); - } - } - result - } - JoinType::LeftSemi | JoinType::LeftAnti => self.clone(), - JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), - } - } -} - -/// This function constructs a duplicate-free `LexOrderingReq` by filtering out -/// duplicate entries that have same physical expression inside. For example, -/// `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a Some(ASC)]`. -pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement { - let mut output = Vec::::new(); - for item in input { - if !output.iter().any(|req| req.expr.eq(&item.expr)) { - output.push(item); - } - } - output -} - -/// This function constructs a duplicate-free `LexOrdering` by filtering out -/// duplicate entries that have same physical expression inside. For example, -/// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. -pub fn collapse_lex_ordering(input: LexOrdering) -> LexOrdering { - let mut output = Vec::::new(); - for item in input { - if !output.iter().any(|req| req.expr.eq(&item.expr)) { - output.push(item); - } - } - output -} - -/// An `OrderingEquivalenceClass` object keeps track of different alternative -/// orderings than can describe a schema. For example, consider the following table: -/// -/// ```text -/// |a|b|c|d| -/// |1|4|3|1| -/// |2|3|3|2| -/// |3|1|2|2| -/// |3|2|1|3| -/// ``` -/// -/// Here, both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the table -/// ordering. In this case, we say that these orderings are equivalent. -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct OrderingEquivalenceClass { - orderings: Vec, -} - -impl OrderingEquivalenceClass { - /// Creates new empty ordering equivalence class. - fn empty() -> Self { - Self { orderings: vec![] } - } - - /// Clears (empties) this ordering equivalence class. - pub fn clear(&mut self) { - self.orderings.clear(); - } - - /// Creates new ordering equivalence class from the given orderings. - pub fn new(orderings: Vec) -> Self { - let mut result = Self { orderings }; - result.remove_redundant_entries(); - result - } - - /// Checks whether `ordering` is a member of this equivalence class. - pub fn contains(&self, ordering: &LexOrdering) -> bool { - self.orderings.contains(ordering) - } - - /// Adds `ordering` to this equivalence class. - #[allow(dead_code)] - fn push(&mut self, ordering: LexOrdering) { - self.orderings.push(ordering); - // Make sure that there are no redundant orderings: - self.remove_redundant_entries(); - } - - /// Checks whether this ordering equivalence class is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns an iterator over the equivalent orderings in this class. - pub fn iter(&self) -> impl Iterator { - self.orderings.iter() - } - - /// Returns how many equivalent orderings there are in this class. - pub fn len(&self) -> usize { - self.orderings.len() - } - - /// Extend this ordering equivalence class with the `other` class. - pub fn extend(&mut self, other: Self) { - self.orderings.extend(other.orderings); - // Make sure that there are no redundant orderings: - self.remove_redundant_entries(); - } - - /// Adds new orderings into this ordering equivalence class. - pub fn add_new_orderings( - &mut self, - orderings: impl IntoIterator, - ) { - self.orderings.extend(orderings); - // Make sure that there are no redundant orderings: - self.remove_redundant_entries(); - } - - /// Removes redundant orderings from this equivalence class. For instance, - /// if we already have the ordering `[a ASC, b ASC, c DESC]`, then there is - /// no need to keep ordering `[a ASC, b ASC]` in the state. - fn remove_redundant_entries(&mut self) { - let mut work = true; - while work { - work = false; - let mut idx = 0; - while idx < self.orderings.len() { - let mut ordering_idx = idx + 1; - let mut removal = self.orderings[idx].is_empty(); - while ordering_idx < self.orderings.len() { - work |= resolve_overlap(&mut self.orderings, idx, ordering_idx); - if self.orderings[idx].is_empty() { - removal = true; - break; - } - work |= resolve_overlap(&mut self.orderings, ordering_idx, idx); - if self.orderings[ordering_idx].is_empty() { - self.orderings.swap_remove(ordering_idx); - } else { - ordering_idx += 1; - } - } - if removal { - self.orderings.swap_remove(idx); - } else { - idx += 1; - } - } - } - } - - /// Returns the concatenation of all the orderings. This enables merge - /// operations to preserve all equivalent orderings simultaneously. - pub fn output_ordering(&self) -> Option { - let output_ordering = self.orderings.iter().flatten().cloned().collect(); - let output_ordering = collapse_lex_ordering(output_ordering); - (!output_ordering.is_empty()).then_some(output_ordering) - } - - // Append orderings in `other` to all existing orderings in this equivalence - // class. - pub fn join_suffix(mut self, other: &Self) -> Self { - let n_ordering = self.orderings.len(); - // Replicate entries before cross product - let n_cross = std::cmp::max(n_ordering, other.len() * n_ordering); - self.orderings = self - .orderings - .iter() - .cloned() - .cycle() - .take(n_cross) - .collect(); - // Suffix orderings of other to the current orderings. - for (outer_idx, ordering) in other.iter().enumerate() { - for idx in 0..n_ordering { - // Calculate cross product index - let idx = outer_idx * n_ordering + idx; - self.orderings[idx].extend(ordering.iter().cloned()); - } - } - self - } - - /// Adds `offset` value to the index of each expression inside this - /// ordering equivalence class. - pub fn add_offset(&mut self, offset: usize) { - for ordering in self.orderings.iter_mut() { - for sort_expr in ordering { - sort_expr.expr = add_offset_to_expr(sort_expr.expr.clone(), offset); - } - } - } - - /// Gets sort options associated with this expression if it is a leading - /// ordering expression. Otherwise, returns `None`. - fn get_options(&self, expr: &Arc) -> Option { - for ordering in self.iter() { - let leading_ordering = &ordering[0]; - if leading_ordering.expr.eq(expr) { - return Some(leading_ordering.options); - } - } - None - } -} - -/// Adds the `offset` value to `Column` indices inside `expr`. This function is -/// generally used during the update of the right table schema in join operations. -pub fn add_offset_to_expr( - expr: Arc, - offset: usize, -) -> Arc { - expr.transform_down(&|e| match e.as_any().downcast_ref::() { - Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( - col.name(), - offset + col.index(), - )))), - None => Ok(Transformed::No(e)), - }) - .unwrap() - // Note that we can safely unwrap here since our transform always returns - // an `Ok` value. -} - -/// Trims `orderings[idx]` if some suffix of it overlaps with a prefix of -/// `orderings[pre_idx]`. Returns `true` if there is any overlap, `false` otherwise. -fn resolve_overlap(orderings: &mut [LexOrdering], idx: usize, pre_idx: usize) -> bool { - let length = orderings[idx].len(); - let other_length = orderings[pre_idx].len(); - for overlap in 1..=length.min(other_length) { - if orderings[idx][length - overlap..] == orderings[pre_idx][..overlap] { - orderings[idx].truncate(length - overlap); - return true; - } - } - false -} - -/// A `EquivalenceProperties` object stores useful information related to a schema. -/// Currently, it keeps track of: -/// - Equivalent expressions, e.g expressions that have same value. -/// - Valid sort expressions (orderings) for the schema. -/// - Constants expressions (e.g expressions that are known to have constant values). -/// -/// Consider table below: -/// -/// ```text -/// ┌-------┐ -/// | a | b | -/// |---|---| -/// | 1 | 9 | -/// | 2 | 8 | -/// | 3 | 7 | -/// | 5 | 5 | -/// └---┴---┘ -/// ``` -/// -/// where both `a ASC` and `b DESC` can describe the table ordering. With -/// `EquivalenceProperties`, we can keep track of these different valid sort -/// expressions and treat `a ASC` and `b DESC` on an equal footing. -/// -/// Similarly, consider the table below: -/// -/// ```text -/// ┌-------┐ -/// | a | b | -/// |---|---| -/// | 1 | 1 | -/// | 2 | 2 | -/// | 3 | 3 | -/// | 5 | 5 | -/// └---┴---┘ -/// ``` -/// -/// where columns `a` and `b` always have the same value. We keep track of such -/// equivalences inside this object. With this information, we can optimize -/// things like partitioning. For example, if the partition requirement is -/// `Hash(a)` and output partitioning is `Hash(b)`, then we can deduce that -/// the existing partitioning satisfies the requirement. -#[derive(Debug, Clone)] -pub struct EquivalenceProperties { - /// Collection of equivalence classes that store expressions with the same - /// value. - eq_group: EquivalenceGroup, - /// Equivalent sort expressions for this table. - oeq_class: OrderingEquivalenceClass, - /// Expressions whose values are constant throughout the table. - /// TODO: We do not need to track constants separately, they can be tracked - /// inside `eq_groups` as `Literal` expressions. - constants: Vec>, - /// Schema associated with this object. - schema: SchemaRef, -} - -impl EquivalenceProperties { - /// Creates an empty `EquivalenceProperties` object. - pub fn new(schema: SchemaRef) -> Self { - Self { - eq_group: EquivalenceGroup::empty(), - oeq_class: OrderingEquivalenceClass::empty(), - constants: vec![], - schema, - } - } - - /// Creates a new `EquivalenceProperties` object with the given orderings. - pub fn new_with_orderings(schema: SchemaRef, orderings: &[LexOrdering]) -> Self { - Self { - eq_group: EquivalenceGroup::empty(), - oeq_class: OrderingEquivalenceClass::new(orderings.to_vec()), - constants: vec![], - schema, - } - } - - /// Returns the associated schema. - pub fn schema(&self) -> &SchemaRef { - &self.schema - } - - /// Returns a reference to the ordering equivalence class within. - pub fn oeq_class(&self) -> &OrderingEquivalenceClass { - &self.oeq_class - } - - /// Returns a reference to the equivalence group within. - pub fn eq_group(&self) -> &EquivalenceGroup { - &self.eq_group - } - - /// Returns a reference to the constant expressions - pub fn constants(&self) -> &[Arc] { - &self.constants - } - - /// Returns the normalized version of the ordering equivalence class within. - /// Normalization removes constants and duplicates as well as standardizing - /// expressions according to the equivalence group within. - pub fn normalized_oeq_class(&self) -> OrderingEquivalenceClass { - OrderingEquivalenceClass::new( - self.oeq_class - .iter() - .map(|ordering| self.normalize_sort_exprs(ordering)) - .collect(), - ) - } - - /// Extends this `EquivalenceProperties` with the `other` object. - pub fn extend(mut self, other: Self) -> Self { - self.eq_group.extend(other.eq_group); - self.oeq_class.extend(other.oeq_class); - self.add_constants(other.constants) - } - - /// Clears (empties) the ordering equivalence class within this object. - /// Call this method when existing orderings are invalidated. - pub fn clear_orderings(&mut self) { - self.oeq_class.clear(); - } - - /// Extends this `EquivalenceProperties` by adding the orderings inside the - /// ordering equivalence class `other`. - pub fn add_ordering_equivalence_class(&mut self, other: OrderingEquivalenceClass) { - self.oeq_class.extend(other); - } - - /// Adds new orderings into the existing ordering equivalence class. - pub fn add_new_orderings( - &mut self, - orderings: impl IntoIterator, - ) { - self.oeq_class.add_new_orderings(orderings); - } - - /// Incorporates the given equivalence group to into the existing - /// equivalence group within. - pub fn add_equivalence_group(&mut self, other_eq_group: EquivalenceGroup) { - self.eq_group.extend(other_eq_group); - } - - /// Adds a new equality condition into the existing equivalence group. - /// If the given equality defines a new equivalence class, adds this new - /// equivalence class to the equivalence group. - pub fn add_equal_conditions( - &mut self, - left: &Arc, - right: &Arc, - ) { - self.eq_group.add_equal_conditions(left, right); - } - - /// Track/register physical expressions with constant values. - pub fn add_constants( - mut self, - constants: impl IntoIterator>, - ) -> Self { - for expr in self.eq_group.normalize_exprs(constants) { - if !physical_exprs_contains(&self.constants, &expr) { - self.constants.push(expr); - } - } - self - } - - /// Updates the ordering equivalence group within assuming that the table - /// is re-sorted according to the argument `sort_exprs`. Note that constants - /// and equivalence classes are unchanged as they are unaffected by a re-sort. - pub fn with_reorder(mut self, sort_exprs: Vec) -> Self { - // TODO: In some cases, existing ordering equivalences may still be valid add this analysis. - self.oeq_class = OrderingEquivalenceClass::new(vec![sort_exprs]); - self - } - - /// Normalizes the given sort expressions (i.e. `sort_exprs`) using the - /// equivalence group and the ordering equivalence class within. - /// - /// Assume that `self.eq_group` states column `a` and `b` are aliases. - /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` - /// are equivalent (in the sense that both describe the ordering of the table). - /// If the `sort_exprs` argument were `vec![b ASC, c ASC, a ASC]`, then this - /// function would return `vec![a ASC, c ASC]`. Internally, it would first - /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result - /// after deduplication. - fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { - // Convert sort expressions to sort requirements: - let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); - // Normalize the requirements: - let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); - // Convert sort requirements back to sort expressions: - PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) - } - - /// Normalizes the given sort requirements (i.e. `sort_reqs`) using the - /// equivalence group and the ordering equivalence class within. It works by: - /// - Removing expressions that have a constant value from the given requirement. - /// - Replacing sections that belong to some equivalence class in the equivalence - /// group with the first entry in the matching equivalence class. - /// - /// Assume that `self.eq_group` states column `a` and `b` are aliases. - /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` - /// are equivalent (in the sense that both describe the ordering of the table). - /// If the `sort_reqs` argument were `vec![b ASC, c ASC, a ASC]`, then this - /// function would return `vec![a ASC, c ASC]`. Internally, it would first - /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result - /// after deduplication. - fn normalize_sort_requirements( - &self, - sort_reqs: LexRequirementRef, - ) -> LexRequirement { - let normalized_sort_reqs = self.eq_group.normalize_sort_requirements(sort_reqs); - let constants_normalized = self.eq_group.normalize_exprs(self.constants.clone()); - // Prune redundant sections in the requirement: - collapse_lex_req( - normalized_sort_reqs - .iter() - .filter(|&order| { - !physical_exprs_contains(&constants_normalized, &order.expr) - }) - .cloned() - .collect(), - ) - } - - /// Checks whether the given ordering is satisfied by any of the existing - /// orderings. - pub fn ordering_satisfy(&self, given: LexOrderingRef) -> bool { - // Convert the given sort expressions to sort requirements: - let sort_requirements = PhysicalSortRequirement::from_sort_exprs(given.iter()); - self.ordering_satisfy_requirement(&sort_requirements) - } - - /// Checks whether the given sort requirements are satisfied by any of the - /// existing orderings. - pub fn ordering_satisfy_requirement(&self, reqs: LexRequirementRef) -> bool { - let mut eq_properties = self.clone(); - // First, standardize the given requirement: - let normalized_reqs = eq_properties.normalize_sort_requirements(reqs); - for normalized_req in normalized_reqs { - // Check whether given ordering is satisfied - if !eq_properties.ordering_satisfy_single(&normalized_req) { - return false; - } - // Treat satisfied keys as constants in subsequent iterations. We - // can do this because the "next" key only matters in a lexicographical - // ordering when the keys to its left have the same values. - // - // Note that these expressions are not properly "constants". This is just - // an implementation strategy confined to this function. - // - // For example, assume that the requirement is `[a ASC, (b + c) ASC]`, - // and existing equivalent orderings are `[a ASC, b ASC]` and `[c ASC]`. - // From the analysis above, we know that `[a ASC]` is satisfied. Then, - // we add column `a` as constant to the algorithm state. This enables us - // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. - eq_properties = - eq_properties.add_constants(std::iter::once(normalized_req.expr)); - } - true - } - - /// Determines whether the ordering specified by the given sort requirement - /// is satisfied based on the orderings within, equivalence classes, and - /// constant expressions. - /// - /// # Arguments - /// - /// - `req`: A reference to a `PhysicalSortRequirement` for which the ordering - /// satisfaction check will be done. - /// - /// # Returns - /// - /// Returns `true` if the specified ordering is satisfied, `false` otherwise. - fn ordering_satisfy_single(&self, req: &PhysicalSortRequirement) -> bool { - let expr_ordering = self.get_expr_ordering(req.expr.clone()); - let ExprOrdering { expr, state, .. } = expr_ordering; - match state { - SortProperties::Ordered(options) => { - let sort_expr = PhysicalSortExpr { expr, options }; - sort_expr.satisfy(req, self.schema()) - } - // Singleton expressions satisfies any ordering. - SortProperties::Singleton => true, - SortProperties::Unordered => false, - } - } - - /// Checks whether the `given`` sort requirements are equal or more specific - /// than the `reference` sort requirements. - pub fn requirements_compatible( - &self, - given: LexRequirementRef, - reference: LexRequirementRef, - ) -> bool { - let normalized_given = self.normalize_sort_requirements(given); - let normalized_reference = self.normalize_sort_requirements(reference); - - (normalized_reference.len() <= normalized_given.len()) - && normalized_reference - .into_iter() - .zip(normalized_given) - .all(|(reference, given)| given.compatible(&reference)) - } - - /// Returns the finer ordering among the orderings `lhs` and `rhs`, breaking - /// any ties by choosing `lhs`. - /// - /// The finer ordering is the ordering that satisfies both of the orderings. - /// If the orderings are incomparable, returns `None`. - /// - /// For example, the finer ordering among `[a ASC]` and `[a ASC, b ASC]` is - /// the latter. - pub fn get_finer_ordering( - &self, - lhs: LexOrderingRef, - rhs: LexOrderingRef, - ) -> Option { - // Convert the given sort expressions to sort requirements: - let lhs = PhysicalSortRequirement::from_sort_exprs(lhs); - let rhs = PhysicalSortRequirement::from_sort_exprs(rhs); - let finer = self.get_finer_requirement(&lhs, &rhs); - // Convert the chosen sort requirements back to sort expressions: - finer.map(PhysicalSortRequirement::to_sort_exprs) - } - - /// Returns the finer ordering among the requirements `lhs` and `rhs`, - /// breaking any ties by choosing `lhs`. - /// - /// The finer requirements are the ones that satisfy both of the given - /// requirements. If the requirements are incomparable, returns `None`. - /// - /// For example, the finer requirements among `[a ASC]` and `[a ASC, b ASC]` - /// is the latter. - pub fn get_finer_requirement( - &self, - req1: LexRequirementRef, - req2: LexRequirementRef, - ) -> Option { - let mut lhs = self.normalize_sort_requirements(req1); - let mut rhs = self.normalize_sort_requirements(req2); - lhs.iter_mut() - .zip(rhs.iter_mut()) - .all(|(lhs, rhs)| { - lhs.expr.eq(&rhs.expr) - && match (lhs.options, rhs.options) { - (Some(lhs_opt), Some(rhs_opt)) => lhs_opt == rhs_opt, - (Some(options), None) => { - rhs.options = Some(options); - true - } - (None, Some(options)) => { - lhs.options = Some(options); - true - } - (None, None) => true, - } - }) - .then_some(if lhs.len() >= rhs.len() { lhs } else { rhs }) - } - - /// Calculates the "meet" of the given orderings (`lhs` and `rhs`). - /// The meet of a set of orderings is the finest ordering that is satisfied - /// by all the orderings in that set. For details, see: - /// - /// - /// - /// If there is no ordering that satisfies both `lhs` and `rhs`, returns - /// `None`. As an example, the meet of orderings `[a ASC]` and `[a ASC, b ASC]` - /// is `[a ASC]`. - pub fn get_meet_ordering( - &self, - lhs: LexOrderingRef, - rhs: LexOrderingRef, - ) -> Option { - let lhs = self.normalize_sort_exprs(lhs); - let rhs = self.normalize_sort_exprs(rhs); - let mut meet = vec![]; - for (lhs, rhs) in lhs.into_iter().zip(rhs.into_iter()) { - if lhs.eq(&rhs) { - meet.push(lhs); - } else { - break; - } - } - (!meet.is_empty()).then_some(meet) - } - - /// Projects argument `expr` according to `projection_mapping`, taking - /// equivalences into account. - /// - /// For example, assume that columns `a` and `c` are always equal, and that - /// `projection_mapping` encodes following mapping: - /// - /// ```text - /// a -> a1 - /// b -> b1 - /// ``` - /// - /// Then, this function projects `a + b` to `Some(a1 + b1)`, `c + b` to - /// `Some(a1 + b1)` and `d` to `None`, meaning that it cannot be projected. - pub fn project_expr( - &self, - expr: &Arc, - projection_mapping: &ProjectionMapping, - ) -> Option> { - self.eq_group.project_expr(projection_mapping, expr) - } - - /// Constructs a dependency map based on existing orderings referred to in - /// the projection. - /// - /// This function analyzes the orderings in the normalized order-equivalence - /// class and builds a dependency map. The dependency map captures relationships - /// between expressions within the orderings, helping to identify dependencies - /// and construct valid projected orderings during projection operations. - /// - /// # Parameters - /// - /// - `mapping`: A reference to the `ProjectionMapping` that defines the - /// relationship between source and target expressions. - /// - /// # Returns - /// - /// A [`DependencyMap`] representing the dependency map, where each - /// [`DependencyNode`] contains dependencies for the key [`PhysicalSortExpr`]. - /// - /// # Example - /// - /// Assume we have two equivalent orderings: `[a ASC, b ASC]` and `[a ASC, c ASC]`, - /// and the projection mapping is `[a -> a_new, b -> b_new, b + c -> b + c]`. - /// Then, the dependency map will be: - /// - /// ```text - /// a ASC: Node {Some(a_new ASC), HashSet{}} - /// b ASC: Node {Some(b_new ASC), HashSet{a ASC}} - /// c ASC: Node {None, HashSet{a ASC}} - /// ``` - fn construct_dependency_map(&self, mapping: &ProjectionMapping) -> DependencyMap { - let mut dependency_map = HashMap::new(); - for ordering in self.normalized_oeq_class().iter() { - for (idx, sort_expr) in ordering.iter().enumerate() { - let target_sort_expr = - self.project_expr(&sort_expr.expr, mapping).map(|expr| { - PhysicalSortExpr { - expr, - options: sort_expr.options, - } - }); - let is_projected = target_sort_expr.is_some(); - if is_projected - || mapping - .iter() - .any(|(source, _)| expr_refers(source, &sort_expr.expr)) - { - // Previous ordering is a dependency. Note that there is no, - // dependency for a leading ordering (i.e. the first sort - // expression). - let dependency = idx.checked_sub(1).map(|a| &ordering[a]); - // Add sort expressions that can be projected or referred to - // by any of the projection expressions to the dependency map: - dependency_map - .entry(sort_expr.clone()) - .or_insert_with(|| DependencyNode { - target_sort_expr: target_sort_expr.clone(), - dependencies: HashSet::new(), - }) - .insert_dependency(dependency); - } - if !is_projected { - // If we can not project, stop constructing the dependency - // map as remaining dependencies will be invalid after projection. - break; - } - } - } - dependency_map - } - - /// Returns a new `ProjectionMapping` where source expressions are normalized. - /// - /// This normalization ensures that source expressions are transformed into a - /// consistent representation. This is beneficial for algorithms that rely on - /// exact equalities, as it allows for more precise and reliable comparisons. - /// - /// # Parameters - /// - /// - `mapping`: A reference to the original `ProjectionMapping` to be normalized. - /// - /// # Returns - /// - /// A new `ProjectionMapping` with normalized source expressions. - fn normalized_mapping(&self, mapping: &ProjectionMapping) -> ProjectionMapping { - // Construct the mapping where source expressions are normalized. In this way - // In the algorithms below we can work on exact equalities - ProjectionMapping { - map: mapping - .iter() - .map(|(source, target)| { - let normalized_source = self.eq_group.normalize_expr(source.clone()); - (normalized_source, target.clone()) - }) - .collect(), - } - } - - /// Computes projected orderings based on a given projection mapping. - /// - /// This function takes a `ProjectionMapping` and computes the possible - /// orderings for the projected expressions. It considers dependencies - /// between expressions and generates valid orderings according to the - /// specified sort properties. - /// - /// # Parameters - /// - /// - `mapping`: A reference to the `ProjectionMapping` that defines the - /// relationship between source and target expressions. - /// - /// # Returns - /// - /// A vector of `LexOrdering` containing all valid orderings after projection. - fn projected_orderings(&self, mapping: &ProjectionMapping) -> Vec { - let mapping = self.normalized_mapping(mapping); - - // Get dependency map for existing orderings: - let dependency_map = self.construct_dependency_map(&mapping); - - let orderings = mapping.iter().flat_map(|(source, target)| { - referred_dependencies(&dependency_map, source) - .into_iter() - .filter_map(|relevant_deps| { - if let SortProperties::Ordered(options) = - get_expr_ordering(source, &relevant_deps) - { - Some((options, relevant_deps)) - } else { - // Do not consider unordered cases - None - } - }) - .flat_map(|(options, relevant_deps)| { - let sort_expr = PhysicalSortExpr { - expr: target.clone(), - options, - }; - // Generate dependent orderings (i.e. prefixes for `sort_expr`): - let mut dependency_orderings = - generate_dependency_orderings(&relevant_deps, &dependency_map); - // Append `sort_expr` to the dependent orderings: - for ordering in dependency_orderings.iter_mut() { - ordering.push(sort_expr.clone()); - } - dependency_orderings - }) - }); - - // Add valid projected orderings. For example, if existing ordering is - // `a + b` and projection is `[a -> a_new, b -> b_new]`, we need to - // preserve `a_new + b_new` as ordered. Please note that `a_new` and - // `b_new` themselves need not be ordered. Such dependencies cannot be - // deduced via the pass above. - let projected_orderings = dependency_map.iter().flat_map(|(sort_expr, node)| { - let mut prefixes = construct_prefix_orderings(sort_expr, &dependency_map); - if prefixes.is_empty() { - // If prefix is empty, there is no dependency. Insert - // empty ordering: - prefixes = vec![vec![]]; - } - // Append current ordering on top its dependencies: - for ordering in prefixes.iter_mut() { - if let Some(target) = &node.target_sort_expr { - ordering.push(target.clone()) - } - } - prefixes - }); - - // Simplify each ordering by removing redundant sections: - orderings - .chain(projected_orderings) - .map(collapse_lex_ordering) - .collect() - } - - /// Projects constants based on the provided `ProjectionMapping`. - /// - /// This function takes a `ProjectionMapping` and identifies/projects - /// constants based on the existing constants and the mapping. It ensures - /// that constants are appropriately propagated through the projection. - /// - /// # Arguments - /// - /// - `mapping`: A reference to a `ProjectionMapping` representing the - /// mapping of source expressions to target expressions in the projection. - /// - /// # Returns - /// - /// Returns a `Vec>` containing the projected constants. - fn projected_constants( - &self, - mapping: &ProjectionMapping, - ) -> Vec> { - // First, project existing constants. For example, assume that `a + b` - // is known to be constant. If the projection were `a as a_new`, `b as b_new`, - // then we would project constant `a + b` as `a_new + b_new`. - let mut projected_constants = self - .constants - .iter() - .flat_map(|expr| self.eq_group.project_expr(mapping, expr)) - .collect::>(); - // Add projection expressions that are known to be constant: - for (source, target) in mapping.iter() { - if self.is_expr_constant(source) - && !physical_exprs_contains(&projected_constants, target) - { - projected_constants.push(target.clone()); - } - } - projected_constants - } - - /// Projects the equivalences within according to `projection_mapping` - /// and `output_schema`. - pub fn project( - &self, - projection_mapping: &ProjectionMapping, - output_schema: SchemaRef, - ) -> Self { - let projected_constants = self.projected_constants(projection_mapping); - let projected_eq_group = self.eq_group.project(projection_mapping); - let projected_orderings = self.projected_orderings(projection_mapping); - Self { - eq_group: projected_eq_group, - oeq_class: OrderingEquivalenceClass::new(projected_orderings), - constants: projected_constants, - schema: output_schema, - } - } - - /// Returns the longest (potentially partial) permutation satisfying the - /// existing ordering. For example, if we have the equivalent orderings - /// `[a ASC, b ASC]` and `[c DESC]`, with `exprs` containing `[c, b, a, d]`, - /// then this function returns `([a ASC, b ASC, c DESC], [2, 1, 0])`. - /// This means that the specification `[a ASC, b ASC, c DESC]` is satisfied - /// by the existing ordering, and `[a, b, c]` resides at indices: `2, 1, 0` - /// inside the argument `exprs` (respectively). For the mathematical - /// definition of "partial permutation", see: - /// - /// - pub fn find_longest_permutation( - &self, - exprs: &[Arc], - ) -> (LexOrdering, Vec) { - let mut eq_properties = self.clone(); - let mut result = vec![]; - // The algorithm is as follows: - // - Iterate over all the expressions and insert ordered expressions - // into the result. - // - Treat inserted expressions as constants (i.e. add them as constants - // to the state). - // - Continue the above procedure until no expression is inserted; i.e. - // the algorithm reaches a fixed point. - // This algorithm should reach a fixed point in at most `exprs.len()` - // iterations. - let mut search_indices = (0..exprs.len()).collect::>(); - for _idx in 0..exprs.len() { - // Get ordered expressions with their indices. - let ordered_exprs = search_indices - .iter() - .flat_map(|&idx| { - let ExprOrdering { expr, state, .. } = - eq_properties.get_expr_ordering(exprs[idx].clone()); - if let SortProperties::Ordered(options) = state { - Some((PhysicalSortExpr { expr, options }, idx)) - } else { - None - } - }) - .collect::>(); - // We reached a fixed point, exit. - if ordered_exprs.is_empty() { - break; - } - // Remove indices that have an ordering from `search_indices`, and - // treat ordered expressions as constants in subsequent iterations. - // We can do this because the "next" key only matters in a lexicographical - // ordering when the keys to its left have the same values. - // - // Note that these expressions are not properly "constants". This is just - // an implementation strategy confined to this function. - for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { - eq_properties = - eq_properties.add_constants(std::iter::once(expr.clone())); - search_indices.remove(idx); - } - // Add new ordered section to the state. - result.extend(ordered_exprs); - } - result.into_iter().unzip() - } - - /// This function determines whether the provided expression is constant - /// based on the known constants. - /// - /// # Arguments - /// - /// - `expr`: A reference to a `Arc` representing the - /// expression to be checked. - /// - /// # Returns - /// - /// Returns `true` if the expression is constant according to equivalence - /// group, `false` otherwise. - fn is_expr_constant(&self, expr: &Arc) -> bool { - // As an example, assume that we know columns `a` and `b` are constant. - // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will - // return `false`. - let normalized_constants = self.eq_group.normalize_exprs(self.constants.to_vec()); - let normalized_expr = self.eq_group.normalize_expr(expr.clone()); - is_constant_recurse(&normalized_constants, &normalized_expr) - } - - /// Retrieves the ordering information for a given physical expression. - /// - /// This function constructs an `ExprOrdering` object for the provided - /// expression, which encapsulates information about the expression's - /// ordering, including its [`SortProperties`]. - /// - /// # Arguments - /// - /// - `expr`: An `Arc` representing the physical expression - /// for which ordering information is sought. - /// - /// # Returns - /// - /// Returns an `ExprOrdering` object containing the ordering information for - /// the given expression. - pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { - ExprOrdering::new(expr.clone()) - .transform_up(&|expr| Ok(update_ordering(expr, self))) - // Guaranteed to always return `Ok`. - .unwrap() - } -} - -/// This function determines whether the provided expression is constant -/// based on the known constants. -/// -/// # Arguments -/// -/// - `constants`: A `&[Arc]` containing expressions known to -/// be a constant. -/// - `expr`: A reference to a `Arc` representing the expression -/// to check. -/// -/// # Returns -/// -/// Returns `true` if the expression is constant according to equivalence -/// group, `false` otherwise. -fn is_constant_recurse( - constants: &[Arc], - expr: &Arc, -) -> bool { - if physical_exprs_contains(constants, expr) { - return true; - } - let children = expr.children(); - !children.is_empty() && children.iter().all(|c| is_constant_recurse(constants, c)) -} - -/// This function examines whether a referring expression directly refers to a -/// given referred expression or if any of its children in the expression tree -/// refer to the specified expression. -/// -/// # Parameters -/// -/// - `referring_expr`: A reference to the referring expression (`Arc`). -/// - `referred_expr`: A reference to the referred expression (`Arc`) -/// -/// # Returns -/// -/// A boolean value indicating whether `referring_expr` refers (needs it to evaluate its result) -/// `referred_expr` or not. -fn expr_refers( - referring_expr: &Arc, - referred_expr: &Arc, -) -> bool { - referring_expr.eq(referred_expr) - || referring_expr - .children() - .iter() - .any(|child| expr_refers(child, referred_expr)) -} - -/// Wrapper struct for `Arc` to use them as keys in a hash map. -#[derive(Debug, Clone)] -struct ExprWrapper(Arc); - -impl PartialEq for ExprWrapper { - fn eq(&self, other: &Self) -> bool { - self.0.eq(&other.0) - } -} - -impl Eq for ExprWrapper {} - -impl Hash for ExprWrapper { - fn hash(&self, state: &mut H) { - self.0.hash(state); - } -} - -/// This function analyzes the dependency map to collect referred dependencies for -/// a given source expression. -/// -/// # Parameters -/// -/// - `dependency_map`: A reference to the `DependencyMap` where each -/// `PhysicalSortExpr` is associated with a `DependencyNode`. -/// - `source`: A reference to the source expression (`Arc`) -/// for which relevant dependencies need to be identified. -/// -/// # Returns -/// -/// A `Vec` containing the dependencies for the given source -/// expression. These dependencies are expressions that are referred to by -/// the source expression based on the provided dependency map. -fn referred_dependencies( - dependency_map: &DependencyMap, - source: &Arc, -) -> Vec { - // Associate `PhysicalExpr`s with `PhysicalSortExpr`s that contain them: - let mut expr_to_sort_exprs = HashMap::::new(); - for sort_expr in dependency_map - .keys() - .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) - { - let key = ExprWrapper(sort_expr.expr.clone()); - expr_to_sort_exprs - .entry(key) - .or_default() - .insert(sort_expr.clone()); - } - - // Generate all valid dependencies for the source. For example, if the source - // is `a + b` and the map is `[a -> (a ASC, a DESC), b -> (b ASC)]`, we get - // `vec![HashSet(a ASC, b ASC), HashSet(a DESC, b ASC)]`. - expr_to_sort_exprs - .values() - .multi_cartesian_product() - .map(|referred_deps| referred_deps.into_iter().cloned().collect()) - .collect() -} - -/// This function recursively analyzes the dependencies of the given sort -/// expression within the given dependency map to construct lexicographical -/// orderings that include the sort expression and its dependencies. -/// -/// # Parameters -/// -/// - `referred_sort_expr`: A reference to the sort expression (`PhysicalSortExpr`) -/// for which lexicographical orderings satisfying its dependencies are to be -/// constructed. -/// - `dependency_map`: A reference to the `DependencyMap` that contains -/// dependencies for different `PhysicalSortExpr`s. -/// -/// # Returns -/// -/// A vector of lexicographical orderings (`Vec`) based on the given -/// sort expression and its dependencies. -fn construct_orderings( - referred_sort_expr: &PhysicalSortExpr, - dependency_map: &DependencyMap, -) -> Vec { - // We are sure that `referred_sort_expr` is inside `dependency_map`. - let node = &dependency_map[referred_sort_expr]; - // Since we work on intermediate nodes, we are sure `val.target_sort_expr` - // exists. - let target_sort_expr = node.target_sort_expr.clone().unwrap(); - if node.dependencies.is_empty() { - vec![vec![target_sort_expr]] - } else { - node.dependencies - .iter() - .flat_map(|dep| { - let mut orderings = construct_orderings(dep, dependency_map); - for ordering in orderings.iter_mut() { - ordering.push(target_sort_expr.clone()) - } - orderings - }) - .collect() - } -} - -/// This function retrieves the dependencies of the given relevant sort expression -/// from the given dependency map. It then constructs prefix orderings by recursively -/// analyzing the dependencies and include them in the orderings. -/// -/// # Parameters -/// -/// - `relevant_sort_expr`: A reference to the relevant sort expression -/// (`PhysicalSortExpr`) for which prefix orderings are to be constructed. -/// - `dependency_map`: A reference to the `DependencyMap` containing dependencies. -/// -/// # Returns -/// -/// A vector of prefix orderings (`Vec`) based on the given relevant -/// sort expression and its dependencies. -fn construct_prefix_orderings( - relevant_sort_expr: &PhysicalSortExpr, - dependency_map: &DependencyMap, -) -> Vec { - dependency_map[relevant_sort_expr] - .dependencies - .iter() - .flat_map(|dep| construct_orderings(dep, dependency_map)) - .collect() -} - -/// Given a set of relevant dependencies (`relevant_deps`) and a map of dependencies -/// (`dependency_map`), this function generates all possible prefix orderings -/// based on the given dependencies. -/// -/// # Parameters -/// -/// * `dependencies` - A reference to the dependencies. -/// * `dependency_map` - A reference to the map of dependencies for expressions. -/// -/// # Returns -/// -/// A vector of lexical orderings (`Vec`) representing all valid orderings -/// based on the given dependencies. -fn generate_dependency_orderings( - dependencies: &Dependencies, - dependency_map: &DependencyMap, -) -> Vec { - // Construct all the valid prefix orderings for each expression appearing - // in the projection: - let relevant_prefixes = dependencies - .iter() - .flat_map(|dep| { - let prefixes = construct_prefix_orderings(dep, dependency_map); - (!prefixes.is_empty()).then_some(prefixes) - }) - .collect::>(); - - // No dependency, dependent is a leading ordering. - if relevant_prefixes.is_empty() { - // Return an empty ordering: - return vec![vec![]]; - } - - // Generate all possible orderings where dependencies are satisfied for the - // current projection expression. For example, if expression is `a + b ASC`, - // and the dependency for `a ASC` is `[c ASC]`, the dependency for `b ASC` - // is `[d DESC]`, then we generate `[c ASC, d DESC, a + b ASC]` and - // `[d DESC, c ASC, a + b ASC]`. - relevant_prefixes - .into_iter() - .multi_cartesian_product() - .flat_map(|prefix_orderings| { - prefix_orderings - .iter() - .permutations(prefix_orderings.len()) - .map(|prefixes| prefixes.into_iter().flatten().cloned().collect()) - .collect::>() - }) - .collect() -} - -/// This function examines the given expression and the sort expressions it -/// refers to determine the ordering properties of the expression. -/// -/// # Parameters -/// -/// - `expr`: A reference to the source expression (`Arc`) for -/// which ordering properties need to be determined. -/// - `dependencies`: A reference to `Dependencies`, containing sort expressions -/// referred to by `expr`. -/// -/// # Returns -/// -/// A `SortProperties` indicating the ordering information of the given expression. -fn get_expr_ordering( - expr: &Arc, - dependencies: &Dependencies, -) -> SortProperties { - if let Some(column_order) = dependencies.iter().find(|&order| expr.eq(&order.expr)) { - // If exact match is found, return its ordering. - SortProperties::Ordered(column_order.options) - } else { - // Find orderings of its children - let child_states = expr - .children() - .iter() - .map(|child| get_expr_ordering(child, dependencies)) - .collect::>(); - // Calculate expression ordering using ordering of its children. - expr.get_ordering(&child_states) - } -} - -/// Represents a node in the dependency map used to construct projected orderings. -/// -/// A `DependencyNode` contains information about a particular sort expression, -/// including its target sort expression and a set of dependencies on other sort -/// expressions. -/// -/// # Fields -/// -/// - `target_sort_expr`: An optional `PhysicalSortExpr` representing the target -/// sort expression associated with the node. It is `None` if the sort expression -/// cannot be projected. -/// - `dependencies`: A [`Dependencies`] containing dependencies on other sort -/// expressions that are referred to by the target sort expression. -#[derive(Debug, Clone, PartialEq, Eq)] -struct DependencyNode { - target_sort_expr: Option, - dependencies: Dependencies, -} - -impl DependencyNode { - // Insert dependency to the state (if exists). - fn insert_dependency(&mut self, dependency: Option<&PhysicalSortExpr>) { - if let Some(dep) = dependency { - self.dependencies.insert(dep.clone()); - } - } -} - -type DependencyMap = HashMap; -type Dependencies = HashSet; - -/// Calculate ordering equivalence properties for the given join operation. -pub fn join_equivalence_properties( - left: EquivalenceProperties, - right: EquivalenceProperties, - join_type: &JoinType, - join_schema: SchemaRef, - maintains_input_order: &[bool], - probe_side: Option, - on: &[(Column, Column)], -) -> EquivalenceProperties { - let left_size = left.schema.fields.len(); - let mut result = EquivalenceProperties::new(join_schema); - result.add_equivalence_group(left.eq_group().join( - right.eq_group(), - join_type, - left_size, - on, - )); - - let left_oeq_class = left.oeq_class; - let mut right_oeq_class = right.oeq_class; - match maintains_input_order { - [true, false] => { - // In this special case, right side ordering can be prefixed with - // the left side ordering. - if let (Some(JoinSide::Left), JoinType::Inner) = (probe_side, join_type) { - updated_right_ordering_equivalence_class( - &mut right_oeq_class, - join_type, - left_size, - ); - - // Right side ordering equivalence properties should be prepended - // with those of the left side while constructing output ordering - // equivalence properties since stream side is the left side. - // - // For example, if the right side ordering equivalences contain - // `b ASC`, and the left side ordering equivalences contain `a ASC`, - // then we should add `a ASC, b ASC` to the ordering equivalences - // of the join output. - let out_oeq_class = left_oeq_class.join_suffix(&right_oeq_class); - result.add_ordering_equivalence_class(out_oeq_class); - } else { - result.add_ordering_equivalence_class(left_oeq_class); - } - } - [false, true] => { - updated_right_ordering_equivalence_class( - &mut right_oeq_class, - join_type, - left_size, - ); - // In this special case, left side ordering can be prefixed with - // the right side ordering. - if let (Some(JoinSide::Right), JoinType::Inner) = (probe_side, join_type) { - // Left side ordering equivalence properties should be prepended - // with those of the right side while constructing output ordering - // equivalence properties since stream side is the right side. - // - // For example, if the left side ordering equivalences contain - // `a ASC`, and the right side ordering equivalences contain `b ASC`, - // then we should add `b ASC, a ASC` to the ordering equivalences - // of the join output. - let out_oeq_class = right_oeq_class.join_suffix(&left_oeq_class); - result.add_ordering_equivalence_class(out_oeq_class); - } else { - result.add_ordering_equivalence_class(right_oeq_class); - } - } - [false, false] => {} - [true, true] => unreachable!("Cannot maintain ordering of both sides"), - _ => unreachable!("Join operators can not have more than two children"), - } - result -} - -/// In the context of a join, update the right side `OrderingEquivalenceClass` -/// so that they point to valid indices in the join output schema. -/// -/// To do so, we increment column indices by the size of the left table when -/// join schema consists of a combination of the left and right schemas. This -/// is the case for `Inner`, `Left`, `Full` and `Right` joins. For other cases, -/// indices do not change. -fn updated_right_ordering_equivalence_class( - right_oeq_class: &mut OrderingEquivalenceClass, - join_type: &JoinType, - left_size: usize, -) { - if matches!( - join_type, - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right - ) { - right_oeq_class.add_offset(left_size); - } -} - -/// Calculates the [`SortProperties`] of a given [`ExprOrdering`] node. -/// The node can either be a leaf node, or an intermediate node: -/// - If it is a leaf node, we directly find the order of the node by looking -/// at the given sort expression and equivalence properties if it is a `Column` -/// leaf, or we mark it as unordered. In the case of a `Literal` leaf, we mark -/// it as singleton so that it can cooperate with all ordered columns. -/// - If it is an intermediate node, the children states matter. Each `PhysicalExpr` -/// and operator has its own rules on how to propagate the children orderings. -/// However, before we engage in recursion, we check whether this intermediate -/// node directly matches with the sort expression. If there is a match, the -/// sort expression emerges at that node immediately, discarding the recursive -/// result coming from its children. -fn update_ordering( - mut node: ExprOrdering, - eq_properties: &EquivalenceProperties, -) -> Transformed { - // We have a Column, which is one of the two possible leaf node types: - let normalized_expr = eq_properties.eq_group.normalize_expr(node.expr.clone()); - if eq_properties.is_expr_constant(&normalized_expr) { - node.state = SortProperties::Singleton; - } else if let Some(options) = eq_properties - .normalized_oeq_class() - .get_options(&normalized_expr) - { - node.state = SortProperties::Ordered(options); - } else if !node.expr.children().is_empty() { - // We have an intermediate (non-leaf) node, account for its children: - node.state = node.expr.get_ordering(&node.children_state()); - } else if node.expr.as_any().is::() { - // We have a Literal, which is the other possible leaf node type: - node.state = node.expr.get_ordering(&[]); - } else { - return Transformed::No(node); - } - Transformed::Yes(node) -} - -#[cfg(test)] -mod tests { - use std::ops::Not; - use std::sync::Arc; - - use super::*; - use crate::execution_props::ExecutionProps; - use crate::expressions::{col, lit, BinaryExpr, Column, Literal}; - use crate::functions::create_physical_expr; - - use arrow::compute::{lexsort_to_indices, SortColumn}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; - use arrow_schema::{Fields, SortOptions, TimeUnit}; - use datafusion_common::{plan_datafusion_err, DataFusionError, Result, ScalarValue}; - use datafusion_expr::{BuiltinScalarFunction, Operator}; - - use itertools::{izip, Itertools}; - use rand::rngs::StdRng; - use rand::seq::SliceRandom; - use rand::{Rng, SeedableRng}; - - fn output_schema( - mapping: &ProjectionMapping, - input_schema: &Arc, - ) -> Result { - // Calculate output schema - let fields: Result> = mapping - .iter() - .map(|(source, target)| { - let name = target - .as_any() - .downcast_ref::() - .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? - .name(); - let field = Field::new( - name, - source.data_type(input_schema)?, - source.nullable(input_schema)?, - ); - - Ok(field) - }) - .collect(); - - let output_schema = Arc::new(Schema::new_with_metadata( - fields?, - input_schema.metadata().clone(), - )); - - Ok(output_schema) - } - - // Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h) - fn create_test_schema() -> Result { - let a = Field::new("a", DataType::Int32, true); - let b = Field::new("b", DataType::Int32, true); - let c = Field::new("c", DataType::Int32, true); - let d = Field::new("d", DataType::Int32, true); - let e = Field::new("e", DataType::Int32, true); - let f = Field::new("f", DataType::Int32, true); - let g = Field::new("g", DataType::Int32, true); - let h = Field::new("h", DataType::Int32, true); - let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h])); - - Ok(schema) - } - - /// Construct a schema with following properties - /// Schema satisfies following orderings: - /// [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - /// and - /// Column [a=c] (e.g they are aliases). - fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { - let test_schema = create_test_schema()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_g = &col("g", &test_schema)?; - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - eq_properties.add_equal_conditions(col_a, col_c); - - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let orderings = vec![ - // [a ASC] - vec![(col_a, option_asc)], - // [d ASC, b ASC] - vec![(col_d, option_asc), (col_b, option_asc)], - // [e DESC, f ASC, g ASC] - vec![ - (col_e, option_desc), - (col_f, option_asc), - (col_g, option_asc), - ], - ]; - let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); - Ok((test_schema, eq_properties)) - } - - // Generate a schema which consists of 6 columns (a, b, c, d, e, f) - fn create_test_schema_2() -> Result { - let a = Field::new("a", DataType::Float64, true); - let b = Field::new("b", DataType::Float64, true); - let c = Field::new("c", DataType::Float64, true); - let d = Field::new("d", DataType::Float64, true); - let e = Field::new("e", DataType::Float64, true); - let f = Field::new("f", DataType::Float64, true); - let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); - - Ok(schema) - } - - /// Construct a schema with random ordering - /// among column a, b, c, d - /// where - /// Column [a=f] (e.g they are aliases). - /// Column e is constant. - fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { - let test_schema = create_test_schema_2()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; - - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f); - // Column e has constant value. - eq_properties = eq_properties.add_constants([col_e.clone()]); - - // Randomly order columns for sorting - let mut rng = StdRng::seed_from_u64(seed); - let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted - - let options_asc = SortOptions { - descending: false, - nulls_first: false, - }; - - while !remaining_exprs.is_empty() { - let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); - remaining_exprs.shuffle(&mut rng); - - let ordering = remaining_exprs - .drain(0..n_sort_expr) - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: options_asc, - }) - .collect(); - - eq_properties.add_new_orderings([ordering]); - } - - Ok((test_schema, eq_properties)) - } - - // Convert each tuple to PhysicalSortRequirement - fn convert_to_sort_reqs( - in_data: &[(&Arc, Option)], - ) -> Vec { - in_data - .iter() - .map(|(expr, options)| { - PhysicalSortRequirement::new((*expr).clone(), *options) - }) - .collect() - } - - // Convert each tuple to PhysicalSortExpr - fn convert_to_sort_exprs( - in_data: &[(&Arc, SortOptions)], - ) -> Vec { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: (*expr).clone(), - options: *options, - }) - .collect() - } - - // Convert each inner tuple to PhysicalSortExpr - fn convert_to_orderings( - orderings: &[Vec<(&Arc, SortOptions)>], - ) -> Vec> { - orderings - .iter() - .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) - .collect() - } - - // Convert each tuple to PhysicalSortExpr - fn convert_to_sort_exprs_owned( - in_data: &[(Arc, SortOptions)], - ) -> Vec { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: (*expr).clone(), - options: *options, - }) - .collect() - } - - // Convert each inner tuple to PhysicalSortExpr - fn convert_to_orderings_owned( - orderings: &[Vec<(Arc, SortOptions)>], - ) -> Vec> { - orderings - .iter() - .map(|sort_exprs| convert_to_sort_exprs_owned(sort_exprs)) - .collect() - } - - // Apply projection to the input_data, return projected equivalence properties and record batch - fn apply_projection( - proj_exprs: Vec<(Arc, String)>, - input_data: &RecordBatch, - input_eq_properties: &EquivalenceProperties, - ) -> Result<(RecordBatch, EquivalenceProperties)> { - let input_schema = input_data.schema(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; - - let output_schema = output_schema(&projection_mapping, &input_schema)?; - let num_rows = input_data.num_rows(); - // Apply projection to the input record batch. - let projected_values = projection_mapping - .iter() - .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) - .collect::>>()?; - let projected_batch = if projected_values.is_empty() { - RecordBatch::new_empty(output_schema.clone()) - } else { - RecordBatch::try_new(output_schema.clone(), projected_values)? - }; - - let projected_eq = - input_eq_properties.project(&projection_mapping, output_schema); - Ok((projected_batch, projected_eq)) - } - - #[test] - fn add_equal_conditions_test() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - Field::new("x", DataType::Int64, true), - Field::new("y", DataType::Int64, true), - ])); - - let mut eq_properties = EquivalenceProperties::new(schema); - let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; - let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; - let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; - - // a and b are aliases - eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr); - assert_eq!(eq_properties.eq_group().len(), 1); - - // This new entry is redundant, size shouldn't increase - eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr); - assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; - assert_eq!(eq_groups.len(), 2); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - - // b and c are aliases. Exising equivalence class should expand, - // however there shouldn't be any new equivalence class - eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr); - assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; - assert_eq!(eq_groups.len(), 3); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); - - // This is a new set of equality. Hence equivalent class count should be 2. - eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr); - assert_eq!(eq_properties.eq_group().len(), 2); - - // This equality bridges distinct equality sets. - // Hence equivalent class count should decrease from 2 to 1. - eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr); - assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; - assert_eq!(eq_groups.len(), 5); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); - assert!(eq_groups.contains(&col_x_expr)); - assert!(eq_groups.contains(&col_y_expr)); - - Ok(()) - } - - #[test] - fn project_equivalence_properties_test() -> Result<()> { - let input_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - ])); - - let input_properties = EquivalenceProperties::new(input_schema.clone()); - let col_a = col("a", &input_schema)?; - - // a as a1, a as a2, a as a3, a as a3 - let proj_exprs = vec![ - (col_a.clone(), "a1".to_string()), - (col_a.clone(), "a2".to_string()), - (col_a.clone(), "a3".to_string()), - (col_a.clone(), "a4".to_string()), - ]; - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; - - let out_schema = output_schema(&projection_mapping, &input_schema)?; - // a as a1, a as a2, a as a3, a as a3 - let proj_exprs = vec![ - (col_a.clone(), "a1".to_string()), - (col_a.clone(), "a2".to_string()), - (col_a.clone(), "a3".to_string()), - (col_a.clone(), "a4".to_string()), - ]; - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; - - // a as a1, a as a2, a as a3, a as a3 - let col_a1 = &col("a1", &out_schema)?; - let col_a2 = &col("a2", &out_schema)?; - let col_a3 = &col("a3", &out_schema)?; - let col_a4 = &col("a4", &out_schema)?; - let out_properties = input_properties.project(&projection_mapping, out_schema); - - // At the output a1=a2=a3=a4 - assert_eq!(out_properties.eq_group().len(), 1); - let eq_class = &out_properties.eq_group().classes[0]; - assert_eq!(eq_class.len(), 4); - assert!(eq_class.contains(col_a1)); - assert!(eq_class.contains(col_a2)); - assert!(eq_class.contains(col_a3)); - assert!(eq_class.contains(col_a4)); - - Ok(()) - } - - #[test] - fn test_ordering_satisfy() -> Result<()> { - let input_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - ])); - let crude = vec![PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }]; - let finer = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - ]; - // finer ordering satisfies, crude ordering should return true - let mut eq_properties_finer = EquivalenceProperties::new(input_schema.clone()); - eq_properties_finer.oeq_class.push(finer.clone()); - assert!(eq_properties_finer.ordering_satisfy(&crude)); - - // Crude ordering doesn't satisfy finer ordering. should return false - let mut eq_properties_crude = EquivalenceProperties::new(input_schema.clone()); - eq_properties_crude.oeq_class.push(crude.clone()); - assert!(!eq_properties_crude.ordering_satisfy(&finer)); - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence() -> Result<()> { - // Schema satisfies following orderings: - // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - // and - // Column [a=c] (e.g they are aliases). - let (test_schema, eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_g = &col("g", &test_schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, 625, 5)?; - - // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function - let requirements = vec![ - // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it - (vec![(col_a, option_asc)], true), - (vec![(col_a, option_desc)], false), - // Test whether equivalence works as expected - (vec![(col_c, option_asc)], true), - (vec![(col_c, option_desc)], false), - // Test whether ordering equivalence works as expected - (vec![(col_d, option_asc)], true), - (vec![(col_d, option_asc), (col_b, option_asc)], true), - (vec![(col_d, option_desc), (col_b, option_asc)], false), - ( - vec![ - (col_e, option_desc), - (col_f, option_asc), - (col_g, option_asc), - ], - true, - ), - (vec![(col_e, option_desc), (col_f, option_asc)], true), - (vec![(col_e, option_asc), (col_f, option_asc)], false), - (vec![(col_e, option_desc), (col_b, option_asc)], false), - (vec![(col_e, option_asc), (col_b, option_asc)], false), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_d, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_desc), - (col_f, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_desc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_d, option_desc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_f, option_asc), - ], - false, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_b, option_asc), - ], - false, - ), - (vec![(col_d, option_asc), (col_e, option_desc)], true), - ( - vec![ - (col_d, option_asc), - (col_c, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_f, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_c, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_b, option_asc), - (col_f, option_asc), - ], - true, - ), - ]; - - for (cols, expected) in requirements { - let err_msg = format!("Error in test case:{cols:?}"); - let required = cols - .into_iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: expr.clone(), - options, - }) - .collect::>(); - - // Check expected result with experimental result. - assert_eq!( - is_table_same_after_sort( - required.clone(), - table_data_with_properties.clone() - )?, - expected - ); - assert_eq!( - eq_properties.ordering_satisfy(&required), - expected, - "{err_msg}" - ); - } - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence2() -> Result<()> { - let test_schema = create_test_schema()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let floor_a = &create_physical_expr( - &BuiltinScalarFunction::Floor, - &[col("a", &test_schema)?], - &test_schema, - &ExecutionProps::default(), - )?; - let floor_f = &create_physical_expr( - &BuiltinScalarFunction::Floor, - &[col("f", &test_schema)?], - &test_schema, - &ExecutionProps::default(), - )?; - let exp_a = &create_physical_expr( - &BuiltinScalarFunction::Exp, - &[col("a", &test_schema)?], - &test_schema, - &ExecutionProps::default(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), - Operator::Plus, - col_b.clone(), - )) as Arc; - let options = SortOptions { - descending: false, - nulls_first: false, - }; - - let test_cases = vec![ - // ------------ TEST CASE 1 ------------ - ( - // orderings - vec![ - // [a ASC, d ASC, b ASC] - vec![(col_a, options), (col_d, options), (col_b, options)], - // [c ASC] - vec![(col_c, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [a ASC, b ASC], requirement is not satisfied. - vec![(col_a, options), (col_b, options)], - // expected: requirement is not satisfied. - false, - ), - // ------------ TEST CASE 2 ------------ - ( - // orderings - vec![ - // [a ASC, c ASC, b ASC] - vec![(col_a, options), (col_c, options), (col_b, options)], - // [d ASC] - vec![(col_d, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [floor(a) ASC], - vec![(floor_a, options)], - // expected: requirement is satisfied. - true, - ), - // ------------ TEST CASE 2.1 ------------ - ( - // orderings - vec![ - // [a ASC, c ASC, b ASC] - vec![(col_a, options), (col_c, options), (col_b, options)], - // [d ASC] - vec![(col_d, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [floor(f) ASC], (Please note that a=f) - vec![(floor_f, options)], - // expected: requirement is satisfied. - true, - ), - // ------------ TEST CASE 3 ------------ - ( - // orderings - vec![ - // [a ASC, c ASC, b ASC] - vec![(col_a, options), (col_c, options), (col_b, options)], - // [d ASC] - vec![(col_d, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [a ASC, c ASC, a+b ASC], - vec![(col_a, options), (col_c, options), (&a_plus_b, options)], - // expected: requirement is satisfied. - true, - ), - // ------------ TEST CASE 4 ------------ - ( - // orderings - vec![ - // [a ASC, b ASC, c ASC, d ASC] - vec![ - (col_a, options), - (col_b, options), - (col_c, options), - (col_d, options), - ], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [floor(a) ASC, a+b ASC], - vec![(floor_a, options), (&a_plus_b, options)], - // expected: requirement is satisfied. - false, - ), - // ------------ TEST CASE 5 ------------ - ( - // orderings - vec![ - // [a ASC, b ASC, c ASC, d ASC] - vec![ - (col_a, options), - (col_b, options), - (col_c, options), - (col_d, options), - ], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [exp(a) ASC, a+b ASC], - vec![(exp_a, options), (&a_plus_b, options)], - // expected: requirement is not satisfied. - // TODO: If we know that exp function is 1-to-1 function. - // we could have deduced that above requirement is satisfied. - false, - ), - // ------------ TEST CASE 6 ------------ - ( - // orderings - vec![ - // [a ASC, d ASC, b ASC] - vec![(col_a, options), (col_d, options), (col_b, options)], - // [c ASC] - vec![(col_c, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [a ASC, d ASC, floor(a) ASC], - vec![(col_a, options), (col_d, options), (floor_a, options)], - // expected: requirement is satisfied. - true, - ), - // ------------ TEST CASE 7 ------------ - ( - // orderings - vec![ - // [a ASC, c ASC, b ASC] - vec![(col_a, options), (col_c, options), (col_b, options)], - // [d ASC] - vec![(col_d, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [a ASC, floor(a) ASC, a + b ASC], - vec![(col_a, options), (floor_a, options), (&a_plus_b, options)], - // expected: requirement is not satisfied. - false, - ), - // ------------ TEST CASE 8 ------------ - ( - // orderings - vec![ - // [a ASC, b ASC, c ASC] - vec![(col_a, options), (col_b, options), (col_c, options)], - // [d ASC] - vec![(col_d, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [a ASC, c ASC, floor(a) ASC, a + b ASC], - vec![ - (col_a, options), - (col_c, options), - (&floor_a, options), - (&a_plus_b, options), - ], - // expected: requirement is not satisfied. - false, - ), - // ------------ TEST CASE 9 ------------ - ( - // orderings - vec![ - // [a ASC, b ASC, c ASC, d ASC] - vec![ - (col_a, options), - (col_b, options), - (col_c, options), - (col_d, options), - ], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [a ASC, b ASC, c ASC, floor(a) ASC], - vec![ - (col_a, options), - (col_b, options), - (&col_c, options), - (&floor_a, options), - ], - // expected: requirement is satisfied. - true, - ), - // ------------ TEST CASE 10 ------------ - ( - // orderings - vec![ - // [d ASC, b ASC] - vec![(col_d, options), (col_b, options)], - // [c ASC, a ASC] - vec![(col_c, options), (col_a, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [c ASC, d ASC, a + b ASC], - vec![(col_c, options), (col_d, options), (&a_plus_b, options)], - // expected: requirement is satisfied. - true, - ), - ]; - - for (orderings, eq_group, constants, reqs, expected) in test_cases { - let err_msg = - format!("error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}"); - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); - let eq_group = eq_group - .into_iter() - .map(|eq_class| { - let eq_classes = eq_class.into_iter().cloned().collect::>(); - EquivalenceClass::new(eq_classes) - }) - .collect::>(); - let eq_group = EquivalenceGroup::new(eq_group); - eq_properties.add_equivalence_group(eq_group); - - let constants = constants.into_iter().cloned(); - eq_properties = eq_properties.add_constants(constants); - - let reqs = convert_to_sort_exprs(&reqs); - assert_eq!( - eq_properties.ordering_satisfy(&reqs), - expected, - "{}", - err_msg - ); - } - - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 5; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let col_exprs = vec![ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - ]; - - for n_req in 0..=col_exprs.len() { - for exprs in col_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - assert_eq!( - eq_properties.ordering_satisfy(&requirement), - expected, - "{}", - err_msg - ); - } - } - } - - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 100; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - - let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, - &[col("a", &test_schema)?], - &test_schema, - &ExecutionProps::default(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let exprs = vec![ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - floor_a, - a_plus_b, - ]; - - for n_req in 0..=exprs.len() { - for exprs in exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - - assert_eq!( - eq_properties.ordering_satisfy(&requirement), - (expected | false), - "{}", - err_msg - ); - } - } - } - - Ok(()) - } - - #[test] - fn test_ordering_satisfy_different_lengths() -> Result<()> { - let test_schema = create_test_schema()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let options = SortOptions { - descending: false, - nulls_first: false, - }; - // a=c (e.g they are aliases). - let mut eq_properties = EquivalenceProperties::new(test_schema); - eq_properties.add_equal_conditions(col_a, col_c); - - let orderings = vec![ - vec![(col_a, options)], - vec![(col_e, options)], - vec![(col_d, options), (col_f, options)], - ]; - let orderings = convert_to_orderings(&orderings); - - // Column [a ASC], [e ASC], [d ASC, f ASC] are all valid orderings for the schema. - eq_properties.add_new_orderings(orderings); - - // First entry in the tuple is required ordering, second entry is the expected flag - // that indicates whether this required ordering is satisfied. - // ([a ASC], true) indicate a ASC requirement is already satisfied by existing orderings. - let test_cases = vec![ - // [c ASC, a ASC, e ASC], expected represents this requirement is satisfied - ( - vec![(col_c, options), (col_a, options), (col_e, options)], - true, - ), - (vec![(col_c, options), (col_b, options)], false), - (vec![(col_c, options), (col_d, options)], true), - ( - vec![(col_d, options), (col_f, options), (col_b, options)], - false, - ), - (vec![(col_d, options), (col_f, options)], true), - ]; - - for (reqs, expected) in test_cases { - let err_msg = - format!("error in test reqs: {:?}, expected: {:?}", reqs, expected,); - let reqs = convert_to_sort_exprs(&reqs); - assert_eq!( - eq_properties.ordering_satisfy(&reqs), - expected, - "{}", - err_msg - ); - } - - Ok(()) - } - - #[test] - fn test_bridge_groups() -> Result<()> { - // First entry in the tuple is argument, second entry is the bridged result - let test_cases = vec![ - // ------- TEST CASE 1 -----------// - ( - vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]], - // Expected is compared with set equality. Order of the specific results may change. - vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]], - ), - // ------- TEST CASE 2 -----------// - ( - vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]], - // Expected - vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]], - ), - ]; - for (entries, expected) in test_cases { - let entries = entries - .into_iter() - .map(|entry| entry.into_iter().map(lit).collect::>()) - .map(EquivalenceClass::new) - .collect::>(); - let expected = expected - .into_iter() - .map(|entry| entry.into_iter().map(lit).collect::>()) - .map(EquivalenceClass::new) - .collect::>(); - let mut eq_groups = EquivalenceGroup::new(entries.clone()); - eq_groups.bridge_classes(); - let eq_groups = eq_groups.classes; - let err_msg = format!( - "error in test entries: {:?}, expected: {:?}, actual:{:?}", - entries, expected, eq_groups - ); - assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg); - for idx in 0..eq_groups.len() { - assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg); - } - } - Ok(()) - } - - #[test] - fn test_remove_redundant_entries_eq_group() -> Result<()> { - let entries = vec![ - EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), - // This group is meaningless should be removed - EquivalenceClass::new(vec![lit(3), lit(3)]), - EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), - ]; - // Given equivalences classes are not in succinct form. - // Expected form is the most plain representation that is functionally same. - let expected = vec![ - EquivalenceClass::new(vec![lit(1), lit(2)]), - EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), - ]; - let mut eq_groups = EquivalenceGroup::new(entries); - eq_groups.remove_redundant_entries(); - - let eq_groups = eq_groups.classes; - assert_eq!(eq_groups.len(), expected.len()); - assert_eq!(eq_groups.len(), 2); - - assert_eq!(eq_groups[0], expected[0]); - assert_eq!(eq_groups[1], expected[1]); - Ok(()) - } - - #[test] - fn test_remove_redundant_entries_oeq_class() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_d = &col("d", &schema)?; - let col_e = &col("e", &schema)?; - - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - - // First entry in the tuple is the given orderings for the table - // Second entry is the simplest version of the given orderings that is functionally equivalent. - let test_cases = vec![ - // ------- TEST CASE 1 --------- - ( - // ORDERINGS GIVEN - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - ], - // EXPECTED orderings that is succinct. - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - ], - ), - // ------- TEST CASE 2 --------- - ( - // ORDERINGS GIVEN - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - ], - // EXPECTED orderings that is succinct. - vec![ - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - ], - ), - // ------- TEST CASE 3 --------- - ( - // ORDERINGS GIVEN - vec![ - // [a ASC, b DESC] - vec![(col_a, option_asc), (col_b, option_desc)], - // [a ASC] - vec![(col_a, option_asc)], - // [a ASC, c ASC] - vec![(col_a, option_asc), (col_c, option_asc)], - ], - // EXPECTED orderings that is succinct. - vec![ - // [a ASC, b DESC] - vec![(col_a, option_asc), (col_b, option_desc)], - // [a ASC, c ASC] - vec![(col_a, option_asc), (col_c, option_asc)], - ], - ), - // ------- TEST CASE 4 --------- - ( - // ORDERINGS GIVEN - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - // [a ASC] - vec![(col_a, option_asc)], - ], - // EXPECTED orderings that is succinct. - vec![ - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - ], - ), - // ------- TEST CASE 5 --------- - // Empty ordering - ( - vec![vec![]], - // No ordering in the state (empty ordering is ignored). - vec![], - ), - // ------- TEST CASE 6 --------- - ( - // ORDERINGS GIVEN - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [b ASC] - vec![(col_b, option_asc)], - ], - // EXPECTED orderings that is succinct. - vec![ - // [a ASC] - vec![(col_a, option_asc)], - // [b ASC] - vec![(col_b, option_asc)], - ], - ), - // ------- TEST CASE 7 --------- - // b, a - // c, a - // d, b, c - ( - // ORDERINGS GIVEN - vec![ - // [b ASC, a ASC] - vec![(col_b, option_asc), (col_a, option_asc)], - // [c ASC, a ASC] - vec![(col_c, option_asc), (col_a, option_asc)], - // [d ASC, b ASC, c ASC] - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - ], - // EXPECTED orderings that is succinct. - vec![ - // [b ASC, a ASC] - vec![(col_b, option_asc), (col_a, option_asc)], - // [c ASC, a ASC] - vec![(col_c, option_asc), (col_a, option_asc)], - // [d ASC] - vec![(col_d, option_asc)], - ], - ), - // ------- TEST CASE 8 --------- - // b, e - // c, a - // d, b, e, c, a - ( - // ORDERINGS GIVEN - vec![ - // [b ASC, e ASC] - vec![(col_b, option_asc), (col_e, option_asc)], - // [c ASC, a ASC] - vec![(col_c, option_asc), (col_a, option_asc)], - // [d ASC, b ASC, e ASC, c ASC, a ASC] - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_c, option_asc), - (col_a, option_asc), - ], - ], - // EXPECTED orderings that is succinct. - vec![ - // [b ASC, e ASC] - vec![(col_b, option_asc), (col_e, option_asc)], - // [c ASC, a ASC] - vec![(col_c, option_asc), (col_a, option_asc)], - // [d ASC] - vec![(col_d, option_asc)], - ], - ), - // ------- TEST CASE 9 --------- - // b - // a, b, c - // d, a, b - ( - // ORDERINGS GIVEN - vec![ - // [b ASC] - vec![(col_b, option_asc)], - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - // [d ASC, a ASC, b ASC] - vec![ - (col_d, option_asc), - (col_a, option_asc), - (col_b, option_asc), - ], - ], - // EXPECTED orderings that is succinct. - vec![ - // [b ASC] - vec![(col_b, option_asc)], - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - // [d ASC] - vec![(col_d, option_asc)], - ], - ), - ]; - for (orderings, expected) in test_cases { - let orderings = convert_to_orderings(&orderings); - let expected = convert_to_orderings(&expected); - let actual = OrderingEquivalenceClass::new(orderings.clone()); - let actual = actual.orderings; - let err_msg = format!( - "orderings: {:?}, expected: {:?}, actual :{:?}", - orderings, expected, actual - ); - assert_eq!(actual.len(), expected.len(), "{}", err_msg); - for elem in actual { - assert!(expected.contains(&elem), "{}", err_msg); - } - } - - Ok(()) - } - - #[test] - fn test_get_updated_right_ordering_equivalence_properties() -> Result<()> { - let join_type = JoinType::Inner; - // Join right child schema - let child_fields: Fields = ["x", "y", "z", "w"] - .into_iter() - .map(|name| Field::new(name, DataType::Int32, true)) - .collect(); - let child_schema = Schema::new(child_fields); - let col_x = &col("x", &child_schema)?; - let col_y = &col("y", &child_schema)?; - let col_z = &col("z", &child_schema)?; - let col_w = &col("w", &child_schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - // [x ASC, y ASC], [z ASC, w ASC] - let orderings = vec![ - vec![(col_x, option_asc), (col_y, option_asc)], - vec![(col_z, option_asc), (col_w, option_asc)], - ]; - let orderings = convert_to_orderings(&orderings); - // Right child ordering equivalences - let mut right_oeq_class = OrderingEquivalenceClass::new(orderings); - - let left_columns_len = 4; - - let fields: Fields = ["a", "b", "c", "d", "x", "y", "z", "w"] - .into_iter() - .map(|name| Field::new(name, DataType::Int32, true)) - .collect(); - - // Join Schema - let schema = Schema::new(fields); - let col_a = &col("a", &schema)?; - let col_d = &col("d", &schema)?; - let col_x = &col("x", &schema)?; - let col_y = &col("y", &schema)?; - let col_z = &col("z", &schema)?; - let col_w = &col("w", &schema)?; - - let mut join_eq_properties = EquivalenceProperties::new(Arc::new(schema)); - // a=x and d=w - join_eq_properties.add_equal_conditions(col_a, col_x); - join_eq_properties.add_equal_conditions(col_d, col_w); - - updated_right_ordering_equivalence_class( - &mut right_oeq_class, - &join_type, - left_columns_len, - ); - join_eq_properties.add_ordering_equivalence_class(right_oeq_class); - let result = join_eq_properties.oeq_class().clone(); - - // [x ASC, y ASC], [z ASC, w ASC] - let orderings = vec![ - vec![(col_x, option_asc), (col_y, option_asc)], - vec![(col_z, option_asc), (col_w, option_asc)], - ]; - let orderings = convert_to_orderings(&orderings); - let expected = OrderingEquivalenceClass::new(orderings); - - assert_eq!(result, expected); - - Ok(()) - } - - /// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. - /// - /// The function works by adding a unique column of ascending integers to the original table. This column ensures - /// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can - /// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce - /// deterministic sorting results. - /// - /// If the table remains the same after sorting with the added unique column, it indicates that the table was - /// already sorted according to `required_ordering` to begin with. - fn is_table_same_after_sort( - mut required_ordering: Vec, - batch: RecordBatch, - ) -> Result { - // Clone the original schema and columns - let original_schema = batch.schema(); - let mut columns = batch.columns().to_vec(); - - // Create a new unique column - let n_row = batch.num_rows(); - let vals: Vec = (0..n_row).collect::>(); - let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); - let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; - columns.push(unique_col.clone()); - - // Create a new schema with the added unique column - let unique_col_name = "unique"; - let unique_field = - Arc::new(Field::new(unique_col_name, DataType::Float64, false)); - let fields: Vec<_> = original_schema - .fields() - .iter() - .cloned() - .chain(std::iter::once(unique_field)) - .collect(); - let schema = Arc::new(Schema::new(fields)); - - // Create a new batch with the added column - let new_batch = RecordBatch::try_new(schema.clone(), columns)?; - - // Add the unique column to the required ordering to ensure deterministic results - required_ordering.push(PhysicalSortExpr { - expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), - options: Default::default(), - }); - - // Convert the required ordering to a list of SortColumn - let sort_columns = required_ordering - .iter() - .map(|order_expr| { - let expr_result = order_expr.expr.evaluate(&new_batch)?; - let values = expr_result.into_array(new_batch.num_rows())?; - Ok(SortColumn { - values, - options: Some(order_expr.options), - }) - }) - .collect::>>()?; - - // Check if the indices after sorting match the initial ordering - let sorted_indices = lexsort_to_indices(&sort_columns, None)?; - let original_indices = UInt32Array::from_iter_values(0..n_row as u32); - - Ok(sorted_indices == original_indices) - } - - // If we already generated a random result for one of the - // expressions in the equivalence classes. For other expressions in the same - // equivalence class use same result. This util gets already calculated result, when available. - fn get_representative_arr( - eq_group: &EquivalenceClass, - existing_vec: &[Option], - schema: SchemaRef, - ) -> Option { - for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - if let Some(res) = &existing_vec[idx] { - return Some(res.clone()); - } - } - None - } - - // Generate a table that satisfies the given equivalence properties; i.e. - // equivalences, ordering equivalences, and constants. - fn generate_table_for_eq_properties( - eq_properties: &EquivalenceProperties, - n_elem: usize, - n_distinct: usize, - ) -> Result { - let mut rng = StdRng::seed_from_u64(23); - - let schema = eq_properties.schema(); - let mut schema_vec = vec![None; schema.fields.len()]; - - // Utility closure to generate random array - let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { - let values: Vec = (0..num_elems) - .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) - .collect(); - Arc::new(Float64Array::from_iter_values(values)) - }; - - // Fill constant columns - for constant in &eq_properties.constants { - let col = constant.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) - as ArrayRef; - schema_vec[idx] = Some(arr); - } - - // Fill columns based on ordering equivalences - for ordering in eq_properties.oeq_class.iter() { - let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering - .iter() - .map(|PhysicalSortExpr { expr, options }| { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = generate_random_array(n_elem, n_distinct); - ( - SortColumn { - values: arr, - options: Some(*options), - }, - idx, - ) - }) - .unzip(); - - let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; - for (idx, arr) in izip!(indices, sort_arrs) { - schema_vec[idx] = Some(arr); - } - } - - // Fill columns based on equivalence groups - for eq_group in eq_properties.eq_group.iter() { - let representative_array = - get_representative_arr(eq_group, &schema_vec, schema.clone()) - .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); - - for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - schema_vec[idx] = Some(representative_array.clone()); - } - } - - let res: Vec<_> = schema_vec - .into_iter() - .zip(schema.fields.iter()) - .map(|(elem, field)| { - ( - field.name(), - // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) - elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), - ) - }) - .collect(); - - Ok(RecordBatch::try_from_iter(res)?) - } - - #[test] - fn test_schema_normalize_expr_with_equivalence() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); - // Assume that column a and c are aliases. - let (_test_schema, eq_properties) = create_test_params()?; - - let col_a_expr = Arc::new(col_a.clone()) as Arc; - let col_b_expr = Arc::new(col_b.clone()) as Arc; - let col_c_expr = Arc::new(col_c.clone()) as Arc; - // Test cases for equivalence normalization, - // First entry in the tuple is argument, second entry is expected result after normalization. - let expressions = vec![ - // Normalized version of the column a and c should go to a - // (by convention all the expressions inside equivalence class are mapped to the first entry - // in this case a is the first entry in the equivalence class.) - (&col_a_expr, &col_a_expr), - (&col_c_expr, &col_a_expr), - // Cannot normalize column b - (&col_b_expr, &col_b_expr), - ]; - let eq_group = eq_properties.eq_group(); - for (expr, expected_eq) in expressions { - assert!( - expected_eq.eq(&eq_group.normalize_expr(expr.clone())), - "error in test: expr: {expr:?}" - ); - } - - Ok(()) - } - - #[test] - fn test_schema_normalize_sort_requirement_with_equivalence() -> Result<()> { - let option1 = SortOptions { - descending: false, - nulls_first: false, - }; - // Assume that column a and c are aliases. - let (test_schema, eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - - // Test cases for equivalence normalization - // First entry in the tuple is PhysicalSortRequirement, second entry in the tuple is - // expected PhysicalSortRequirement after normalization. - let test_cases = vec![ - (vec![(col_a, Some(option1))], vec![(col_a, Some(option1))]), - // In the normalized version column c should be replace with column a - (vec![(col_c, Some(option1))], vec![(col_a, Some(option1))]), - (vec![(col_c, None)], vec![(col_a, None)]), - (vec![(col_d, Some(option1))], vec![(col_d, Some(option1))]), - ]; - for (reqs, expected) in test_cases.into_iter() { - let reqs = convert_to_sort_reqs(&reqs); - let expected = convert_to_sort_reqs(&expected); - - let normalized = eq_properties.normalize_sort_requirements(&reqs); - assert!( - expected.eq(&normalized), - "error in test: reqs: {reqs:?}, expected: {expected:?}, normalized: {normalized:?}" - ); - } - - Ok(()) - } - - #[test] - fn test_normalize_sort_reqs() -> Result<()> { - // Schema satisfies following properties - // a=c - // and following orderings are valid - // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - let (test_schema, eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function - let requirements = vec![ - ( - vec![(col_a, Some(option_asc))], - vec![(col_a, Some(option_asc))], - ), - ( - vec![(col_a, Some(option_desc))], - vec![(col_a, Some(option_desc))], - ), - (vec![(col_a, None)], vec![(col_a, None)]), - // Test whether equivalence works as expected - ( - vec![(col_c, Some(option_asc))], - vec![(col_a, Some(option_asc))], - ), - (vec![(col_c, None)], vec![(col_a, None)]), - // Test whether ordering equivalence works as expected - ( - vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], - vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], - ), - ( - vec![(col_d, None), (col_b, None)], - vec![(col_d, None), (col_b, None)], - ), - ( - vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], - vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], - ), - // We should be able to normalize in compatible requirements also (not exactly equal) - ( - vec![(col_e, Some(option_desc)), (col_f, None)], - vec![(col_e, Some(option_desc)), (col_f, None)], - ), - ( - vec![(col_e, None), (col_f, None)], - vec![(col_e, None), (col_f, None)], - ), - ]; - - for (reqs, expected_normalized) in requirements.into_iter() { - let req = convert_to_sort_reqs(&reqs); - let expected_normalized = convert_to_sort_reqs(&expected_normalized); - - assert_eq!( - eq_properties.normalize_sort_requirements(&req), - expected_normalized - ); - } - - Ok(()) - } - - #[test] - fn test_get_finer() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let eq_properties = EquivalenceProperties::new(schema); - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - // First entry, and second entry are the physical sort requirement that are argument for get_finer_requirement. - // Third entry is the expected result. - let tests_cases = vec![ - // Get finer requirement between [a Some(ASC)] and [a None, b Some(ASC)] - // result should be [a Some(ASC), b Some(ASC)] - ( - vec![(col_a, Some(option_asc))], - vec![(col_a, None), (col_b, Some(option_asc))], - Some(vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))]), - ), - // Get finer requirement between [a Some(ASC), b Some(ASC), c Some(ASC)] and [a Some(ASC), b Some(ASC)] - // result should be [a Some(ASC), b Some(ASC), c Some(ASC)] - ( - vec![ - (col_a, Some(option_asc)), - (col_b, Some(option_asc)), - (col_c, Some(option_asc)), - ], - vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], - Some(vec![ - (col_a, Some(option_asc)), - (col_b, Some(option_asc)), - (col_c, Some(option_asc)), - ]), - ), - // Get finer requirement between [a Some(ASC), b Some(ASC)] and [a Some(ASC), b Some(DESC)] - // result should be None - ( - vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], - vec![(col_a, Some(option_asc)), (col_b, Some(option_desc))], - None, - ), - ]; - for (lhs, rhs, expected) in tests_cases { - let lhs = convert_to_sort_reqs(&lhs); - let rhs = convert_to_sort_reqs(&rhs); - let expected = expected.map(|expected| convert_to_sort_reqs(&expected)); - let finer = eq_properties.get_finer_requirement(&lhs, &rhs); - assert_eq!(finer, expected) - } - - Ok(()) - } - - #[test] - fn test_get_meet_ordering() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let eq_properties = EquivalenceProperties::new(schema); - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let tests_cases = vec![ - // Get meet ordering between [a ASC] and [a ASC, b ASC] - // result should be [a ASC] - ( - vec![(col_a, option_asc)], - vec![(col_a, option_asc), (col_b, option_asc)], - Some(vec![(col_a, option_asc)]), - ), - // Get meet ordering between [a ASC] and [a DESC] - // result should be None. - (vec![(col_a, option_asc)], vec![(col_a, option_desc)], None), - // Get meet ordering between [a ASC, b ASC] and [a ASC, b DESC] - // result should be [a ASC]. - ( - vec![(col_a, option_asc), (col_b, option_asc)], - vec![(col_a, option_asc), (col_b, option_desc)], - Some(vec![(col_a, option_asc)]), - ), - ]; - for (lhs, rhs, expected) in tests_cases { - let lhs = convert_to_sort_exprs(&lhs); - let rhs = convert_to_sort_exprs(&rhs); - let expected = expected.map(|expected| convert_to_sort_exprs(&expected)); - let finer = eq_properties.get_meet_ordering(&lhs, &rhs); - assert_eq!(finer, expected) - } - - Ok(()) - } - - #[test] - fn test_find_longest_permutation() -> Result<()> { - // Schema satisfies following orderings: - // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - // and - // Column [a=c] (e.g they are aliases). - // At below we add [d ASC, h DESC] also, for test purposes - let (test_schema, mut eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_h = &col("h", &test_schema)?; - // a + d - let a_plus_d = Arc::new(BinaryExpr::new( - col_a.clone(), - Operator::Plus, - col_d.clone(), - )) as Arc; - - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - // [d ASC, h ASC] also satisfies schema. - eq_properties.add_new_orderings([vec![ - PhysicalSortExpr { - expr: col_d.clone(), - options: option_asc, - }, - PhysicalSortExpr { - expr: col_h.clone(), - options: option_desc, - }, - ]]); - let test_cases = vec![ - // TEST CASE 1 - (vec![col_a], vec![(col_a, option_asc)]), - // TEST CASE 2 - (vec![col_c], vec![(col_c, option_asc)]), - // TEST CASE 3 - ( - vec![col_d, col_e, col_b], - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_b, option_asc), - ], - ), - // TEST CASE 4 - (vec![col_b], vec![]), - // TEST CASE 5 - (vec![col_d], vec![(col_d, option_asc)]), - // TEST CASE 5 - (vec![&a_plus_d], vec![(&a_plus_d, option_asc)]), - // TEST CASE 6 - ( - vec![col_b, col_d], - vec![(col_d, option_asc), (col_b, option_asc)], - ), - // TEST CASE 6 - ( - vec![col_c, col_e], - vec![(col_c, option_asc), (col_e, option_desc)], - ), - ]; - for (exprs, expected) in test_cases { - let exprs = exprs.into_iter().cloned().collect::>(); - let expected = convert_to_sort_exprs(&expected); - let (actual, _) = eq_properties.find_longest_permutation(&exprs); - assert_eq!(actual, expected); - } - - Ok(()) - } - - #[test] - fn test_find_longest_permutation_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 100; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - - let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, - &[col("a", &test_schema)?], - &test_schema, - &ExecutionProps::default(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let exprs = vec![ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - floor_a, - a_plus_b, - ]; - - for n_req in 0..=exprs.len() { - for exprs in exprs.iter().combinations(n_req) { - let exprs = exprs.into_iter().cloned().collect::>(); - let (ordering, indices) = - eq_properties.find_longest_permutation(&exprs); - // Make sure that find_longest_permutation return values are consistent - let ordering2 = indices - .iter() - .zip(ordering.iter()) - .map(|(&idx, sort_expr)| PhysicalSortExpr { - expr: exprs[idx].clone(), - options: sort_expr.options, - }) - .collect::>(); - assert_eq!( - ordering, ordering2, - "indices and lexicographical ordering do not match" - ); - - let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - assert_eq!(ordering.len(), indices.len(), "{}", err_msg); - // Since ordered section satisfies schema, we expect - // that result will be same after sort (e.g sort was unnecessary). - assert!( - is_table_same_after_sort( - ordering.clone(), - table_data_with_properties.clone(), - )?, - "{}", - err_msg - ); - } - } - } - - Ok(()) - } - - #[test] - fn test_update_ordering() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - Field::new("d", DataType::Int32, true), - ]); - - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_d = &col("d", &schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - // b=a (e.g they are aliases) - eq_properties.add_equal_conditions(col_b, col_a); - // [b ASC], [d ASC] - eq_properties.add_new_orderings(vec![ - vec![PhysicalSortExpr { - expr: col_b.clone(), - options: option_asc, - }], - vec![PhysicalSortExpr { - expr: col_d.clone(), - options: option_asc, - }], - ]); - - let test_cases = vec![ - // d + b - ( - Arc::new(BinaryExpr::new( - col_d.clone(), - Operator::Plus, - col_b.clone(), - )) as Arc, - SortProperties::Ordered(option_asc), - ), - // b - (col_b.clone(), SortProperties::Ordered(option_asc)), - // a - (col_a.clone(), SortProperties::Ordered(option_asc)), - // a + c - ( - Arc::new(BinaryExpr::new( - col_a.clone(), - Operator::Plus, - col_c.clone(), - )), - SortProperties::Unordered, - ), - ]; - for (expr, expected) in test_cases { - let leading_orderings = eq_properties - .oeq_class() - .iter() - .flat_map(|ordering| ordering.first().cloned()) - .collect::>(); - let expr_ordering = eq_properties.get_expr_ordering(expr.clone()); - let err_msg = format!( - "expr:{:?}, expected: {:?}, actual: {:?}, leading_orderings: {leading_orderings:?}", - expr, expected, expr_ordering.state - ); - assert_eq!(expr_ordering.state, expected, "{}", err_msg); - } - - Ok(()) - } - - #[test] - fn test_contains_any() { - let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) - as Arc; - let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) - as Arc; - let lit2 = - Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; - let lit1 = - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - - let cls1 = EquivalenceClass::new(vec![lit_true.clone(), lit_false.clone()]); - let cls2 = EquivalenceClass::new(vec![lit_true.clone(), col_b_expr.clone()]); - let cls3 = EquivalenceClass::new(vec![lit2.clone(), lit1.clone()]); - - // lit_true is common - assert!(cls1.contains_any(&cls2)); - // there is no common entry - assert!(!cls1.contains_any(&cls3)); - assert!(!cls2.contains_any(&cls3)); - } - - #[test] - fn test_get_indices_of_matching_sort_exprs_with_order_eq() -> Result<()> { - let sort_options = SortOptions::default(); - let sort_options_not = SortOptions::default().not(); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ]); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let required_columns = [col_b.clone(), col_a.clone()]; - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - eq_properties.add_new_orderings([vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ]]); - let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); - assert_eq!(idxs, vec![0, 1]); - assert_eq!( - result, - vec![ - PhysicalSortExpr { - expr: col_b.clone(), - options: sort_options_not - }, - PhysicalSortExpr { - expr: col_a.clone(), - options: sort_options - } - ] - ); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let required_columns = [col_b.clone(), col_a.clone()]; - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - eq_properties.add_new_orderings([ - vec![PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }], - vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ], - ]); - let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); - assert_eq!(idxs, vec![0, 1]); - assert_eq!( - result, - vec![ - PhysicalSortExpr { - expr: col_b.clone(), - options: sort_options_not - }, - PhysicalSortExpr { - expr: col_a.clone(), - options: sort_options - } - ] - ); - - let required_columns = [ - Arc::new(Column::new("b", 1)) as _, - Arc::new(Column::new("a", 0)) as _, - ]; - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - - // not satisfied orders - eq_properties.add_new_orderings([vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ]]); - let (_, idxs) = eq_properties.find_longest_permutation(&required_columns); - assert_eq!(idxs, vec![0]); - - Ok(()) - } - - #[test] - fn test_normalize_ordering_equivalence_classes() -> Result<()> { - let sort_options = SortOptions::default(); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let col_a_expr = col("a", &schema)?; - let col_b_expr = col("b", &schema)?; - let col_c_expr = col("c", &schema)?; - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - - eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr); - let others = vec![ - vec![PhysicalSortExpr { - expr: col_b_expr.clone(), - options: sort_options, - }], - vec![PhysicalSortExpr { - expr: col_c_expr.clone(), - options: sort_options, - }], - ]; - eq_properties.add_new_orderings(others); - - let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema)); - expected_eqs.add_new_orderings([ - vec![PhysicalSortExpr { - expr: col_b_expr.clone(), - options: sort_options, - }], - vec![PhysicalSortExpr { - expr: col_c_expr.clone(), - options: sort_options, - }], - ]); - - let oeq_class = eq_properties.oeq_class().clone(); - let expected = expected_eqs.oeq_class(); - assert!(oeq_class.eq(expected)); - - Ok(()) - } - - #[test] - fn project_orderings() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - Field::new("d", DataType::Int32, true), - Field::new("e", DataType::Int32, true), - Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), - ])); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_d = &col("d", &schema)?; - let col_e = &col("e", &schema)?; - let col_ts = &col("ts", &schema)?; - let interval = Arc::new(Literal::new(ScalarValue::IntervalDayTime(Some(2)))) - as Arc; - let date_bin_func = &create_physical_expr( - &BuiltinScalarFunction::DateBin, - &[interval, col_ts.clone()], - &schema, - &ExecutionProps::default(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), - Operator::Plus, - col_b.clone(), - )) as Arc; - let b_plus_d = Arc::new(BinaryExpr::new( - col_b.clone(), - Operator::Plus, - col_d.clone(), - )) as Arc; - let b_plus_e = Arc::new(BinaryExpr::new( - col_b.clone(), - Operator::Plus, - col_e.clone(), - )) as Arc; - let c_plus_d = Arc::new(BinaryExpr::new( - col_c.clone(), - Operator::Plus, - col_d.clone(), - )) as Arc; - - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - - let test_cases = vec![ - // ---------- TEST CASE 1 ------------ - ( - // orderings - vec![ - // [b ASC] - vec![(col_b, option_asc)], - ], - // projection exprs - vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())], - // expected - vec![ - // [b_new ASC] - vec![("b_new", option_asc)], - ], - ), - // ---------- TEST CASE 2 ------------ - ( - // orderings - vec![ - // empty ordering - ], - // projection exprs - vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())], - // expected - vec![ - // no ordering at the output - ], - ), - // ---------- TEST CASE 3 ------------ - ( - // orderings - vec![ - // [ts ASC] - vec![(col_ts, option_asc)], - ], - // projection exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_ts, "ts_new".to_string()), - (date_bin_func, "date_bin_res".to_string()), - ], - // expected - vec![ - // [date_bin_res ASC] - vec![("date_bin_res", option_asc)], - // [ts_new ASC] - vec![("ts_new", option_asc)], - ], - ), - // ---------- TEST CASE 4 ------------ - ( - // orderings - vec![ - // [a ASC, ts ASC] - vec![(col_a, option_asc), (col_ts, option_asc)], - // [b ASC, ts ASC] - vec![(col_b, option_asc), (col_ts, option_asc)], - ], - // projection exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_ts, "ts_new".to_string()), - (date_bin_func, "date_bin_res".to_string()), - ], - // expected - vec![ - // [a_new ASC, ts_new ASC] - vec![("a_new", option_asc), ("ts_new", option_asc)], - // [a_new ASC, date_bin_res ASC] - vec![("a_new", option_asc), ("date_bin_res", option_asc)], - // [b_new ASC, ts_new ASC] - vec![("b_new", option_asc), ("ts_new", option_asc)], - // [b_new ASC, date_bin_res ASC] - vec![("b_new", option_asc), ("date_bin_res", option_asc)], - ], - ), - // ---------- TEST CASE 5 ------------ - ( - // orderings - vec![ - // [a + b ASC] - vec![(&a_plus_b, option_asc)], - ], - // projection exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (&a_plus_b, "a+b".to_string()), - ], - // expected - vec![ - // [a + b ASC] - vec![("a+b", option_asc)], - ], - ), - // ---------- TEST CASE 6 ------------ - ( - // orderings - vec![ - // [a + b ASC, c ASC] - vec![(&a_plus_b, option_asc), (&col_c, option_asc)], - ], - // projection exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_c, "c_new".to_string()), - (&a_plus_b, "a+b".to_string()), - ], - // expected - vec![ - // [a + b ASC, c_new ASC] - vec![("a+b", option_asc), ("c_new", option_asc)], - ], - ), - // ------- TEST CASE 7 ---------- - ( - vec![ - // [a ASC, b ASC, c ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [a ASC, d ASC] - vec![(col_a, option_asc), (col_d, option_asc)], - ], - // b as b_new, a as a_new, d as d_new b+d - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_d, "d_new".to_string()), - (&b_plus_d, "b+d".to_string()), - ], - // expected - vec![ - // [a_new ASC, b_new ASC] - vec![("a_new", option_asc), ("b_new", option_asc)], - // [a_new ASC, d_new ASC] - vec![("a_new", option_asc), ("d_new", option_asc)], - // [a_new ASC, b+d ASC] - vec![("a_new", option_asc), ("b+d", option_asc)], - ], - ), - // ------- TEST CASE 8 ---------- - ( - // orderings - vec![ - // [b+d ASC] - vec![(&b_plus_d, option_asc)], - ], - // proj exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_d, "d_new".to_string()), - (&b_plus_d, "b+d".to_string()), - ], - // expected - vec![ - // [b+d ASC] - vec![("b+d", option_asc)], - ], - ), - // ------- TEST CASE 9 ---------- - ( - // orderings - vec![ - // [a ASC, d ASC, b ASC] - vec![ - (col_a, option_asc), - (col_d, option_asc), - (col_b, option_asc), - ], - // [c ASC] - vec![(col_c, option_asc)], - ], - // proj exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_d, "d_new".to_string()), - (col_c, "c_new".to_string()), - ], - // expected - vec![ - // [a_new ASC, d_new ASC, b_new ASC] - vec![ - ("a_new", option_asc), - ("d_new", option_asc), - ("b_new", option_asc), - ], - // [c_new ASC], - vec![("c_new", option_asc)], - ], - ), - // ------- TEST CASE 10 ---------- - ( - vec![ - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - // [a ASC, d ASC] - vec![(col_a, option_asc), (col_d, option_asc)], - ], - // proj exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_c, "c_new".to_string()), - (&c_plus_d, "c+d".to_string()), - ], - // expected - vec![ - // [a_new ASC, b_new ASC, c_new ASC] - vec![ - ("a_new", option_asc), - ("b_new", option_asc), - ("c_new", option_asc), - ], - // [a_new ASC, b_new ASC, c+d ASC] - vec![ - ("a_new", option_asc), - ("b_new", option_asc), - ("c+d", option_asc), - ], - ], - ), - // ------- TEST CASE 11 ---------- - ( - // orderings - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [a ASC, d ASC] - vec![(col_a, option_asc), (col_d, option_asc)], - ], - // proj exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (&b_plus_d, "b+d".to_string()), - ], - // expected - vec![ - // [a_new ASC, b_new ASC] - vec![("a_new", option_asc), ("b_new", option_asc)], - // [a_new ASC, b + d ASC] - vec![("a_new", option_asc), ("b+d", option_asc)], - ], - ), - // ------- TEST CASE 12 ---------- - ( - // orderings - vec![ - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - ], - // proj exprs - vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())], - // expected - vec![ - // [a_new ASC] - vec![("a_new", option_asc)], - ], - ), - // ------- TEST CASE 13 ---------- - ( - // orderings - vec![ - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - // [a ASC, a + b ASC, c ASC] - vec![ - (col_a, option_asc), - (&a_plus_b, option_asc), - (col_c, option_asc), - ], - ], - // proj exprs - vec![ - (col_c, "c_new".to_string()), - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (&a_plus_b, "a+b".to_string()), - ], - // expected - vec![ - // [a_new ASC, b_new ASC, c_new ASC] - vec![ - ("a_new", option_asc), - ("b_new", option_asc), - ("c_new", option_asc), - ], - // [a_new ASC, a+b ASC, c_new ASC] - vec![ - ("a_new", option_asc), - ("a+b", option_asc), - ("c_new", option_asc), - ], - ], - ), - // ------- TEST CASE 14 ---------- - ( - // orderings - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [c ASC, b ASC] - vec![(col_c, option_asc), (col_b, option_asc)], - // [d ASC, e ASC] - vec![(col_d, option_asc), (col_e, option_asc)], - ], - // proj exprs - vec![ - (col_c, "c_new".to_string()), - (col_d, "d_new".to_string()), - (col_a, "a_new".to_string()), - (&b_plus_e, "b+e".to_string()), - ], - // expected - vec![ - // [a_new ASC, d_new ASC, b+e ASC] - vec![ - ("a_new", option_asc), - ("d_new", option_asc), - ("b+e", option_asc), - ], - // [d_new ASC, a_new ASC, b+e ASC] - vec![ - ("d_new", option_asc), - ("a_new", option_asc), - ("b+e", option_asc), - ], - // [c_new ASC, d_new ASC, b+e ASC] - vec![ - ("c_new", option_asc), - ("d_new", option_asc), - ("b+e", option_asc), - ], - // [d_new ASC, c_new ASC, b+e ASC] - vec![ - ("d_new", option_asc), - ("c_new", option_asc), - ("b+e", option_asc), - ], - ], - ), - // ------- TEST CASE 15 ---------- - ( - // orderings - vec![ - // [a ASC, c ASC, b ASC] - vec![ - (col_a, option_asc), - (col_c, option_asc), - (&col_b, option_asc), - ], - ], - // proj exprs - vec![ - (col_c, "c_new".to_string()), - (col_a, "a_new".to_string()), - (&a_plus_b, "a+b".to_string()), - ], - // expected - vec![ - // [a_new ASC, d_new ASC, b+e ASC] - vec![ - ("a_new", option_asc), - ("c_new", option_asc), - ("a+b", option_asc), - ], - ], - ), - // ------- TEST CASE 16 ---------- - ( - // orderings - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [c ASC, b DESC] - vec![(col_c, option_asc), (col_b, option_desc)], - // [e ASC] - vec![(col_e, option_asc)], - ], - // proj exprs - vec![ - (col_c, "c_new".to_string()), - (col_a, "a_new".to_string()), - (col_b, "b_new".to_string()), - (&b_plus_e, "b+e".to_string()), - ], - // expected - vec![ - // [a_new ASC, b_new ASC] - vec![("a_new", option_asc), ("b_new", option_asc)], - // [a_new ASC, b_new ASC] - vec![("a_new", option_asc), ("b+e", option_asc)], - // [c_new ASC, b_new DESC] - vec![("c_new", option_asc), ("b_new", option_desc)], - ], - ), - ]; - - for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate() - { - let mut eq_properties = EquivalenceProperties::new(schema.clone()); - - let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); - - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (expr.clone(), name)) - .collect::>(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; - let output_schema = output_schema(&projection_mapping, &schema)?; - - let expected = expected - .into_iter() - .map(|ordering| { - ordering - .into_iter() - .map(|(name, options)| { - (col(name, &output_schema).unwrap(), options) - }) - .collect::>() - }) - .collect::>(); - let expected = convert_to_orderings_owned(&expected); - - let projected_eq = eq_properties.project(&projection_mapping, output_schema); - let orderings = projected_eq.oeq_class(); - - let err_msg = format!( - "test_idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", - idx, orderings.orderings, expected, projection_mapping - ); - - assert_eq!(orderings.len(), expected.len(), "{}", err_msg); - for expected_ordering in &expected { - assert!(orderings.contains(expected_ordering), "{}", err_msg) - } - } - - Ok(()) - } - - #[test] - fn project_orderings2() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - Field::new("d", DataType::Int32, true), - Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), - ])); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_ts = &col("ts", &schema)?; - let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), - Operator::Plus, - col_b.clone(), - )) as Arc; - let interval = Arc::new(Literal::new(ScalarValue::IntervalDayTime(Some(2)))) - as Arc; - let date_bin_ts = &create_physical_expr( - &BuiltinScalarFunction::DateBin, - &[interval, col_ts.clone()], - &schema, - &ExecutionProps::default(), - )?; - - let round_c = &create_physical_expr( - &BuiltinScalarFunction::Round, - &[col_c.clone()], - &schema, - &ExecutionProps::default(), - )?; - - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - - let proj_exprs = vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_c, "c_new".to_string()), - (date_bin_ts, "date_bin_res".to_string()), - (round_c, "round_c_res".to_string()), - ]; - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (expr.clone(), name)) - .collect::>(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; - let output_schema = output_schema(&projection_mapping, &schema)?; - - let col_a_new = &col("a_new", &output_schema)?; - let col_b_new = &col("b_new", &output_schema)?; - let col_c_new = &col("c_new", &output_schema)?; - let col_date_bin_res = &col("date_bin_res", &output_schema)?; - let col_round_c_res = &col("round_c_res", &output_schema)?; - let a_new_plus_b_new = Arc::new(BinaryExpr::new( - col_a_new.clone(), - Operator::Plus, - col_b_new.clone(), - )) as Arc; - - let test_cases = vec![ - // ---------- TEST CASE 1 ------------ - ( - // orderings - vec![ - // [a ASC] - vec![(col_a, option_asc)], - ], - // expected - vec![ - // [b_new ASC] - vec![(col_a_new, option_asc)], - ], - ), - // ---------- TEST CASE 2 ------------ - ( - // orderings - vec![ - // [a+b ASC] - vec![(&a_plus_b, option_asc)], - ], - // expected - vec![ - // [b_new ASC] - vec![(&a_new_plus_b_new, option_asc)], - ], - ), - // ---------- TEST CASE 3 ------------ - ( - // orderings - vec![ - // [a ASC, ts ASC] - vec![(col_a, option_asc), (col_ts, option_asc)], - ], - // expected - vec![ - // [a_new ASC, date_bin_res ASC] - vec![(col_a_new, option_asc), (col_date_bin_res, option_asc)], - ], - ), - // ---------- TEST CASE 4 ------------ - ( - // orderings - vec![ - // [a ASC, ts ASC, b ASC] - vec![ - (col_a, option_asc), - (col_ts, option_asc), - (col_b, option_asc), - ], - ], - // expected - vec![ - // [a_new ASC, date_bin_res ASC] - // Please note that result is not [a_new ASC, date_bin_res ASC, b_new ASC] - // because, datebin_res may not be 1-1 function. Hence without introducing ts - // dependency we cannot guarantee any ordering after date_bin_res column. - vec![(col_a_new, option_asc), (col_date_bin_res, option_asc)], - ], - ), - // ---------- TEST CASE 5 ------------ - ( - // orderings - vec![ - // [a ASC, c ASC] - vec![(col_a, option_asc), (col_c, option_asc)], - ], - // expected - vec![ - // [a_new ASC, round_c_res ASC, c_new ASC] - vec![(col_a_new, option_asc), (col_round_c_res, option_asc)], - // [a_new ASC, c_new ASC] - vec![(col_a_new, option_asc), (col_c_new, option_asc)], - ], - ), - // ---------- TEST CASE 6 ------------ - ( - // orderings - vec![ - // [c ASC, b ASC] - vec![(col_c, option_asc), (col_b, option_asc)], - ], - // expected - vec![ - // [round_c_res ASC] - vec![(col_round_c_res, option_asc)], - // [c_new ASC, b_new ASC] - vec![(col_c_new, option_asc), (col_b_new, option_asc)], - ], - ), - // ---------- TEST CASE 7 ------------ - ( - // orderings - vec![ - // [a+b ASC, c ASC] - vec![(&a_plus_b, option_asc), (col_c, option_asc)], - ], - // expected - vec![ - // [a+b ASC, round(c) ASC, c_new ASC] - vec![ - (&a_new_plus_b_new, option_asc), - (&col_round_c_res, option_asc), - ], - // [a+b ASC, c_new ASC] - vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)], - ], - ), - ]; - - for (idx, (orderings, expected)) in test_cases.iter().enumerate() { - let mut eq_properties = EquivalenceProperties::new(schema.clone()); - - let orderings = convert_to_orderings(orderings); - eq_properties.add_new_orderings(orderings); - - let expected = convert_to_orderings(expected); - - let projected_eq = - eq_properties.project(&projection_mapping, output_schema.clone()); - let orderings = projected_eq.oeq_class(); - - let err_msg = format!( - "test idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", - idx, orderings.orderings, expected, projection_mapping - ); - - assert_eq!(orderings.len(), expected.len(), "{}", err_msg); - for expected_ordering in &expected { - assert!(orderings.contains(expected_ordering), "{}", err_msg) - } - } - Ok(()) - } - - #[test] - fn project_orderings3() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - Field::new("d", DataType::Int32, true), - Field::new("e", DataType::Int32, true), - Field::new("f", DataType::Int32, true), - ])); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_d = &col("d", &schema)?; - let col_e = &col("e", &schema)?; - let col_f = &col("f", &schema)?; - let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), - Operator::Plus, - col_b.clone(), - )) as Arc; - - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - - let proj_exprs = vec![ - (col_c, "c_new".to_string()), - (col_d, "d_new".to_string()), - (&a_plus_b, "a+b".to_string()), - ]; - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (expr.clone(), name)) - .collect::>(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; - let output_schema = output_schema(&projection_mapping, &schema)?; - - let col_a_plus_b_new = &col("a+b", &output_schema)?; - let col_c_new = &col("c_new", &output_schema)?; - let col_d_new = &col("d_new", &output_schema)?; - - let test_cases = vec![ - // ---------- TEST CASE 1 ------------ - ( - // orderings - vec![ - // [d ASC, b ASC] - vec![(col_d, option_asc), (col_b, option_asc)], - // [c ASC, a ASC] - vec![(col_c, option_asc), (col_a, option_asc)], - ], - // equal conditions - vec![], - // expected - vec![ - // [d_new ASC, c_new ASC, a+b ASC] - vec![ - (col_d_new, option_asc), - (col_c_new, option_asc), - (col_a_plus_b_new, option_asc), - ], - // [c_new ASC, d_new ASC, a+b ASC] - vec![ - (col_c_new, option_asc), - (col_d_new, option_asc), - (col_a_plus_b_new, option_asc), - ], - ], - ), - // ---------- TEST CASE 2 ------------ - ( - // orderings - vec![ - // [d ASC, b ASC] - vec![(col_d, option_asc), (col_b, option_asc)], - // [c ASC, e ASC], Please note that a=e - vec![(col_c, option_asc), (col_e, option_asc)], - ], - // equal conditions - vec![(col_e, col_a)], - // expected - vec![ - // [d_new ASC, c_new ASC, a+b ASC] - vec![ - (col_d_new, option_asc), - (col_c_new, option_asc), - (col_a_plus_b_new, option_asc), - ], - // [c_new ASC, d_new ASC, a+b ASC] - vec![ - (col_c_new, option_asc), - (col_d_new, option_asc), - (col_a_plus_b_new, option_asc), - ], - ], - ), - // ---------- TEST CASE 3 ------------ - ( - // orderings - vec![ - // [d ASC, b ASC] - vec![(col_d, option_asc), (col_b, option_asc)], - // [c ASC, e ASC], Please note that a=f - vec![(col_c, option_asc), (col_e, option_asc)], - ], - // equal conditions - vec![(col_a, col_f)], - // expected - vec![ - // [d_new ASC] - vec![(col_d_new, option_asc)], - // [c_new ASC] - vec![(col_c_new, option_asc)], - ], - ), - ]; - for (orderings, equal_columns, expected) in test_cases { - let mut eq_properties = EquivalenceProperties::new(schema.clone()); - for (lhs, rhs) in equal_columns { - eq_properties.add_equal_conditions(lhs, rhs); - } - - let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); - - let expected = convert_to_orderings(&expected); - - let projected_eq = - eq_properties.project(&projection_mapping, output_schema.clone()); - let orderings = projected_eq.oeq_class(); - - let err_msg = format!( - "actual: {:?}, expected: {:?}, projection_mapping: {:?}", - orderings.orderings, expected, projection_mapping - ); - - assert_eq!(orderings.len(), expected.len(), "{}", err_msg); - for expected_ordering in &expected { - assert!(orderings.contains(expected_ordering), "{}", err_msg) - } - } - - Ok(()) - } - - #[test] - fn project_orderings_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 20; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - // Floor(a) - let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, - &[col("a", &test_schema)?], - &test_schema, - &ExecutionProps::default(), - )?; - // a + b - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let proj_exprs = vec![ - (col("a", &test_schema)?, "a_new"), - (col("b", &test_schema)?, "b_new"), - (col("c", &test_schema)?, "c_new"), - (col("d", &test_schema)?, "d_new"), - (col("e", &test_schema)?, "e_new"), - (col("f", &test_schema)?, "f_new"), - (floor_a, "floor(a)"), - (a_plus_b, "a+b"), - ]; - - for n_req in 0..=proj_exprs.len() { - for proj_exprs in proj_exprs.iter().combinations(n_req) { - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (expr.clone(), name.to_string())) - .collect::>(); - let (projected_batch, projected_eq) = apply_projection( - proj_exprs.clone(), - &table_data_with_properties, - &eq_properties, - )?; - - // Make sure each ordering after projection is valid. - for ordering in projected_eq.oeq_class().iter() { - let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}", - ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs - ); - // Since ordered section satisfies schema, we expect - // that result will be same after sort (e.g sort was unnecessary). - assert!( - is_table_same_after_sort( - ordering.clone(), - projected_batch.clone(), - )?, - "{}", - err_msg - ); - } - } - } - } - - Ok(()) - } - - #[test] - fn ordering_satisfy_after_projection_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 20; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - // Floor(a) - let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, - &[col("a", &test_schema)?], - &test_schema, - &ExecutionProps::default(), - )?; - // a + b - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let proj_exprs = vec![ - (col("a", &test_schema)?, "a_new"), - (col("b", &test_schema)?, "b_new"), - (col("c", &test_schema)?, "c_new"), - (col("d", &test_schema)?, "d_new"), - (col("e", &test_schema)?, "e_new"), - (col("f", &test_schema)?, "f_new"), - (floor_a, "floor(a)"), - (a_plus_b, "a+b"), - ]; - - for n_req in 0..=proj_exprs.len() { - for proj_exprs in proj_exprs.iter().combinations(n_req) { - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (expr.clone(), name.to_string())) - .collect::>(); - let (projected_batch, projected_eq) = apply_projection( - proj_exprs.clone(), - &table_data_with_properties, - &eq_properties, - )?; - - let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &test_schema)?; - - let projected_exprs = projection_mapping - .iter() - .map(|(_source, target)| target.clone()) - .collect::>(); - - for n_req in 0..=projected_exprs.len() { - for exprs in projected_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - projected_batch.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - assert_eq!( - projected_eq.ordering_satisfy(&requirement), - expected, - "{}", - err_msg - ); - } - } - } - } - } - - Ok(()) - } - - #[test] - fn test_expr_consists_of_constants() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - Field::new("d", DataType::Int32, true), - Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), - ])); - let col_a = col("a", &schema)?; - let col_b = col("b", &schema)?; - let col_d = col("d", &schema)?; - let b_plus_d = Arc::new(BinaryExpr::new( - col_b.clone(), - Operator::Plus, - col_d.clone(), - )) as Arc; - - let constants = vec![col_a.clone(), col_b.clone()]; - let expr = b_plus_d.clone(); - assert!(!is_constant_recurse(&constants, &expr)); - - let constants = vec![col_a.clone(), col_b.clone(), col_d.clone()]; - let expr = b_plus_d.clone(); - assert!(is_constant_recurse(&constants, &expr)); - Ok(()) - } - - #[test] - fn test_join_equivalence_properties() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let offset = schema.fields.len(); - let col_a2 = &add_offset_to_expr(col_a.clone(), offset); - let col_b2 = &add_offset_to_expr(col_b.clone(), offset); - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let test_cases = vec![ - // ------- TEST CASE 1 -------- - // [a ASC], [b ASC] - ( - // [a ASC], [b ASC] - vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], - // [a ASC], [b ASC] - vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], - // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC] - vec![ - vec![(col_a, option_asc), (col_a2, option_asc)], - vec![(col_a, option_asc), (col_b2, option_asc)], - vec![(col_b, option_asc), (col_a2, option_asc)], - vec![(col_b, option_asc), (col_b2, option_asc)], - ], - ), - // ------- TEST CASE 2 -------- - // [a ASC], [b ASC] - ( - // [a ASC], [b ASC], [c ASC] - vec![ - vec![(col_a, option_asc)], - vec![(col_b, option_asc)], - vec![(col_c, option_asc)], - ], - // [a ASC], [b ASC] - vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], - // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC], [c ASC, a2 ASC], [c ASC, b2 ASC] - vec![ - vec![(col_a, option_asc), (col_a2, option_asc)], - vec![(col_a, option_asc), (col_b2, option_asc)], - vec![(col_b, option_asc), (col_a2, option_asc)], - vec![(col_b, option_asc), (col_b2, option_asc)], - vec![(col_c, option_asc), (col_a2, option_asc)], - vec![(col_c, option_asc), (col_b2, option_asc)], - ], - ), - ]; - for (left_orderings, right_orderings, expected) in test_cases { - let mut left_eq_properties = EquivalenceProperties::new(schema.clone()); - let mut right_eq_properties = EquivalenceProperties::new(schema.clone()); - let left_orderings = convert_to_orderings(&left_orderings); - let right_orderings = convert_to_orderings(&right_orderings); - let expected = convert_to_orderings(&expected); - left_eq_properties.add_new_orderings(left_orderings); - right_eq_properties.add_new_orderings(right_orderings); - let join_eq = join_equivalence_properties( - left_eq_properties, - right_eq_properties, - &JoinType::Inner, - Arc::new(Schema::empty()), - &[true, false], - Some(JoinSide::Left), - &[], - ); - let orderings = &join_eq.oeq_class.orderings; - let err_msg = format!("expected: {:?}, actual:{:?}", expected, orderings); - assert_eq!( - join_eq.oeq_class.orderings.len(), - expected.len(), - "{}", - err_msg - ); - for ordering in orderings { - assert!( - expected.contains(ordering), - "{}, ordering: {:?}", - err_msg, - ordering - ); - } - } - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs new file mode 100644 index 000000000000..f0bd1740d5d2 --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -0,0 +1,598 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping}; +use crate::{ + expressions::Column, physical_expr::deduplicate_physical_exprs, + physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexOrderingRef, + LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, +}; +use datafusion_common::tree_node::TreeNode; +use datafusion_common::{tree_node::Transformed, JoinType}; +use std::sync::Arc; + +/// An `EquivalenceClass` is a set of [`Arc`]s that are known +/// to have the same value for all tuples in a relation. These are generated by +/// equality predicates (e.g. `a = b`), typically equi-join conditions and +/// equality conditions in filters. +/// +/// Two `EquivalenceClass`es are equal if they contains the same expressions in +/// without any ordering. +#[derive(Debug, Clone)] +pub struct EquivalenceClass { + /// The expressions in this equivalence class. The order doesn't + /// matter for equivalence purposes + /// + /// TODO: use a HashSet for this instead of a Vec + exprs: Vec>, +} + +impl PartialEq for EquivalenceClass { + /// Returns true if other is equal in the sense + /// of bags (multi-sets), disregarding their orderings. + fn eq(&self, other: &Self) -> bool { + physical_exprs_bag_equal(&self.exprs, &other.exprs) + } +} + +impl EquivalenceClass { + /// Create a new empty equivalence class + pub fn new_empty() -> Self { + Self { exprs: vec![] } + } + + // Create a new equivalence class from a pre-existing `Vec` + pub fn new(mut exprs: Vec>) -> Self { + deduplicate_physical_exprs(&mut exprs); + Self { exprs } + } + + /// Return the inner vector of expressions + pub fn into_vec(self) -> Vec> { + self.exprs + } + + /// Return the "canonical" expression for this class (the first element) + /// if any + fn canonical_expr(&self) -> Option> { + self.exprs.first().cloned() + } + + /// Insert the expression into this class, meaning it is known to be equal to + /// all other expressions in this class + pub fn push(&mut self, expr: Arc) { + if !self.contains(&expr) { + self.exprs.push(expr); + } + } + + /// Inserts all the expressions from other into this class + pub fn extend(&mut self, other: Self) { + for expr in other.exprs { + // use push so entries are deduplicated + self.push(expr); + } + } + + /// Returns true if this equivalence class contains t expression + pub fn contains(&self, expr: &Arc) -> bool { + physical_exprs_contains(&self.exprs, expr) + } + + /// Returns true if this equivalence class has any entries in common with `other` + pub fn contains_any(&self, other: &Self) -> bool { + self.exprs.iter().any(|e| other.contains(e)) + } + + /// return the number of items in this class + pub fn len(&self) -> usize { + self.exprs.len() + } + + /// return true if this class is empty + pub fn is_empty(&self) -> bool { + self.exprs.is_empty() + } + + /// Iterate over all elements in this class, in some arbitrary order + pub fn iter(&self) -> impl Iterator> { + self.exprs.iter() + } + + /// Return a new equivalence class that have the specified offset added to + /// each expression (used when schemas are appended such as in joins) + pub fn with_offset(&self, offset: usize) -> Self { + let new_exprs = self + .exprs + .iter() + .cloned() + .map(|e| add_offset_to_expr(e, offset)) + .collect(); + Self::new(new_exprs) + } +} + +/// An `EquivalenceGroup` is a collection of `EquivalenceClass`es where each +/// class represents a distinct equivalence class in a relation. +#[derive(Debug, Clone)] +pub struct EquivalenceGroup { + pub classes: Vec, +} + +impl EquivalenceGroup { + /// Creates an empty equivalence group. + pub fn empty() -> Self { + Self { classes: vec![] } + } + + /// Creates an equivalence group from the given equivalence classes. + pub fn new(classes: Vec) -> Self { + let mut result = Self { classes }; + result.remove_redundant_entries(); + result + } + + /// Returns how many equivalence classes there are in this group. + pub fn len(&self) -> usize { + self.classes.len() + } + + /// Checks whether this equivalence group is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns an iterator over the equivalence classes in this group. + pub fn iter(&self) -> impl Iterator { + self.classes.iter() + } + + /// Adds the equality `left` = `right` to this equivalence group. + /// New equality conditions often arise after steps like `Filter(a = b)`, + /// `Alias(a, a as b)` etc. + pub fn add_equal_conditions( + &mut self, + left: &Arc, + right: &Arc, + ) { + let mut first_class = None; + let mut second_class = None; + for (idx, cls) in self.classes.iter().enumerate() { + if cls.contains(left) { + first_class = Some(idx); + } + if cls.contains(right) { + second_class = Some(idx); + } + } + match (first_class, second_class) { + (Some(mut first_idx), Some(mut second_idx)) => { + // If the given left and right sides belong to different classes, + // we should unify/bridge these classes. + if first_idx != second_idx { + // By convention, make sure `second_idx` is larger than `first_idx`. + if first_idx > second_idx { + (first_idx, second_idx) = (second_idx, first_idx); + } + // Remove the class at `second_idx` and merge its values with + // the class at `first_idx`. The convention above makes sure + // that `first_idx` is still valid after removing `second_idx`. + let other_class = self.classes.swap_remove(second_idx); + self.classes[first_idx].extend(other_class); + } + } + (Some(group_idx), None) => { + // Right side is new, extend left side's class: + self.classes[group_idx].push(right.clone()); + } + (None, Some(group_idx)) => { + // Left side is new, extend right side's class: + self.classes[group_idx].push(left.clone()); + } + (None, None) => { + // None of the expressions is among existing classes. + // Create a new equivalence class and extend the group. + self.classes + .push(EquivalenceClass::new(vec![left.clone(), right.clone()])); + } + } + } + + /// Removes redundant entries from this group. + fn remove_redundant_entries(&mut self) { + // Remove duplicate entries from each equivalence class: + self.classes.retain_mut(|cls| { + // Keep groups that have at least two entries as singleton class is + // meaningless (i.e. it contains no non-trivial information): + cls.len() > 1 + }); + // Unify/bridge groups that have common expressions: + self.bridge_classes() + } + + /// This utility function unifies/bridges classes that have common expressions. + /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`. + /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all + /// equal and belong to one class. This utility converts merges such classes. + fn bridge_classes(&mut self) { + let mut idx = 0; + while idx < self.classes.len() { + let mut next_idx = idx + 1; + let start_size = self.classes[idx].len(); + while next_idx < self.classes.len() { + if self.classes[idx].contains_any(&self.classes[next_idx]) { + let extension = self.classes.swap_remove(next_idx); + self.classes[idx].extend(extension); + } else { + next_idx += 1; + } + } + if self.classes[idx].len() > start_size { + continue; + } + idx += 1; + } + } + + /// Extends this equivalence group with the `other` equivalence group. + pub fn extend(&mut self, other: Self) { + self.classes.extend(other.classes); + self.remove_redundant_entries(); + } + + /// Normalizes the given physical expression according to this group. + /// The expression is replaced with the first expression in the equivalence + /// class it matches with (if any). + pub fn normalize_expr(&self, expr: Arc) -> Arc { + expr.clone() + .transform(&|expr| { + for cls in self.iter() { + if cls.contains(&expr) { + return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); + } + } + Ok(Transformed::No(expr)) + }) + .unwrap_or(expr) + } + + /// Normalizes the given sort expression according to this group. + /// The underlying physical expression is replaced with the first expression + /// in the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, returns + /// the sort expression as is. + pub fn normalize_sort_expr( + &self, + mut sort_expr: PhysicalSortExpr, + ) -> PhysicalSortExpr { + sort_expr.expr = self.normalize_expr(sort_expr.expr); + sort_expr + } + + /// Normalizes the given sort requirement according to this group. + /// The underlying physical expression is replaced with the first expression + /// in the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, returns + /// the given sort requirement as is. + pub fn normalize_sort_requirement( + &self, + mut sort_requirement: PhysicalSortRequirement, + ) -> PhysicalSortRequirement { + sort_requirement.expr = self.normalize_expr(sort_requirement.expr); + sort_requirement + } + + /// This function applies the `normalize_expr` function for all expressions + /// in `exprs` and returns the corresponding normalized physical expressions. + pub fn normalize_exprs( + &self, + exprs: impl IntoIterator>, + ) -> Vec> { + exprs + .into_iter() + .map(|expr| self.normalize_expr(expr)) + .collect() + } + + /// This function applies the `normalize_sort_expr` function for all sort + /// expressions in `sort_exprs` and returns the corresponding normalized + /// sort expressions. + pub fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { + // Convert sort expressions to sort requirements: + let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); + // Normalize the requirements: + let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); + // Convert sort requirements back to sort expressions: + PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) + } + + /// This function applies the `normalize_sort_requirement` function for all + /// requirements in `sort_reqs` and returns the corresponding normalized + /// sort requirements. + pub fn normalize_sort_requirements( + &self, + sort_reqs: LexRequirementRef, + ) -> LexRequirement { + collapse_lex_req( + sort_reqs + .iter() + .map(|sort_req| self.normalize_sort_requirement(sort_req.clone())) + .collect(), + ) + } + + /// Projects `expr` according to the given projection mapping. + /// If the resulting expression is invalid after projection, returns `None`. + pub fn project_expr( + &self, + mapping: &ProjectionMapping, + expr: &Arc, + ) -> Option> { + // First, we try to project expressions with an exact match. If we are + // unable to do this, we consult equivalence classes. + if let Some(target) = mapping.target_expr(expr) { + // If we match the source, we can project directly: + return Some(target); + } else { + // If the given expression is not inside the mapping, try to project + // expressions considering the equivalence classes. + for (source, target) in mapping.iter() { + // If we match an equivalent expression to `source`, then we can + // project. For example, if we have the mapping `(a as a1, a + c)` + // and the equivalence class `(a, b)`, expression `b` projects to `a1`. + if self + .get_equivalence_class(source) + .map_or(false, |group| group.contains(expr)) + { + return Some(target.clone()); + } + } + } + // Project a non-leaf expression by projecting its children. + let children = expr.children(); + if children.is_empty() { + // Leaf expression should be inside mapping. + return None; + } + children + .into_iter() + .map(|child| self.project_expr(mapping, &child)) + .collect::>>() + .map(|children| expr.clone().with_new_children(children).unwrap()) + } + + /// Projects this equivalence group according to the given projection mapping. + pub fn project(&self, mapping: &ProjectionMapping) -> Self { + let projected_classes = self.iter().filter_map(|cls| { + let new_class = cls + .iter() + .filter_map(|expr| self.project_expr(mapping, expr)) + .collect::>(); + (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) + }); + // TODO: Convert the algorithm below to a version that uses `HashMap`. + // once `Arc` can be stored in `HashMap`. + // See issue: https://github.com/apache/arrow-datafusion/issues/8027 + let mut new_classes = vec![]; + for (source, target) in mapping.iter() { + if new_classes.is_empty() { + new_classes.push((source, vec![target.clone()])); + } + if let Some((_, values)) = + new_classes.iter_mut().find(|(key, _)| key.eq(source)) + { + if !physical_exprs_contains(values, target) { + values.push(target.clone()); + } + } + } + // Only add equivalence classes with at least two members as singleton + // equivalence classes are meaningless. + let new_classes = new_classes + .into_iter() + .filter_map(|(_, values)| (values.len() > 1).then_some(values)) + .map(EquivalenceClass::new); + + let classes = projected_classes.chain(new_classes).collect(); + Self::new(classes) + } + + /// Returns the equivalence class containing `expr`. If no equivalence class + /// contains `expr`, returns `None`. + fn get_equivalence_class( + &self, + expr: &Arc, + ) -> Option<&EquivalenceClass> { + self.iter().find(|cls| cls.contains(expr)) + } + + /// Combine equivalence groups of the given join children. + pub fn join( + &self, + right_equivalences: &Self, + join_type: &JoinType, + left_size: usize, + on: &[(Column, Column)], + ) -> Self { + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + let mut result = Self::new( + self.iter() + .cloned() + .chain( + right_equivalences + .iter() + .map(|cls| cls.with_offset(left_size)), + ) + .collect(), + ); + // In we have an inner join, expressions in the "on" condition + // are equal in the resulting table. + if join_type == &JoinType::Inner { + for (lhs, rhs) in on.iter() { + let index = rhs.index() + left_size; + let new_lhs = Arc::new(lhs.clone()) as _; + let new_rhs = Arc::new(Column::new(rhs.name(), index)) as _; + result.add_equal_conditions(&new_lhs, &new_rhs); + } + } + result + } + JoinType::LeftSemi | JoinType::LeftAnti => self.clone(), + JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::equivalence::tests::create_test_params; + use crate::equivalence::{EquivalenceClass, EquivalenceGroup}; + use crate::expressions::lit; + use crate::expressions::Column; + use crate::expressions::Literal; + use datafusion_common::Result; + use datafusion_common::ScalarValue; + use std::sync::Arc; + + #[test] + fn test_bridge_groups() -> Result<()> { + // First entry in the tuple is argument, second entry is the bridged result + let test_cases = vec![ + // ------- TEST CASE 1 -----------// + ( + vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]], + // Expected is compared with set equality. Order of the specific results may change. + vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]], + ), + // ------- TEST CASE 2 -----------// + ( + vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]], + // Expected + vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]], + ), + ]; + for (entries, expected) in test_cases { + let entries = entries + .into_iter() + .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(EquivalenceClass::new) + .collect::>(); + let expected = expected + .into_iter() + .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(EquivalenceClass::new) + .collect::>(); + let mut eq_groups = EquivalenceGroup::new(entries.clone()); + eq_groups.bridge_classes(); + let eq_groups = eq_groups.classes; + let err_msg = format!( + "error in test entries: {:?}, expected: {:?}, actual:{:?}", + entries, expected, eq_groups + ); + assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg); + for idx in 0..eq_groups.len() { + assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg); + } + } + Ok(()) + } + + #[test] + fn test_remove_redundant_entries_eq_group() -> Result<()> { + let entries = vec![ + EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), + // This group is meaningless should be removed + EquivalenceClass::new(vec![lit(3), lit(3)]), + EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + ]; + // Given equivalences classes are not in succinct form. + // Expected form is the most plain representation that is functionally same. + let expected = vec![ + EquivalenceClass::new(vec![lit(1), lit(2)]), + EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + ]; + let mut eq_groups = EquivalenceGroup::new(entries); + eq_groups.remove_redundant_entries(); + + let eq_groups = eq_groups.classes; + assert_eq!(eq_groups.len(), expected.len()); + assert_eq!(eq_groups.len(), 2); + + assert_eq!(eq_groups[0], expected[0]); + assert_eq!(eq_groups[1], expected[1]); + Ok(()) + } + + #[test] + fn test_schema_normalize_expr_with_equivalence() -> Result<()> { + let col_a = &Column::new("a", 0); + let col_b = &Column::new("b", 1); + let col_c = &Column::new("c", 2); + // Assume that column a and c are aliases. + let (_test_schema, eq_properties) = create_test_params()?; + + let col_a_expr = Arc::new(col_a.clone()) as Arc; + let col_b_expr = Arc::new(col_b.clone()) as Arc; + let col_c_expr = Arc::new(col_c.clone()) as Arc; + // Test cases for equivalence normalization, + // First entry in the tuple is argument, second entry is expected result after normalization. + let expressions = vec![ + // Normalized version of the column a and c should go to a + // (by convention all the expressions inside equivalence class are mapped to the first entry + // in this case a is the first entry in the equivalence class.) + (&col_a_expr, &col_a_expr), + (&col_c_expr, &col_a_expr), + // Cannot normalize column b + (&col_b_expr, &col_b_expr), + ]; + let eq_group = eq_properties.eq_group(); + for (expr, expected_eq) in expressions { + assert!( + expected_eq.eq(&eq_group.normalize_expr(expr.clone())), + "error in test: expr: {expr:?}" + ); + } + + Ok(()) + } + + #[test] + fn test_contains_any() { + let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + as Arc; + let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + as Arc; + let lit2 = + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let lit1 = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + + let cls1 = EquivalenceClass::new(vec![lit_true.clone(), lit_false.clone()]); + let cls2 = EquivalenceClass::new(vec![lit_true.clone(), col_b_expr.clone()]); + let cls3 = EquivalenceClass::new(vec![lit2.clone(), lit1.clone()]); + + // lit_true is common + assert!(cls1.contains_any(&cls2)); + // there is no common entry + assert!(!cls1.contains_any(&cls3)); + assert!(!cls2.contains_any(&cls3)); + } +} diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs new file mode 100644 index 000000000000..387dce2cdc8b --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -0,0 +1,533 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod class; +mod ordering; +mod projection; +mod properties; +use crate::expressions::Column; +use crate::{LexRequirement, PhysicalExpr, PhysicalSortRequirement}; +pub use class::{EquivalenceClass, EquivalenceGroup}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +pub use ordering::OrderingEquivalenceClass; +pub use projection::ProjectionMapping; +pub use properties::{join_equivalence_properties, EquivalenceProperties}; +use std::sync::Arc; + +/// This function constructs a duplicate-free `LexOrderingReq` by filtering out +/// duplicate entries that have same physical expression inside. For example, +/// `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a Some(ASC)]`. +pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement { + let mut output = Vec::::new(); + for item in input { + if !output.iter().any(|req| req.expr.eq(&item.expr)) { + output.push(item); + } + } + output +} + +/// Adds the `offset` value to `Column` indices inside `expr`. This function is +/// generally used during the update of the right table schema in join operations. +pub fn add_offset_to_expr( + expr: Arc, + offset: usize, +) -> Arc { + expr.transform_down(&|e| match e.as_any().downcast_ref::() { + Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( + col.name(), + offset + col.index(), + )))), + None => Ok(Transformed::No(e)), + }) + .unwrap() + // Note that we can safely unwrap here since our transform always returns + // an `Ok` value. +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{col, Column}; + use crate::PhysicalSortExpr; + use arrow::compute::{lexsort_to_indices, SortColumn}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; + use arrow_schema::{SchemaRef, SortOptions}; + use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; + use itertools::izip; + use rand::rngs::StdRng; + use rand::seq::SliceRandom; + use rand::{Rng, SeedableRng}; + use std::sync::Arc; + + pub fn output_schema( + mapping: &ProjectionMapping, + input_schema: &Arc, + ) -> Result { + // Calculate output schema + let fields: Result> = mapping + .iter() + .map(|(source, target)| { + let name = target + .as_any() + .downcast_ref::() + .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? + .name(); + let field = Field::new( + name, + source.data_type(input_schema)?, + source.nullable(input_schema)?, + ); + + Ok(field) + }) + .collect(); + + let output_schema = Arc::new(Schema::new_with_metadata( + fields?, + input_schema.metadata().clone(), + )); + + Ok(output_schema) + } + + // Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h) + pub fn create_test_schema() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let g = Field::new("g", DataType::Int32, true); + let h = Field::new("h", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h])); + + Ok(schema) + } + + /// Construct a schema with following properties + /// Schema satisfies following orderings: + /// [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + /// and + /// Column [a=c] (e.g they are aliases). + pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + eq_properties.add_equal_conditions(col_a, col_c); + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let orderings = vec![ + // [a ASC] + vec![(col_a, option_asc)], + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [e DESC, f ASC, g ASC] + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + ]; + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + Ok((test_schema, eq_properties)) + } + + // Generate a schema which consists of 6 columns (a, b, c, d, e, f) + fn create_test_schema_2() -> Result { + let a = Field::new("a", DataType::Float64, true); + let b = Field::new("b", DataType::Float64, true); + let c = Field::new("c", DataType::Float64, true); + let d = Field::new("d", DataType::Float64, true); + let e = Field::new("e", DataType::Float64, true); + let f = Field::new("f", DataType::Float64, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) + } + + /// Construct a schema with random ordering + /// among column a, b, c, d + /// where + /// Column [a=f] (e.g they are aliases). + /// Column e is constant. + pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema_2()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; + + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + // Define a and f are aliases + eq_properties.add_equal_conditions(col_a, col_f); + // Column e has constant value. + eq_properties = eq_properties.add_constants([col_e.clone()]); + + // Randomly order columns for sorting + let mut rng = StdRng::seed_from_u64(seed); + let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted + + let options_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + while !remaining_exprs.is_empty() { + let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + remaining_exprs.shuffle(&mut rng); + + let ordering = remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: options_asc, + }) + .collect(); + + eq_properties.add_new_orderings([ordering]); + } + + Ok((test_schema, eq_properties)) + } + + // Convert each tuple to PhysicalSortRequirement + pub fn convert_to_sort_reqs( + in_data: &[(&Arc, Option)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| { + PhysicalSortRequirement::new((*expr).clone(), *options) + }) + .collect() + } + + // Convert each tuple to PhysicalSortExpr + pub fn convert_to_sort_exprs( + in_data: &[(&Arc, SortOptions)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: (*expr).clone(), + options: *options, + }) + .collect() + } + + // Convert each inner tuple to PhysicalSortExpr + pub fn convert_to_orderings( + orderings: &[Vec<(&Arc, SortOptions)>], + ) -> Vec> { + orderings + .iter() + .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) + .collect() + } + + // Convert each tuple to PhysicalSortExpr + pub fn convert_to_sort_exprs_owned( + in_data: &[(Arc, SortOptions)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: (*expr).clone(), + options: *options, + }) + .collect() + } + + // Convert each inner tuple to PhysicalSortExpr + pub fn convert_to_orderings_owned( + orderings: &[Vec<(Arc, SortOptions)>], + ) -> Vec> { + orderings + .iter() + .map(|sort_exprs| convert_to_sort_exprs_owned(sort_exprs)) + .collect() + } + + // Apply projection to the input_data, return projected equivalence properties and record batch + pub fn apply_projection( + proj_exprs: Vec<(Arc, String)>, + input_data: &RecordBatch, + input_eq_properties: &EquivalenceProperties, + ) -> Result<(RecordBatch, EquivalenceProperties)> { + let input_schema = input_data.schema(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + let output_schema = output_schema(&projection_mapping, &input_schema)?; + let num_rows = input_data.num_rows(); + // Apply projection to the input record batch. + let projected_values = projection_mapping + .iter() + .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) + .collect::>>()?; + let projected_batch = if projected_values.is_empty() { + RecordBatch::new_empty(output_schema.clone()) + } else { + RecordBatch::try_new(output_schema.clone(), projected_values)? + }; + + let projected_eq = + input_eq_properties.project(&projection_mapping, output_schema); + Ok((projected_batch, projected_eq)) + } + + #[test] + fn add_equal_conditions_test() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("x", DataType::Int64, true), + Field::new("y", DataType::Int64, true), + ])); + + let mut eq_properties = EquivalenceProperties::new(schema); + let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; + let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; + let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + + // a and b are aliases + eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + + // This new entry is redundant, size shouldn't increase + eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 2); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + + // b and c are aliases. Exising equivalence class should expand, + // however there shouldn't be any new equivalence class + eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 3); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + + // This is a new set of equality. Hence equivalent class count should be 2. + eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr); + assert_eq!(eq_properties.eq_group().len(), 2); + + // This equality bridges distinct equality sets. + // Hence equivalent class count should decrease from 2 to 1. + eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 5); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_x_expr)); + assert!(eq_groups.contains(&col_y_expr)); + + Ok(()) + } + + /// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. + /// + /// The function works by adding a unique column of ascending integers to the original table. This column ensures + /// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can + /// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce + /// deterministic sorting results. + /// + /// If the table remains the same after sorting with the added unique column, it indicates that the table was + /// already sorted according to `required_ordering` to begin with. + pub fn is_table_same_after_sort( + mut required_ordering: Vec, + batch: RecordBatch, + ) -> Result { + // Clone the original schema and columns + let original_schema = batch.schema(); + let mut columns = batch.columns().to_vec(); + + // Create a new unique column + let n_row = batch.num_rows(); + let vals: Vec = (0..n_row).collect::>(); + let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); + let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; + columns.push(unique_col.clone()); + + // Create a new schema with the added unique column + let unique_col_name = "unique"; + let unique_field = + Arc::new(Field::new(unique_col_name, DataType::Float64, false)); + let fields: Vec<_> = original_schema + .fields() + .iter() + .cloned() + .chain(std::iter::once(unique_field)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + + // Create a new batch with the added column + let new_batch = RecordBatch::try_new(schema.clone(), columns)?; + + // Add the unique column to the required ordering to ensure deterministic results + required_ordering.push(PhysicalSortExpr { + expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), + options: Default::default(), + }); + + // Convert the required ordering to a list of SortColumn + let sort_columns = required_ordering + .iter() + .map(|order_expr| { + let expr_result = order_expr.expr.evaluate(&new_batch)?; + let values = expr_result.into_array(new_batch.num_rows())?; + Ok(SortColumn { + values, + options: Some(order_expr.options), + }) + }) + .collect::>>()?; + + // Check if the indices after sorting match the initial ordering + let sorted_indices = lexsort_to_indices(&sort_columns, None)?; + let original_indices = UInt32Array::from_iter_values(0..n_row as u32); + + Ok(sorted_indices == original_indices) + } + + // If we already generated a random result for one of the + // expressions in the equivalence classes. For other expressions in the same + // equivalence class use same result. This util gets already calculated result, when available. + fn get_representative_arr( + eq_group: &EquivalenceClass, + existing_vec: &[Option], + schema: SchemaRef, + ) -> Option { + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + if let Some(res) = &existing_vec[idx] { + return Some(res.clone()); + } + } + None + } + + // Generate a table that satisfies the given equivalence properties; i.e. + // equivalences, ordering equivalences, and constants. + pub fn generate_table_for_eq_properties( + eq_properties: &EquivalenceProperties, + n_elem: usize, + n_distinct: usize, + ) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + let schema = eq_properties.schema(); + let mut schema_vec = vec![None; schema.fields.len()]; + + // Utility closure to generate random array + let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { + let values: Vec = (0..num_elems) + .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) + .collect(); + Arc::new(Float64Array::from_iter_values(values)) + }; + + // Fill constant columns + for constant in &eq_properties.constants { + let col = constant.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) + as ArrayRef; + schema_vec[idx] = Some(arr); + } + + // Fill columns based on ordering equivalences + for ordering in eq_properties.oeq_class.iter() { + let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering + .iter() + .map(|PhysicalSortExpr { expr, options }| { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = generate_random_array(n_elem, n_distinct); + ( + SortColumn { + values: arr, + options: Some(*options), + }, + idx, + ) + }) + .unzip(); + + let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + for (idx, arr) in izip!(indices, sort_arrs) { + schema_vec[idx] = Some(arr); + } + } + + // Fill columns based on equivalence groups + for eq_group in eq_properties.eq_group.iter() { + let representative_array = + get_representative_arr(eq_group, &schema_vec, schema.clone()) + .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); + + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + schema_vec[idx] = Some(representative_array.clone()); + } + } + + let res: Vec<_> = schema_vec + .into_iter() + .zip(schema.fields.iter()) + .map(|(elem, field)| { + ( + field.name(), + // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) + elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), + ) + }) + .collect(); + + Ok(RecordBatch::try_from_iter(res)?) + } +} diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs new file mode 100644 index 000000000000..1a414592ce4c --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -0,0 +1,1159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::SortOptions; +use std::hash::Hash; +use std::sync::Arc; + +use crate::equivalence::add_offset_to_expr; +use crate::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; + +/// An `OrderingEquivalenceClass` object keeps track of different alternative +/// orderings than can describe a schema. For example, consider the following table: +/// +/// ```text +/// |a|b|c|d| +/// |1|4|3|1| +/// |2|3|3|2| +/// |3|1|2|2| +/// |3|2|1|3| +/// ``` +/// +/// Here, both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the table +/// ordering. In this case, we say that these orderings are equivalent. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct OrderingEquivalenceClass { + pub orderings: Vec, +} + +impl OrderingEquivalenceClass { + /// Creates new empty ordering equivalence class. + pub fn empty() -> Self { + Self { orderings: vec![] } + } + + /// Clears (empties) this ordering equivalence class. + pub fn clear(&mut self) { + self.orderings.clear(); + } + + /// Creates new ordering equivalence class from the given orderings. + pub fn new(orderings: Vec) -> Self { + let mut result = Self { orderings }; + result.remove_redundant_entries(); + result + } + + /// Checks whether `ordering` is a member of this equivalence class. + pub fn contains(&self, ordering: &LexOrdering) -> bool { + self.orderings.contains(ordering) + } + + /// Adds `ordering` to this equivalence class. + #[allow(dead_code)] + fn push(&mut self, ordering: LexOrdering) { + self.orderings.push(ordering); + // Make sure that there are no redundant orderings: + self.remove_redundant_entries(); + } + + /// Checks whether this ordering equivalence class is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns an iterator over the equivalent orderings in this class. + pub fn iter(&self) -> impl Iterator { + self.orderings.iter() + } + + /// Returns how many equivalent orderings there are in this class. + pub fn len(&self) -> usize { + self.orderings.len() + } + + /// Extend this ordering equivalence class with the `other` class. + pub fn extend(&mut self, other: Self) { + self.orderings.extend(other.orderings); + // Make sure that there are no redundant orderings: + self.remove_redundant_entries(); + } + + /// Adds new orderings into this ordering equivalence class. + pub fn add_new_orderings( + &mut self, + orderings: impl IntoIterator, + ) { + self.orderings.extend(orderings); + // Make sure that there are no redundant orderings: + self.remove_redundant_entries(); + } + + /// Removes redundant orderings from this equivalence class. For instance, + /// if we already have the ordering `[a ASC, b ASC, c DESC]`, then there is + /// no need to keep ordering `[a ASC, b ASC]` in the state. + fn remove_redundant_entries(&mut self) { + let mut work = true; + while work { + work = false; + let mut idx = 0; + while idx < self.orderings.len() { + let mut ordering_idx = idx + 1; + let mut removal = self.orderings[idx].is_empty(); + while ordering_idx < self.orderings.len() { + work |= resolve_overlap(&mut self.orderings, idx, ordering_idx); + if self.orderings[idx].is_empty() { + removal = true; + break; + } + work |= resolve_overlap(&mut self.orderings, ordering_idx, idx); + if self.orderings[ordering_idx].is_empty() { + self.orderings.swap_remove(ordering_idx); + } else { + ordering_idx += 1; + } + } + if removal { + self.orderings.swap_remove(idx); + } else { + idx += 1; + } + } + } + } + + /// Returns the concatenation of all the orderings. This enables merge + /// operations to preserve all equivalent orderings simultaneously. + pub fn output_ordering(&self) -> Option { + let output_ordering = self.orderings.iter().flatten().cloned().collect(); + let output_ordering = collapse_lex_ordering(output_ordering); + (!output_ordering.is_empty()).then_some(output_ordering) + } + + // Append orderings in `other` to all existing orderings in this equivalence + // class. + pub fn join_suffix(mut self, other: &Self) -> Self { + let n_ordering = self.orderings.len(); + // Replicate entries before cross product + let n_cross = std::cmp::max(n_ordering, other.len() * n_ordering); + self.orderings = self + .orderings + .iter() + .cloned() + .cycle() + .take(n_cross) + .collect(); + // Suffix orderings of other to the current orderings. + for (outer_idx, ordering) in other.iter().enumerate() { + for idx in 0..n_ordering { + // Calculate cross product index + let idx = outer_idx * n_ordering + idx; + self.orderings[idx].extend(ordering.iter().cloned()); + } + } + self + } + + /// Adds `offset` value to the index of each expression inside this + /// ordering equivalence class. + pub fn add_offset(&mut self, offset: usize) { + for ordering in self.orderings.iter_mut() { + for sort_expr in ordering { + sort_expr.expr = add_offset_to_expr(sort_expr.expr.clone(), offset); + } + } + } + + /// Gets sort options associated with this expression if it is a leading + /// ordering expression. Otherwise, returns `None`. + pub fn get_options(&self, expr: &Arc) -> Option { + for ordering in self.iter() { + let leading_ordering = &ordering[0]; + if leading_ordering.expr.eq(expr) { + return Some(leading_ordering.options); + } + } + None + } +} + +/// This function constructs a duplicate-free `LexOrdering` by filtering out +/// duplicate entries that have same physical expression inside. For example, +/// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. +pub fn collapse_lex_ordering(input: LexOrdering) -> LexOrdering { + let mut output = Vec::::new(); + for item in input { + if !output.iter().any(|req| req.expr.eq(&item.expr)) { + output.push(item); + } + } + output +} + +/// Trims `orderings[idx]` if some suffix of it overlaps with a prefix of +/// `orderings[pre_idx]`. Returns `true` if there is any overlap, `false` otherwise. +fn resolve_overlap(orderings: &mut [LexOrdering], idx: usize, pre_idx: usize) -> bool { + let length = orderings[idx].len(); + let other_length = orderings[pre_idx].len(); + for overlap in 1..=length.min(other_length) { + if orderings[idx][length - overlap..] == orderings[pre_idx][..overlap] { + orderings[idx].truncate(length - overlap); + return true; + } + } + false +} + +#[cfg(test)] +mod tests { + use crate::equivalence::tests::{ + convert_to_orderings, convert_to_sort_exprs, create_random_schema, + create_test_params, generate_table_for_eq_properties, is_table_same_after_sort, + }; + use crate::equivalence::{tests::create_test_schema, EquivalenceProperties}; + use crate::equivalence::{ + EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, + }; + use crate::execution_props::ExecutionProps; + use crate::expressions::Column; + use crate::expressions::{col, BinaryExpr}; + use crate::functions::create_physical_expr; + use crate::{PhysicalExpr, PhysicalSortExpr}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::SortOptions; + use datafusion_common::Result; + use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; + use std::sync::Arc; + + #[test] + fn test_ordering_satisfy() -> Result<()> { + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + ])); + let crude = vec![PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }]; + let finer = vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + ]; + // finer ordering satisfies, crude ordering should return true + let mut eq_properties_finer = EquivalenceProperties::new(input_schema.clone()); + eq_properties_finer.oeq_class.push(finer.clone()); + assert!(eq_properties_finer.ordering_satisfy(&crude)); + + // Crude ordering doesn't satisfy finer ordering. should return false + let mut eq_properties_crude = EquivalenceProperties::new(input_schema.clone()); + eq_properties_crude.oeq_class.push(crude.clone()); + assert!(!eq_properties_crude.ordering_satisfy(&finer)); + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence2() -> Result<()> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let floor_a = &create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let floor_f = &create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("f", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let exp_a = &create_physical_expr( + &BuiltinScalarFunction::Exp, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + let options = SortOptions { + descending: false, + nulls_first: false, + }; + + let test_cases = vec![ + // ------------ TEST CASE 1 ------------ + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![(col_a, options), (col_d, options), (col_b, options)], + // [c ASC] + vec![(col_c, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, b ASC], requirement is not satisfied. + vec![(col_a, options), (col_b, options)], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 2 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(a) ASC], + vec![(floor_a, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 2.1 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(f) ASC], (Please note that a=f) + vec![(floor_f, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 3 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, c ASC, a+b ASC], + vec![(col_a, options), (col_c, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(a) ASC, a+b ASC], + vec![(floor_a, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + false, + ), + // ------------ TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [exp(a) ASC, a+b ASC], + vec![(exp_a, options), (&a_plus_b, options)], + // expected: requirement is not satisfied. + // TODO: If we know that exp function is 1-to-1 function. + // we could have deduced that above requirement is satisfied. + false, + ), + // ------------ TEST CASE 6 ------------ + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![(col_a, options), (col_d, options), (col_b, options)], + // [c ASC] + vec![(col_c, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, d ASC, floor(a) ASC], + vec![(col_a, options), (col_d, options), (floor_a, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 7 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, floor(a) ASC, a + b ASC], + vec![(col_a, options), (floor_a, options), (&a_plus_b, options)], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 8 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![(col_a, options), (col_b, options), (col_c, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, c ASC, floor(a) ASC, a + b ASC], + vec![ + (col_a, options), + (col_c, options), + (&floor_a, options), + (&a_plus_b, options), + ], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 9 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, b ASC, c ASC, floor(a) ASC], + vec![ + (col_a, options), + (col_b, options), + (&col_c, options), + (&floor_a, options), + ], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 10 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, options), (col_b, options)], + // [c ASC, a ASC] + vec![(col_c, options), (col_a, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [c ASC, d ASC, a + b ASC], + vec![(col_c, options), (col_d, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + true, + ), + ]; + + for (orderings, eq_group, constants, reqs, expected) in test_cases { + let err_msg = + format!("error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}"); + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + let eq_group = eq_group + .into_iter() + .map(|eq_class| { + let eq_classes = eq_class.into_iter().cloned().collect::>(); + EquivalenceClass::new(eq_classes) + }) + .collect::>(); + let eq_group = EquivalenceGroup::new(eq_group); + eq_properties.add_equivalence_group(eq_group); + + let constants = constants.into_iter().cloned(); + eq_properties = eq_properties.add_constants(constants); + + let reqs = convert_to_sort_exprs(&reqs); + assert_eq!( + eq_properties.ordering_satisfy(&reqs), + expected, + "{}", + err_msg + ); + } + + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence() -> Result<()> { + // Schema satisfies following orderings: + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + // and + // Column [a=c] (e.g they are aliases). + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, 625, 5)?; + + // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function + let requirements = vec![ + // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it + (vec![(col_a, option_asc)], true), + (vec![(col_a, option_desc)], false), + // Test whether equivalence works as expected + (vec![(col_c, option_asc)], true), + (vec![(col_c, option_desc)], false), + // Test whether ordering equivalence works as expected + (vec![(col_d, option_asc)], true), + (vec![(col_d, option_asc), (col_b, option_asc)], true), + (vec![(col_d, option_desc), (col_b, option_asc)], false), + ( + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + true, + ), + (vec![(col_e, option_desc), (col_f, option_asc)], true), + (vec![(col_e, option_asc), (col_f, option_asc)], false), + (vec![(col_e, option_desc), (col_b, option_asc)], false), + (vec![(col_e, option_asc), (col_b, option_asc)], false), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_f, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_f, option_asc), + ], + false, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_b, option_asc), + ], + false, + ), + (vec![(col_d, option_asc), (col_e, option_desc)], true), + ( + vec![ + (col_d, option_asc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_f, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_b, option_asc), + (col_f, option_asc), + ], + true, + ), + ]; + + for (cols, expected) in requirements { + let err_msg = format!("Error in test case:{cols:?}"); + let required = cols + .into_iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: expr.clone(), + options, + }) + .collect::>(); + + // Check expected result with experimental result. + assert_eq!( + is_table_same_after_sort( + required.clone(), + table_data_with_properties.clone() + )?, + expected + ); + assert_eq!( + eq_properties.ordering_satisfy(&required), + expected, + "{err_msg}" + ); + } + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 5; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + let col_exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + ]; + + for n_req in 0..=col_exprs.len() { + for exprs in col_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + eq_properties.ordering_satisfy(&requirement), + expected, + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + + assert_eq!( + eq_properties.ordering_satisfy(&requirement), + (expected | false), + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + + #[test] + fn test_ordering_satisfy_different_lengths() -> Result<()> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let options = SortOptions { + descending: false, + nulls_first: false, + }; + // a=c (e.g they are aliases). + let mut eq_properties = EquivalenceProperties::new(test_schema); + eq_properties.add_equal_conditions(col_a, col_c); + + let orderings = vec![ + vec![(col_a, options)], + vec![(col_e, options)], + vec![(col_d, options), (col_f, options)], + ]; + let orderings = convert_to_orderings(&orderings); + + // Column [a ASC], [e ASC], [d ASC, f ASC] are all valid orderings for the schema. + eq_properties.add_new_orderings(orderings); + + // First entry in the tuple is required ordering, second entry is the expected flag + // that indicates whether this required ordering is satisfied. + // ([a ASC], true) indicate a ASC requirement is already satisfied by existing orderings. + let test_cases = vec![ + // [c ASC, a ASC, e ASC], expected represents this requirement is satisfied + ( + vec![(col_c, options), (col_a, options), (col_e, options)], + true, + ), + (vec![(col_c, options), (col_b, options)], false), + (vec![(col_c, options), (col_d, options)], true), + ( + vec![(col_d, options), (col_f, options), (col_b, options)], + false, + ), + (vec![(col_d, options), (col_f, options)], true), + ]; + + for (reqs, expected) in test_cases { + let err_msg = + format!("error in test reqs: {:?}, expected: {:?}", reqs, expected,); + let reqs = convert_to_sort_exprs(&reqs); + assert_eq!( + eq_properties.ordering_satisfy(&reqs), + expected, + "{}", + err_msg + ); + } + + Ok(()) + } + + #[test] + fn test_remove_redundant_entries_oeq_class() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + + // First entry in the tuple is the given orderings for the table + // Second entry is the simplest version of the given orderings that is functionally equivalent. + let test_cases = vec![ + // ------- TEST CASE 1 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + ], + ), + // ------- TEST CASE 2 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + ), + // ------- TEST CASE 3 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b DESC] + vec![(col_a, option_asc), (col_b, option_desc)], + // [a ASC] + vec![(col_a, option_asc)], + // [a ASC, c ASC] + vec![(col_a, option_asc), (col_c, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b DESC] + vec![(col_a, option_asc), (col_b, option_desc)], + // [a ASC, c ASC] + vec![(col_a, option_asc), (col_c, option_asc)], + ], + ), + // ------- TEST CASE 4 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [a ASC] + vec![(col_a, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + ), + // ------- TEST CASE 5 --------- + // Empty ordering + ( + vec![vec![]], + // No ordering in the state (empty ordering is ignored). + vec![], + ), + // ------- TEST CASE 6 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [b ASC] + vec![(col_b, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC] + vec![(col_a, option_asc)], + // [b ASC] + vec![(col_b, option_asc)], + ], + ), + // ------- TEST CASE 7 --------- + // b, a + // c, a + // d, b, c + ( + // ORDERINGS GIVEN + vec![ + // [b ASC, a ASC] + vec![(col_b, option_asc), (col_a, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC, b ASC, c ASC] + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC, a ASC] + vec![(col_b, option_asc), (col_a, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + // ------- TEST CASE 8 --------- + // b, e + // c, a + // d, b, e, c, a + ( + // ORDERINGS GIVEN + vec![ + // [b ASC, e ASC] + vec![(col_b, option_asc), (col_e, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC, b ASC, e ASC, c ASC, a ASC] + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_c, option_asc), + (col_a, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC, e ASC] + vec![(col_b, option_asc), (col_e, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + // ------- TEST CASE 9 --------- + // b + // a, b, c + // d, a, b + ( + // ORDERINGS GIVEN + vec![ + // [b ASC] + vec![(col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [d ASC, a ASC, b ASC] + vec![ + (col_d, option_asc), + (col_a, option_asc), + (col_b, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC] + vec![(col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + ]; + for (orderings, expected) in test_cases { + let orderings = convert_to_orderings(&orderings); + let expected = convert_to_orderings(&expected); + let actual = OrderingEquivalenceClass::new(orderings.clone()); + let actual = actual.orderings; + let err_msg = format!( + "orderings: {:?}, expected: {:?}, actual :{:?}", + orderings, expected, actual + ); + assert_eq!(actual.len(), expected.len(), "{}", err_msg); + for elem in actual { + assert!(expected.contains(&elem), "{}", err_msg); + } + } + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs new file mode 100644 index 000000000000..0f92b2c2f431 --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -0,0 +1,1153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::expressions::Column; +use crate::PhysicalExpr; + +use arrow::datatypes::SchemaRef; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; + +/// Stores the mapping between source expressions and target expressions for a +/// projection. +#[derive(Debug, Clone)] +pub struct ProjectionMapping { + /// Mapping between source expressions and target expressions. + /// Vector indices correspond to the indices after projection. + pub map: Vec<(Arc, Arc)>, +} + +impl ProjectionMapping { + /// Constructs the mapping between a projection's input and output + /// expressions. + /// + /// For example, given the input projection expressions (`a + b`, `c + d`) + /// and an output schema with two columns `"c + d"` and `"a + b"`, the + /// projection mapping would be: + /// + /// ```text + /// [0]: (c + d, col("c + d")) + /// [1]: (a + b, col("a + b")) + /// ``` + /// + /// where `col("c + d")` means the column named `"c + d"`. + pub fn try_new( + expr: &[(Arc, String)], + input_schema: &SchemaRef, + ) -> Result { + // Construct a map from the input expressions to the output expression of the projection: + expr.iter() + .enumerate() + .map(|(expr_idx, (expression, name))| { + let target_expr = Arc::new(Column::new(name, expr_idx)) as _; + expression + .clone() + .transform_down(&|e| match e.as_any().downcast_ref::() { + Some(col) => { + // Sometimes, an expression and its name in the input_schema + // doesn't match. This can cause problems, so we make sure + // that the expression name matches with the name in `input_schema`. + // Conceptually, `source_expr` and `expression` should be the same. + let idx = col.index(); + let matching_input_field = input_schema.field(idx); + let matching_input_column = + Column::new(matching_input_field.name(), idx); + Ok(Transformed::Yes(Arc::new(matching_input_column))) + } + None => Ok(Transformed::No(e)), + }) + .map(|source_expr| (source_expr, target_expr)) + }) + .collect::>>() + .map(|map| Self { map }) + } + + /// Iterate over pairs of (source, target) expressions + pub fn iter( + &self, + ) -> impl Iterator, Arc)> + '_ { + self.map.iter() + } + + /// This function returns the target expression for a given source expression. + /// + /// # Arguments + /// + /// * `expr` - Source physical expression. + /// + /// # Returns + /// + /// An `Option` containing the target for the given source expression, + /// where a `None` value means that `expr` is not inside the mapping. + pub fn target_expr( + &self, + expr: &Arc, + ) -> Option> { + self.map + .iter() + .find(|(source, _)| source.eq(expr)) + .map(|(_, target)| target.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::equivalence::tests::{ + apply_projection, convert_to_orderings, convert_to_orderings_owned, + create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, + output_schema, + }; + use crate::equivalence::EquivalenceProperties; + use crate::execution_props::ExecutionProps; + use crate::expressions::{col, BinaryExpr, Literal}; + use crate::functions::create_physical_expr; + use crate::PhysicalSortExpr; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{SortOptions, TimeUnit}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; + use std::sync::Arc; + + #[test] + fn project_orderings() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + let col_ts = &col("ts", &schema)?; + let interval = Arc::new(Literal::new(ScalarValue::IntervalDayTime(Some(2)))) + as Arc; + let date_bin_func = &create_physical_expr( + &BuiltinScalarFunction::DateBin, + &[interval, col_ts.clone()], + &schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + let b_plus_d = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + let b_plus_e = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_e.clone(), + )) as Arc; + let c_plus_d = Arc::new(BinaryExpr::new( + col_c.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + + let test_cases = vec![ + // ---------- TEST CASE 1 ------------ + ( + // orderings + vec![ + // [b ASC] + vec![(col_b, option_asc)], + ], + // projection exprs + vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())], + // expected + vec![ + // [b_new ASC] + vec![("b_new", option_asc)], + ], + ), + // ---------- TEST CASE 2 ------------ + ( + // orderings + vec![ + // empty ordering + ], + // projection exprs + vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())], + // expected + vec![ + // no ordering at the output + ], + ), + // ---------- TEST CASE 3 ------------ + ( + // orderings + vec![ + // [ts ASC] + vec![(col_ts, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_ts, "ts_new".to_string()), + (date_bin_func, "date_bin_res".to_string()), + ], + // expected + vec![ + // [date_bin_res ASC] + vec![("date_bin_res", option_asc)], + // [ts_new ASC] + vec![("ts_new", option_asc)], + ], + ), + // ---------- TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, ts ASC] + vec![(col_a, option_asc), (col_ts, option_asc)], + // [b ASC, ts ASC] + vec![(col_b, option_asc), (col_ts, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_ts, "ts_new".to_string()), + (date_bin_func, "date_bin_res".to_string()), + ], + // expected + vec![ + // [a_new ASC, ts_new ASC] + vec![("a_new", option_asc), ("ts_new", option_asc)], + // [a_new ASC, date_bin_res ASC] + vec![("a_new", option_asc), ("date_bin_res", option_asc)], + // [b_new ASC, ts_new ASC] + vec![("b_new", option_asc), ("ts_new", option_asc)], + // [b_new ASC, date_bin_res ASC] + vec![("b_new", option_asc), ("date_bin_res", option_asc)], + ], + ), + // ---------- TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a + b ASC] + vec![(&a_plus_b, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a + b ASC] + vec![("a+b", option_asc)], + ], + ), + // ---------- TEST CASE 6 ------------ + ( + // orderings + vec![ + // [a + b ASC, c ASC] + vec![(&a_plus_b, option_asc), (&col_c, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_c, "c_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a + b ASC, c_new ASC] + vec![("a+b", option_asc), ("c_new", option_asc)], + ], + ), + // ------- TEST CASE 7 ---------- + ( + vec![ + // [a ASC, b ASC, c ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, d ASC] + vec![(col_a, option_asc), (col_d, option_asc)], + ], + // b as b_new, a as a_new, d as d_new b+d + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_d, "d_new".to_string()), + (&b_plus_d, "b+d".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b_new", option_asc)], + // [a_new ASC, d_new ASC] + vec![("a_new", option_asc), ("d_new", option_asc)], + // [a_new ASC, b+d ASC] + vec![("a_new", option_asc), ("b+d", option_asc)], + ], + ), + // ------- TEST CASE 8 ---------- + ( + // orderings + vec![ + // [b+d ASC] + vec![(&b_plus_d, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_d, "d_new".to_string()), + (&b_plus_d, "b+d".to_string()), + ], + // expected + vec![ + // [b+d ASC] + vec![("b+d", option_asc)], + ], + ), + // ------- TEST CASE 9 ---------- + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![ + (col_a, option_asc), + (col_d, option_asc), + (col_b, option_asc), + ], + // [c ASC] + vec![(col_c, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_d, "d_new".to_string()), + (col_c, "c_new".to_string()), + ], + // expected + vec![ + // [a_new ASC, d_new ASC, b_new ASC] + vec![ + ("a_new", option_asc), + ("d_new", option_asc), + ("b_new", option_asc), + ], + // [c_new ASC], + vec![("c_new", option_asc)], + ], + ), + // ------- TEST CASE 10 ---------- + ( + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [a ASC, d ASC] + vec![(col_a, option_asc), (col_d, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_c, "c_new".to_string()), + (&c_plus_d, "c+d".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC, c_new ASC] + vec![ + ("a_new", option_asc), + ("b_new", option_asc), + ("c_new", option_asc), + ], + // [a_new ASC, b_new ASC, c+d ASC] + vec![ + ("a_new", option_asc), + ("b_new", option_asc), + ("c+d", option_asc), + ], + ], + ), + // ------- TEST CASE 11 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, d ASC] + vec![(col_a, option_asc), (col_d, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (&b_plus_d, "b+d".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b_new", option_asc)], + // [a_new ASC, b + d ASC] + vec![("a_new", option_asc), ("b+d", option_asc)], + ], + ), + // ------- TEST CASE 12 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // proj exprs + vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())], + // expected + vec![ + // [a_new ASC] + vec![("a_new", option_asc)], + ], + ), + // ------- TEST CASE 13 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [a ASC, a + b ASC, c ASC] + vec![ + (col_a, option_asc), + (&a_plus_b, option_asc), + (col_c, option_asc), + ], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC, c_new ASC] + vec![ + ("a_new", option_asc), + ("b_new", option_asc), + ("c_new", option_asc), + ], + // [a_new ASC, a+b ASC, c_new ASC] + vec![ + ("a_new", option_asc), + ("a+b", option_asc), + ("c_new", option_asc), + ], + ], + ), + // ------- TEST CASE 14 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [c ASC, b ASC] + vec![(col_c, option_asc), (col_b, option_asc)], + // [d ASC, e ASC] + vec![(col_d, option_asc), (col_e, option_asc)], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_d, "d_new".to_string()), + (col_a, "a_new".to_string()), + (&b_plus_e, "b+e".to_string()), + ], + // expected + vec![ + // [a_new ASC, d_new ASC, b+e ASC] + vec![ + ("a_new", option_asc), + ("d_new", option_asc), + ("b+e", option_asc), + ], + // [d_new ASC, a_new ASC, b+e ASC] + vec![ + ("d_new", option_asc), + ("a_new", option_asc), + ("b+e", option_asc), + ], + // [c_new ASC, d_new ASC, b+e ASC] + vec![ + ("c_new", option_asc), + ("d_new", option_asc), + ("b+e", option_asc), + ], + // [d_new ASC, c_new ASC, b+e ASC] + vec![ + ("d_new", option_asc), + ("c_new", option_asc), + ("b+e", option_asc), + ], + ], + ), + // ------- TEST CASE 15 ---------- + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![ + (col_a, option_asc), + (col_c, option_asc), + (&col_b, option_asc), + ], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_a, "a_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a_new ASC, d_new ASC, b+e ASC] + vec![ + ("a_new", option_asc), + ("c_new", option_asc), + ("a+b", option_asc), + ], + ], + ), + // ------- TEST CASE 16 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [c ASC, b DESC] + vec![(col_c, option_asc), (col_b, option_desc)], + // [e ASC] + vec![(col_e, option_asc)], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_a, "a_new".to_string()), + (col_b, "b_new".to_string()), + (&b_plus_e, "b+e".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b_new", option_asc)], + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b+e", option_asc)], + // [c_new ASC, b_new DESC] + vec![("c_new", option_asc), ("b_new", option_desc)], + ], + ), + ]; + + for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate() + { + let mut eq_properties = EquivalenceProperties::new(schema.clone()); + + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name)) + .collect::>(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + let output_schema = output_schema(&projection_mapping, &schema)?; + + let expected = expected + .into_iter() + .map(|ordering| { + ordering + .into_iter() + .map(|(name, options)| { + (col(name, &output_schema).unwrap(), options) + }) + .collect::>() + }) + .collect::>(); + let expected = convert_to_orderings_owned(&expected); + + let projected_eq = eq_properties.project(&projection_mapping, output_schema); + let orderings = projected_eq.oeq_class(); + + let err_msg = format!( + "test_idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", + idx, orderings.orderings, expected, projection_mapping + ); + + assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + for expected_ordering in &expected { + assert!(orderings.contains(expected_ordering), "{}", err_msg) + } + } + + Ok(()) + } + + #[test] + fn project_orderings2() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_ts = &col("ts", &schema)?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + let interval = Arc::new(Literal::new(ScalarValue::IntervalDayTime(Some(2)))) + as Arc; + let date_bin_ts = &create_physical_expr( + &BuiltinScalarFunction::DateBin, + &[interval, col_ts.clone()], + &schema, + &ExecutionProps::default(), + )?; + + let round_c = &create_physical_expr( + &BuiltinScalarFunction::Round, + &[col_c.clone()], + &schema, + &ExecutionProps::default(), + )?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let proj_exprs = vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_c, "c_new".to_string()), + (date_bin_ts, "date_bin_res".to_string()), + (round_c, "round_c_res".to_string()), + ]; + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name)) + .collect::>(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + let output_schema = output_schema(&projection_mapping, &schema)?; + + let col_a_new = &col("a_new", &output_schema)?; + let col_b_new = &col("b_new", &output_schema)?; + let col_c_new = &col("c_new", &output_schema)?; + let col_date_bin_res = &col("date_bin_res", &output_schema)?; + let col_round_c_res = &col("round_c_res", &output_schema)?; + let a_new_plus_b_new = Arc::new(BinaryExpr::new( + col_a_new.clone(), + Operator::Plus, + col_b_new.clone(), + )) as Arc; + + let test_cases = vec![ + // ---------- TEST CASE 1 ------------ + ( + // orderings + vec![ + // [a ASC] + vec![(col_a, option_asc)], + ], + // expected + vec![ + // [b_new ASC] + vec![(col_a_new, option_asc)], + ], + ), + // ---------- TEST CASE 2 ------------ + ( + // orderings + vec![ + // [a+b ASC] + vec![(&a_plus_b, option_asc)], + ], + // expected + vec![ + // [b_new ASC] + vec![(&a_new_plus_b_new, option_asc)], + ], + ), + // ---------- TEST CASE 3 ------------ + ( + // orderings + vec![ + // [a ASC, ts ASC] + vec![(col_a, option_asc), (col_ts, option_asc)], + ], + // expected + vec![ + // [a_new ASC, date_bin_res ASC] + vec![(col_a_new, option_asc), (col_date_bin_res, option_asc)], + ], + ), + // ---------- TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, ts ASC, b ASC] + vec![ + (col_a, option_asc), + (col_ts, option_asc), + (col_b, option_asc), + ], + ], + // expected + vec![ + // [a_new ASC, date_bin_res ASC] + // Please note that result is not [a_new ASC, date_bin_res ASC, b_new ASC] + // because, datebin_res may not be 1-1 function. Hence without introducing ts + // dependency we cannot guarantee any ordering after date_bin_res column. + vec![(col_a_new, option_asc), (col_date_bin_res, option_asc)], + ], + ), + // ---------- TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC] + vec![(col_a, option_asc), (col_c, option_asc)], + ], + // expected + vec![ + // [a_new ASC, round_c_res ASC, c_new ASC] + vec![(col_a_new, option_asc), (col_round_c_res, option_asc)], + // [a_new ASC, c_new ASC] + vec![(col_a_new, option_asc), (col_c_new, option_asc)], + ], + ), + // ---------- TEST CASE 6 ------------ + ( + // orderings + vec![ + // [c ASC, b ASC] + vec![(col_c, option_asc), (col_b, option_asc)], + ], + // expected + vec![ + // [round_c_res ASC] + vec![(col_round_c_res, option_asc)], + // [c_new ASC, b_new ASC] + vec![(col_c_new, option_asc), (col_b_new, option_asc)], + ], + ), + // ---------- TEST CASE 7 ------------ + ( + // orderings + vec![ + // [a+b ASC, c ASC] + vec![(&a_plus_b, option_asc), (col_c, option_asc)], + ], + // expected + vec![ + // [a+b ASC, round(c) ASC, c_new ASC] + vec![ + (&a_new_plus_b_new, option_asc), + (&col_round_c_res, option_asc), + ], + // [a+b ASC, c_new ASC] + vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)], + ], + ), + ]; + + for (idx, (orderings, expected)) in test_cases.iter().enumerate() { + let mut eq_properties = EquivalenceProperties::new(schema.clone()); + + let orderings = convert_to_orderings(orderings); + eq_properties.add_new_orderings(orderings); + + let expected = convert_to_orderings(expected); + + let projected_eq = + eq_properties.project(&projection_mapping, output_schema.clone()); + let orderings = projected_eq.oeq_class(); + + let err_msg = format!( + "test idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", + idx, orderings.orderings, expected, projection_mapping + ); + + assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + for expected_ordering in &expected { + assert!(orderings.contains(expected_ordering), "{}", err_msg) + } + } + Ok(()) + } + + #[test] + fn project_orderings3() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Int32, true), + ])); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + let col_f = &col("f", &schema)?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let proj_exprs = vec![ + (col_c, "c_new".to_string()), + (col_d, "d_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ]; + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name)) + .collect::>(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + let output_schema = output_schema(&projection_mapping, &schema)?; + + let col_a_plus_b_new = &col("a+b", &output_schema)?; + let col_c_new = &col("c_new", &output_schema)?; + let col_d_new = &col("d_new", &output_schema)?; + + let test_cases = vec![ + // ---------- TEST CASE 1 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + ], + // equal conditions + vec![], + // expected + vec![ + // [d_new ASC, c_new ASC, a+b ASC] + vec![ + (col_d_new, option_asc), + (col_c_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + // [c_new ASC, d_new ASC, a+b ASC] + vec![ + (col_c_new, option_asc), + (col_d_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + ], + ), + // ---------- TEST CASE 2 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [c ASC, e ASC], Please note that a=e + vec![(col_c, option_asc), (col_e, option_asc)], + ], + // equal conditions + vec![(col_e, col_a)], + // expected + vec![ + // [d_new ASC, c_new ASC, a+b ASC] + vec![ + (col_d_new, option_asc), + (col_c_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + // [c_new ASC, d_new ASC, a+b ASC] + vec![ + (col_c_new, option_asc), + (col_d_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + ], + ), + // ---------- TEST CASE 3 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [c ASC, e ASC], Please note that a=f + vec![(col_c, option_asc), (col_e, option_asc)], + ], + // equal conditions + vec![(col_a, col_f)], + // expected + vec![ + // [d_new ASC] + vec![(col_d_new, option_asc)], + // [c_new ASC] + vec![(col_c_new, option_asc)], + ], + ), + ]; + for (orderings, equal_columns, expected) in test_cases { + let mut eq_properties = EquivalenceProperties::new(schema.clone()); + for (lhs, rhs) in equal_columns { + eq_properties.add_equal_conditions(lhs, rhs); + } + + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + + let expected = convert_to_orderings(&expected); + + let projected_eq = + eq_properties.project(&projection_mapping, output_schema.clone()); + let orderings = projected_eq.oeq_class(); + + let err_msg = format!( + "actual: {:?}, expected: {:?}, projection_mapping: {:?}", + orderings.orderings, expected, projection_mapping + ); + + assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + for expected_ordering in &expected { + assert!(orderings.contains(expected_ordering), "{}", err_msg) + } + } + + Ok(()) + } + + #[test] + fn project_orderings_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + // Make sure each ordering after projection is valid. + for ordering in projected_eq.oeq_class().iter() { + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs + ); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + projected_batch.clone(), + )?, + "{}", + err_msg + ); + } + } + } + } + + Ok(()) + } + + #[test] + fn ordering_satisfy_after_projection_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + let projection_mapping = + ProjectionMapping::try_new(&proj_exprs, &test_schema)?; + + let projected_exprs = projection_mapping + .iter() + .map(|(_source, target)| target.clone()) + .collect::>(); + + for n_req in 0..=projected_exprs.len() { + for exprs in projected_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + projected_batch.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + projected_eq.ordering_satisfy(&requirement), + expected, + "{}", + err_msg + ); + } + } + } + } + } + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs new file mode 100644 index 000000000000..31c1cf61193a --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -0,0 +1,2062 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::Column; +use arrow_schema::SchemaRef; +use datafusion_common::{JoinSide, JoinType}; +use indexmap::IndexSet; +use itertools::Itertools; +use std::collections::{HashMap, HashSet}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use crate::equivalence::{ + collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, +}; + +use crate::expressions::Literal; +use crate::sort_properties::{ExprOrdering, SortProperties}; +use crate::{ + physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, + LexRequirementRef, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, +}; +use datafusion_common::tree_node::{Transformed, TreeNode}; + +use super::ordering::collapse_lex_ordering; + +/// A `EquivalenceProperties` object stores useful information related to a schema. +/// Currently, it keeps track of: +/// - Equivalent expressions, e.g expressions that have same value. +/// - Valid sort expressions (orderings) for the schema. +/// - Constants expressions (e.g expressions that are known to have constant values). +/// +/// Consider table below: +/// +/// ```text +/// ┌-------┐ +/// | a | b | +/// |---|---| +/// | 1 | 9 | +/// | 2 | 8 | +/// | 3 | 7 | +/// | 5 | 5 | +/// └---┴---┘ +/// ``` +/// +/// where both `a ASC` and `b DESC` can describe the table ordering. With +/// `EquivalenceProperties`, we can keep track of these different valid sort +/// expressions and treat `a ASC` and `b DESC` on an equal footing. +/// +/// Similarly, consider the table below: +/// +/// ```text +/// ┌-------┐ +/// | a | b | +/// |---|---| +/// | 1 | 1 | +/// | 2 | 2 | +/// | 3 | 3 | +/// | 5 | 5 | +/// └---┴---┘ +/// ``` +/// +/// where columns `a` and `b` always have the same value. We keep track of such +/// equivalences inside this object. With this information, we can optimize +/// things like partitioning. For example, if the partition requirement is +/// `Hash(a)` and output partitioning is `Hash(b)`, then we can deduce that +/// the existing partitioning satisfies the requirement. +#[derive(Debug, Clone)] +pub struct EquivalenceProperties { + /// Collection of equivalence classes that store expressions with the same + /// value. + pub eq_group: EquivalenceGroup, + /// Equivalent sort expressions for this table. + pub oeq_class: OrderingEquivalenceClass, + /// Expressions whose values are constant throughout the table. + /// TODO: We do not need to track constants separately, they can be tracked + /// inside `eq_groups` as `Literal` expressions. + pub constants: Vec>, + /// Schema associated with this object. + schema: SchemaRef, +} + +impl EquivalenceProperties { + /// Creates an empty `EquivalenceProperties` object. + pub fn new(schema: SchemaRef) -> Self { + Self { + eq_group: EquivalenceGroup::empty(), + oeq_class: OrderingEquivalenceClass::empty(), + constants: vec![], + schema, + } + } + + /// Creates a new `EquivalenceProperties` object with the given orderings. + pub fn new_with_orderings(schema: SchemaRef, orderings: &[LexOrdering]) -> Self { + Self { + eq_group: EquivalenceGroup::empty(), + oeq_class: OrderingEquivalenceClass::new(orderings.to_vec()), + constants: vec![], + schema, + } + } + + /// Returns the associated schema. + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + + /// Returns a reference to the ordering equivalence class within. + pub fn oeq_class(&self) -> &OrderingEquivalenceClass { + &self.oeq_class + } + + /// Returns a reference to the equivalence group within. + pub fn eq_group(&self) -> &EquivalenceGroup { + &self.eq_group + } + + /// Returns a reference to the constant expressions + pub fn constants(&self) -> &[Arc] { + &self.constants + } + + /// Returns the normalized version of the ordering equivalence class within. + /// Normalization removes constants and duplicates as well as standardizing + /// expressions according to the equivalence group within. + pub fn normalized_oeq_class(&self) -> OrderingEquivalenceClass { + OrderingEquivalenceClass::new( + self.oeq_class + .iter() + .map(|ordering| self.normalize_sort_exprs(ordering)) + .collect(), + ) + } + + /// Extends this `EquivalenceProperties` with the `other` object. + pub fn extend(mut self, other: Self) -> Self { + self.eq_group.extend(other.eq_group); + self.oeq_class.extend(other.oeq_class); + self.add_constants(other.constants) + } + + /// Clears (empties) the ordering equivalence class within this object. + /// Call this method when existing orderings are invalidated. + pub fn clear_orderings(&mut self) { + self.oeq_class.clear(); + } + + /// Extends this `EquivalenceProperties` by adding the orderings inside the + /// ordering equivalence class `other`. + pub fn add_ordering_equivalence_class(&mut self, other: OrderingEquivalenceClass) { + self.oeq_class.extend(other); + } + + /// Adds new orderings into the existing ordering equivalence class. + pub fn add_new_orderings( + &mut self, + orderings: impl IntoIterator, + ) { + self.oeq_class.add_new_orderings(orderings); + } + + /// Incorporates the given equivalence group to into the existing + /// equivalence group within. + pub fn add_equivalence_group(&mut self, other_eq_group: EquivalenceGroup) { + self.eq_group.extend(other_eq_group); + } + + /// Adds a new equality condition into the existing equivalence group. + /// If the given equality defines a new equivalence class, adds this new + /// equivalence class to the equivalence group. + pub fn add_equal_conditions( + &mut self, + left: &Arc, + right: &Arc, + ) { + self.eq_group.add_equal_conditions(left, right); + } + + /// Track/register physical expressions with constant values. + pub fn add_constants( + mut self, + constants: impl IntoIterator>, + ) -> Self { + for expr in self.eq_group.normalize_exprs(constants) { + if !physical_exprs_contains(&self.constants, &expr) { + self.constants.push(expr); + } + } + self + } + + /// Updates the ordering equivalence group within assuming that the table + /// is re-sorted according to the argument `sort_exprs`. Note that constants + /// and equivalence classes are unchanged as they are unaffected by a re-sort. + pub fn with_reorder(mut self, sort_exprs: Vec) -> Self { + // TODO: In some cases, existing ordering equivalences may still be valid add this analysis. + self.oeq_class = OrderingEquivalenceClass::new(vec![sort_exprs]); + self + } + + /// Normalizes the given sort expressions (i.e. `sort_exprs`) using the + /// equivalence group and the ordering equivalence class within. + /// + /// Assume that `self.eq_group` states column `a` and `b` are aliases. + /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` + /// are equivalent (in the sense that both describe the ordering of the table). + /// If the `sort_exprs` argument were `vec![b ASC, c ASC, a ASC]`, then this + /// function would return `vec![a ASC, c ASC]`. Internally, it would first + /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result + /// after deduplication. + fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { + // Convert sort expressions to sort requirements: + let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); + // Normalize the requirements: + let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); + // Convert sort requirements back to sort expressions: + PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) + } + + /// Normalizes the given sort requirements (i.e. `sort_reqs`) using the + /// equivalence group and the ordering equivalence class within. It works by: + /// - Removing expressions that have a constant value from the given requirement. + /// - Replacing sections that belong to some equivalence class in the equivalence + /// group with the first entry in the matching equivalence class. + /// + /// Assume that `self.eq_group` states column `a` and `b` are aliases. + /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` + /// are equivalent (in the sense that both describe the ordering of the table). + /// If the `sort_reqs` argument were `vec![b ASC, c ASC, a ASC]`, then this + /// function would return `vec![a ASC, c ASC]`. Internally, it would first + /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result + /// after deduplication. + fn normalize_sort_requirements( + &self, + sort_reqs: LexRequirementRef, + ) -> LexRequirement { + let normalized_sort_reqs = self.eq_group.normalize_sort_requirements(sort_reqs); + let constants_normalized = self.eq_group.normalize_exprs(self.constants.clone()); + // Prune redundant sections in the requirement: + collapse_lex_req( + normalized_sort_reqs + .iter() + .filter(|&order| { + !physical_exprs_contains(&constants_normalized, &order.expr) + }) + .cloned() + .collect(), + ) + } + + /// Checks whether the given ordering is satisfied by any of the existing + /// orderings. + pub fn ordering_satisfy(&self, given: LexOrderingRef) -> bool { + // Convert the given sort expressions to sort requirements: + let sort_requirements = PhysicalSortRequirement::from_sort_exprs(given.iter()); + self.ordering_satisfy_requirement(&sort_requirements) + } + + /// Checks whether the given sort requirements are satisfied by any of the + /// existing orderings. + pub fn ordering_satisfy_requirement(&self, reqs: LexRequirementRef) -> bool { + let mut eq_properties = self.clone(); + // First, standardize the given requirement: + let normalized_reqs = eq_properties.normalize_sort_requirements(reqs); + for normalized_req in normalized_reqs { + // Check whether given ordering is satisfied + if !eq_properties.ordering_satisfy_single(&normalized_req) { + return false; + } + // Treat satisfied keys as constants in subsequent iterations. We + // can do this because the "next" key only matters in a lexicographical + // ordering when the keys to its left have the same values. + // + // Note that these expressions are not properly "constants". This is just + // an implementation strategy confined to this function. + // + // For example, assume that the requirement is `[a ASC, (b + c) ASC]`, + // and existing equivalent orderings are `[a ASC, b ASC]` and `[c ASC]`. + // From the analysis above, we know that `[a ASC]` is satisfied. Then, + // we add column `a` as constant to the algorithm state. This enables us + // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. + eq_properties = + eq_properties.add_constants(std::iter::once(normalized_req.expr)); + } + true + } + + /// Determines whether the ordering specified by the given sort requirement + /// is satisfied based on the orderings within, equivalence classes, and + /// constant expressions. + /// + /// # Arguments + /// + /// - `req`: A reference to a `PhysicalSortRequirement` for which the ordering + /// satisfaction check will be done. + /// + /// # Returns + /// + /// Returns `true` if the specified ordering is satisfied, `false` otherwise. + fn ordering_satisfy_single(&self, req: &PhysicalSortRequirement) -> bool { + let expr_ordering = self.get_expr_ordering(req.expr.clone()); + let ExprOrdering { expr, state, .. } = expr_ordering; + match state { + SortProperties::Ordered(options) => { + let sort_expr = PhysicalSortExpr { expr, options }; + sort_expr.satisfy(req, self.schema()) + } + // Singleton expressions satisfies any ordering. + SortProperties::Singleton => true, + SortProperties::Unordered => false, + } + } + + /// Checks whether the `given`` sort requirements are equal or more specific + /// than the `reference` sort requirements. + pub fn requirements_compatible( + &self, + given: LexRequirementRef, + reference: LexRequirementRef, + ) -> bool { + let normalized_given = self.normalize_sort_requirements(given); + let normalized_reference = self.normalize_sort_requirements(reference); + + (normalized_reference.len() <= normalized_given.len()) + && normalized_reference + .into_iter() + .zip(normalized_given) + .all(|(reference, given)| given.compatible(&reference)) + } + + /// Returns the finer ordering among the orderings `lhs` and `rhs`, breaking + /// any ties by choosing `lhs`. + /// + /// The finer ordering is the ordering that satisfies both of the orderings. + /// If the orderings are incomparable, returns `None`. + /// + /// For example, the finer ordering among `[a ASC]` and `[a ASC, b ASC]` is + /// the latter. + pub fn get_finer_ordering( + &self, + lhs: LexOrderingRef, + rhs: LexOrderingRef, + ) -> Option { + // Convert the given sort expressions to sort requirements: + let lhs = PhysicalSortRequirement::from_sort_exprs(lhs); + let rhs = PhysicalSortRequirement::from_sort_exprs(rhs); + let finer = self.get_finer_requirement(&lhs, &rhs); + // Convert the chosen sort requirements back to sort expressions: + finer.map(PhysicalSortRequirement::to_sort_exprs) + } + + /// Returns the finer ordering among the requirements `lhs` and `rhs`, + /// breaking any ties by choosing `lhs`. + /// + /// The finer requirements are the ones that satisfy both of the given + /// requirements. If the requirements are incomparable, returns `None`. + /// + /// For example, the finer requirements among `[a ASC]` and `[a ASC, b ASC]` + /// is the latter. + pub fn get_finer_requirement( + &self, + req1: LexRequirementRef, + req2: LexRequirementRef, + ) -> Option { + let mut lhs = self.normalize_sort_requirements(req1); + let mut rhs = self.normalize_sort_requirements(req2); + lhs.iter_mut() + .zip(rhs.iter_mut()) + .all(|(lhs, rhs)| { + lhs.expr.eq(&rhs.expr) + && match (lhs.options, rhs.options) { + (Some(lhs_opt), Some(rhs_opt)) => lhs_opt == rhs_opt, + (Some(options), None) => { + rhs.options = Some(options); + true + } + (None, Some(options)) => { + lhs.options = Some(options); + true + } + (None, None) => true, + } + }) + .then_some(if lhs.len() >= rhs.len() { lhs } else { rhs }) + } + + /// Calculates the "meet" of the given orderings (`lhs` and `rhs`). + /// The meet of a set of orderings is the finest ordering that is satisfied + /// by all the orderings in that set. For details, see: + /// + /// + /// + /// If there is no ordering that satisfies both `lhs` and `rhs`, returns + /// `None`. As an example, the meet of orderings `[a ASC]` and `[a ASC, b ASC]` + /// is `[a ASC]`. + pub fn get_meet_ordering( + &self, + lhs: LexOrderingRef, + rhs: LexOrderingRef, + ) -> Option { + let lhs = self.normalize_sort_exprs(lhs); + let rhs = self.normalize_sort_exprs(rhs); + let mut meet = vec![]; + for (lhs, rhs) in lhs.into_iter().zip(rhs.into_iter()) { + if lhs.eq(&rhs) { + meet.push(lhs); + } else { + break; + } + } + (!meet.is_empty()).then_some(meet) + } + + /// Projects argument `expr` according to `projection_mapping`, taking + /// equivalences into account. + /// + /// For example, assume that columns `a` and `c` are always equal, and that + /// `projection_mapping` encodes following mapping: + /// + /// ```text + /// a -> a1 + /// b -> b1 + /// ``` + /// + /// Then, this function projects `a + b` to `Some(a1 + b1)`, `c + b` to + /// `Some(a1 + b1)` and `d` to `None`, meaning that it cannot be projected. + pub fn project_expr( + &self, + expr: &Arc, + projection_mapping: &ProjectionMapping, + ) -> Option> { + self.eq_group.project_expr(projection_mapping, expr) + } + + /// Constructs a dependency map based on existing orderings referred to in + /// the projection. + /// + /// This function analyzes the orderings in the normalized order-equivalence + /// class and builds a dependency map. The dependency map captures relationships + /// between expressions within the orderings, helping to identify dependencies + /// and construct valid projected orderings during projection operations. + /// + /// # Parameters + /// + /// - `mapping`: A reference to the `ProjectionMapping` that defines the + /// relationship between source and target expressions. + /// + /// # Returns + /// + /// A [`DependencyMap`] representing the dependency map, where each + /// [`DependencyNode`] contains dependencies for the key [`PhysicalSortExpr`]. + /// + /// # Example + /// + /// Assume we have two equivalent orderings: `[a ASC, b ASC]` and `[a ASC, c ASC]`, + /// and the projection mapping is `[a -> a_new, b -> b_new, b + c -> b + c]`. + /// Then, the dependency map will be: + /// + /// ```text + /// a ASC: Node {Some(a_new ASC), HashSet{}} + /// b ASC: Node {Some(b_new ASC), HashSet{a ASC}} + /// c ASC: Node {None, HashSet{a ASC}} + /// ``` + fn construct_dependency_map(&self, mapping: &ProjectionMapping) -> DependencyMap { + let mut dependency_map = HashMap::new(); + for ordering in self.normalized_oeq_class().iter() { + for (idx, sort_expr) in ordering.iter().enumerate() { + let target_sort_expr = + self.project_expr(&sort_expr.expr, mapping).map(|expr| { + PhysicalSortExpr { + expr, + options: sort_expr.options, + } + }); + let is_projected = target_sort_expr.is_some(); + if is_projected + || mapping + .iter() + .any(|(source, _)| expr_refers(source, &sort_expr.expr)) + { + // Previous ordering is a dependency. Note that there is no, + // dependency for a leading ordering (i.e. the first sort + // expression). + let dependency = idx.checked_sub(1).map(|a| &ordering[a]); + // Add sort expressions that can be projected or referred to + // by any of the projection expressions to the dependency map: + dependency_map + .entry(sort_expr.clone()) + .or_insert_with(|| DependencyNode { + target_sort_expr: target_sort_expr.clone(), + dependencies: HashSet::new(), + }) + .insert_dependency(dependency); + } + if !is_projected { + // If we can not project, stop constructing the dependency + // map as remaining dependencies will be invalid after projection. + break; + } + } + } + dependency_map + } + + /// Returns a new `ProjectionMapping` where source expressions are normalized. + /// + /// This normalization ensures that source expressions are transformed into a + /// consistent representation. This is beneficial for algorithms that rely on + /// exact equalities, as it allows for more precise and reliable comparisons. + /// + /// # Parameters + /// + /// - `mapping`: A reference to the original `ProjectionMapping` to be normalized. + /// + /// # Returns + /// + /// A new `ProjectionMapping` with normalized source expressions. + fn normalized_mapping(&self, mapping: &ProjectionMapping) -> ProjectionMapping { + // Construct the mapping where source expressions are normalized. In this way + // In the algorithms below we can work on exact equalities + ProjectionMapping { + map: mapping + .iter() + .map(|(source, target)| { + let normalized_source = self.eq_group.normalize_expr(source.clone()); + (normalized_source, target.clone()) + }) + .collect(), + } + } + + /// Computes projected orderings based on a given projection mapping. + /// + /// This function takes a `ProjectionMapping` and computes the possible + /// orderings for the projected expressions. It considers dependencies + /// between expressions and generates valid orderings according to the + /// specified sort properties. + /// + /// # Parameters + /// + /// - `mapping`: A reference to the `ProjectionMapping` that defines the + /// relationship between source and target expressions. + /// + /// # Returns + /// + /// A vector of `LexOrdering` containing all valid orderings after projection. + fn projected_orderings(&self, mapping: &ProjectionMapping) -> Vec { + let mapping = self.normalized_mapping(mapping); + + // Get dependency map for existing orderings: + let dependency_map = self.construct_dependency_map(&mapping); + + let orderings = mapping.iter().flat_map(|(source, target)| { + referred_dependencies(&dependency_map, source) + .into_iter() + .filter_map(|relevant_deps| { + if let SortProperties::Ordered(options) = + get_expr_ordering(source, &relevant_deps) + { + Some((options, relevant_deps)) + } else { + // Do not consider unordered cases + None + } + }) + .flat_map(|(options, relevant_deps)| { + let sort_expr = PhysicalSortExpr { + expr: target.clone(), + options, + }; + // Generate dependent orderings (i.e. prefixes for `sort_expr`): + let mut dependency_orderings = + generate_dependency_orderings(&relevant_deps, &dependency_map); + // Append `sort_expr` to the dependent orderings: + for ordering in dependency_orderings.iter_mut() { + ordering.push(sort_expr.clone()); + } + dependency_orderings + }) + }); + + // Add valid projected orderings. For example, if existing ordering is + // `a + b` and projection is `[a -> a_new, b -> b_new]`, we need to + // preserve `a_new + b_new` as ordered. Please note that `a_new` and + // `b_new` themselves need not be ordered. Such dependencies cannot be + // deduced via the pass above. + let projected_orderings = dependency_map.iter().flat_map(|(sort_expr, node)| { + let mut prefixes = construct_prefix_orderings(sort_expr, &dependency_map); + if prefixes.is_empty() { + // If prefix is empty, there is no dependency. Insert + // empty ordering: + prefixes = vec![vec![]]; + } + // Append current ordering on top its dependencies: + for ordering in prefixes.iter_mut() { + if let Some(target) = &node.target_sort_expr { + ordering.push(target.clone()) + } + } + prefixes + }); + + // Simplify each ordering by removing redundant sections: + orderings + .chain(projected_orderings) + .map(collapse_lex_ordering) + .collect() + } + + /// Projects constants based on the provided `ProjectionMapping`. + /// + /// This function takes a `ProjectionMapping` and identifies/projects + /// constants based on the existing constants and the mapping. It ensures + /// that constants are appropriately propagated through the projection. + /// + /// # Arguments + /// + /// - `mapping`: A reference to a `ProjectionMapping` representing the + /// mapping of source expressions to target expressions in the projection. + /// + /// # Returns + /// + /// Returns a `Vec>` containing the projected constants. + fn projected_constants( + &self, + mapping: &ProjectionMapping, + ) -> Vec> { + // First, project existing constants. For example, assume that `a + b` + // is known to be constant. If the projection were `a as a_new`, `b as b_new`, + // then we would project constant `a + b` as `a_new + b_new`. + let mut projected_constants = self + .constants + .iter() + .flat_map(|expr| self.eq_group.project_expr(mapping, expr)) + .collect::>(); + // Add projection expressions that are known to be constant: + for (source, target) in mapping.iter() { + if self.is_expr_constant(source) + && !physical_exprs_contains(&projected_constants, target) + { + projected_constants.push(target.clone()); + } + } + projected_constants + } + + /// Projects the equivalences within according to `projection_mapping` + /// and `output_schema`. + pub fn project( + &self, + projection_mapping: &ProjectionMapping, + output_schema: SchemaRef, + ) -> Self { + let projected_constants = self.projected_constants(projection_mapping); + let projected_eq_group = self.eq_group.project(projection_mapping); + let projected_orderings = self.projected_orderings(projection_mapping); + Self { + eq_group: projected_eq_group, + oeq_class: OrderingEquivalenceClass::new(projected_orderings), + constants: projected_constants, + schema: output_schema, + } + } + + /// Returns the longest (potentially partial) permutation satisfying the + /// existing ordering. For example, if we have the equivalent orderings + /// `[a ASC, b ASC]` and `[c DESC]`, with `exprs` containing `[c, b, a, d]`, + /// then this function returns `([a ASC, b ASC, c DESC], [2, 1, 0])`. + /// This means that the specification `[a ASC, b ASC, c DESC]` is satisfied + /// by the existing ordering, and `[a, b, c]` resides at indices: `2, 1, 0` + /// inside the argument `exprs` (respectively). For the mathematical + /// definition of "partial permutation", see: + /// + /// + pub fn find_longest_permutation( + &self, + exprs: &[Arc], + ) -> (LexOrdering, Vec) { + let mut eq_properties = self.clone(); + let mut result = vec![]; + // The algorithm is as follows: + // - Iterate over all the expressions and insert ordered expressions + // into the result. + // - Treat inserted expressions as constants (i.e. add them as constants + // to the state). + // - Continue the above procedure until no expression is inserted; i.e. + // the algorithm reaches a fixed point. + // This algorithm should reach a fixed point in at most `exprs.len()` + // iterations. + let mut search_indices = (0..exprs.len()).collect::>(); + for _idx in 0..exprs.len() { + // Get ordered expressions with their indices. + let ordered_exprs = search_indices + .iter() + .flat_map(|&idx| { + let ExprOrdering { expr, state, .. } = + eq_properties.get_expr_ordering(exprs[idx].clone()); + if let SortProperties::Ordered(options) = state { + Some((PhysicalSortExpr { expr, options }, idx)) + } else { + None + } + }) + .collect::>(); + // We reached a fixed point, exit. + if ordered_exprs.is_empty() { + break; + } + // Remove indices that have an ordering from `search_indices`, and + // treat ordered expressions as constants in subsequent iterations. + // We can do this because the "next" key only matters in a lexicographical + // ordering when the keys to its left have the same values. + // + // Note that these expressions are not properly "constants". This is just + // an implementation strategy confined to this function. + for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { + eq_properties = + eq_properties.add_constants(std::iter::once(expr.clone())); + search_indices.remove(idx); + } + // Add new ordered section to the state. + result.extend(ordered_exprs); + } + result.into_iter().unzip() + } + + /// This function determines whether the provided expression is constant + /// based on the known constants. + /// + /// # Arguments + /// + /// - `expr`: A reference to a `Arc` representing the + /// expression to be checked. + /// + /// # Returns + /// + /// Returns `true` if the expression is constant according to equivalence + /// group, `false` otherwise. + fn is_expr_constant(&self, expr: &Arc) -> bool { + // As an example, assume that we know columns `a` and `b` are constant. + // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will + // return `false`. + let normalized_constants = self.eq_group.normalize_exprs(self.constants.to_vec()); + let normalized_expr = self.eq_group.normalize_expr(expr.clone()); + is_constant_recurse(&normalized_constants, &normalized_expr) + } + + /// Retrieves the ordering information for a given physical expression. + /// + /// This function constructs an `ExprOrdering` object for the provided + /// expression, which encapsulates information about the expression's + /// ordering, including its [`SortProperties`]. + /// + /// # Arguments + /// + /// - `expr`: An `Arc` representing the physical expression + /// for which ordering information is sought. + /// + /// # Returns + /// + /// Returns an `ExprOrdering` object containing the ordering information for + /// the given expression. + pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { + ExprOrdering::new(expr.clone()) + .transform_up(&|expr| Ok(update_ordering(expr, self))) + // Guaranteed to always return `Ok`. + .unwrap() + } +} + +/// Calculates the [`SortProperties`] of a given [`ExprOrdering`] node. +/// The node can either be a leaf node, or an intermediate node: +/// - If it is a leaf node, we directly find the order of the node by looking +/// at the given sort expression and equivalence properties if it is a `Column` +/// leaf, or we mark it as unordered. In the case of a `Literal` leaf, we mark +/// it as singleton so that it can cooperate with all ordered columns. +/// - If it is an intermediate node, the children states matter. Each `PhysicalExpr` +/// and operator has its own rules on how to propagate the children orderings. +/// However, before we engage in recursion, we check whether this intermediate +/// node directly matches with the sort expression. If there is a match, the +/// sort expression emerges at that node immediately, discarding the recursive +/// result coming from its children. +fn update_ordering( + mut node: ExprOrdering, + eq_properties: &EquivalenceProperties, +) -> Transformed { + // We have a Column, which is one of the two possible leaf node types: + let normalized_expr = eq_properties.eq_group.normalize_expr(node.expr.clone()); + if eq_properties.is_expr_constant(&normalized_expr) { + node.state = SortProperties::Singleton; + } else if let Some(options) = eq_properties + .normalized_oeq_class() + .get_options(&normalized_expr) + { + node.state = SortProperties::Ordered(options); + } else if !node.expr.children().is_empty() { + // We have an intermediate (non-leaf) node, account for its children: + node.state = node.expr.get_ordering(&node.children_state()); + } else if node.expr.as_any().is::() { + // We have a Literal, which is the other possible leaf node type: + node.state = node.expr.get_ordering(&[]); + } else { + return Transformed::No(node); + } + Transformed::Yes(node) +} + +/// This function determines whether the provided expression is constant +/// based on the known constants. +/// +/// # Arguments +/// +/// - `constants`: A `&[Arc]` containing expressions known to +/// be a constant. +/// - `expr`: A reference to a `Arc` representing the expression +/// to check. +/// +/// # Returns +/// +/// Returns `true` if the expression is constant according to equivalence +/// group, `false` otherwise. +fn is_constant_recurse( + constants: &[Arc], + expr: &Arc, +) -> bool { + if physical_exprs_contains(constants, expr) { + return true; + } + let children = expr.children(); + !children.is_empty() && children.iter().all(|c| is_constant_recurse(constants, c)) +} + +/// This function examines whether a referring expression directly refers to a +/// given referred expression or if any of its children in the expression tree +/// refer to the specified expression. +/// +/// # Parameters +/// +/// - `referring_expr`: A reference to the referring expression (`Arc`). +/// - `referred_expr`: A reference to the referred expression (`Arc`) +/// +/// # Returns +/// +/// A boolean value indicating whether `referring_expr` refers (needs it to evaluate its result) +/// `referred_expr` or not. +fn expr_refers( + referring_expr: &Arc, + referred_expr: &Arc, +) -> bool { + referring_expr.eq(referred_expr) + || referring_expr + .children() + .iter() + .any(|child| expr_refers(child, referred_expr)) +} + +/// This function analyzes the dependency map to collect referred dependencies for +/// a given source expression. +/// +/// # Parameters +/// +/// - `dependency_map`: A reference to the `DependencyMap` where each +/// `PhysicalSortExpr` is associated with a `DependencyNode`. +/// - `source`: A reference to the source expression (`Arc`) +/// for which relevant dependencies need to be identified. +/// +/// # Returns +/// +/// A `Vec` containing the dependencies for the given source +/// expression. These dependencies are expressions that are referred to by +/// the source expression based on the provided dependency map. +fn referred_dependencies( + dependency_map: &DependencyMap, + source: &Arc, +) -> Vec { + // Associate `PhysicalExpr`s with `PhysicalSortExpr`s that contain them: + let mut expr_to_sort_exprs = HashMap::::new(); + for sort_expr in dependency_map + .keys() + .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) + { + let key = ExprWrapper(sort_expr.expr.clone()); + expr_to_sort_exprs + .entry(key) + .or_default() + .insert(sort_expr.clone()); + } + + // Generate all valid dependencies for the source. For example, if the source + // is `a + b` and the map is `[a -> (a ASC, a DESC), b -> (b ASC)]`, we get + // `vec![HashSet(a ASC, b ASC), HashSet(a DESC, b ASC)]`. + expr_to_sort_exprs + .values() + .multi_cartesian_product() + .map(|referred_deps| referred_deps.into_iter().cloned().collect()) + .collect() +} + +/// This function retrieves the dependencies of the given relevant sort expression +/// from the given dependency map. It then constructs prefix orderings by recursively +/// analyzing the dependencies and include them in the orderings. +/// +/// # Parameters +/// +/// - `relevant_sort_expr`: A reference to the relevant sort expression +/// (`PhysicalSortExpr`) for which prefix orderings are to be constructed. +/// - `dependency_map`: A reference to the `DependencyMap` containing dependencies. +/// +/// # Returns +/// +/// A vector of prefix orderings (`Vec`) based on the given relevant +/// sort expression and its dependencies. +fn construct_prefix_orderings( + relevant_sort_expr: &PhysicalSortExpr, + dependency_map: &DependencyMap, +) -> Vec { + dependency_map[relevant_sort_expr] + .dependencies + .iter() + .flat_map(|dep| construct_orderings(dep, dependency_map)) + .collect() +} + +/// Given a set of relevant dependencies (`relevant_deps`) and a map of dependencies +/// (`dependency_map`), this function generates all possible prefix orderings +/// based on the given dependencies. +/// +/// # Parameters +/// +/// * `dependencies` - A reference to the dependencies. +/// * `dependency_map` - A reference to the map of dependencies for expressions. +/// +/// # Returns +/// +/// A vector of lexical orderings (`Vec`) representing all valid orderings +/// based on the given dependencies. +fn generate_dependency_orderings( + dependencies: &Dependencies, + dependency_map: &DependencyMap, +) -> Vec { + // Construct all the valid prefix orderings for each expression appearing + // in the projection: + let relevant_prefixes = dependencies + .iter() + .flat_map(|dep| { + let prefixes = construct_prefix_orderings(dep, dependency_map); + (!prefixes.is_empty()).then_some(prefixes) + }) + .collect::>(); + + // No dependency, dependent is a leading ordering. + if relevant_prefixes.is_empty() { + // Return an empty ordering: + return vec![vec![]]; + } + + // Generate all possible orderings where dependencies are satisfied for the + // current projection expression. For example, if expression is `a + b ASC`, + // and the dependency for `a ASC` is `[c ASC]`, the dependency for `b ASC` + // is `[d DESC]`, then we generate `[c ASC, d DESC, a + b ASC]` and + // `[d DESC, c ASC, a + b ASC]`. + relevant_prefixes + .into_iter() + .multi_cartesian_product() + .flat_map(|prefix_orderings| { + prefix_orderings + .iter() + .permutations(prefix_orderings.len()) + .map(|prefixes| prefixes.into_iter().flatten().cloned().collect()) + .collect::>() + }) + .collect() +} + +/// This function examines the given expression and the sort expressions it +/// refers to determine the ordering properties of the expression. +/// +/// # Parameters +/// +/// - `expr`: A reference to the source expression (`Arc`) for +/// which ordering properties need to be determined. +/// - `dependencies`: A reference to `Dependencies`, containing sort expressions +/// referred to by `expr`. +/// +/// # Returns +/// +/// A `SortProperties` indicating the ordering information of the given expression. +fn get_expr_ordering( + expr: &Arc, + dependencies: &Dependencies, +) -> SortProperties { + if let Some(column_order) = dependencies.iter().find(|&order| expr.eq(&order.expr)) { + // If exact match is found, return its ordering. + SortProperties::Ordered(column_order.options) + } else { + // Find orderings of its children + let child_states = expr + .children() + .iter() + .map(|child| get_expr_ordering(child, dependencies)) + .collect::>(); + // Calculate expression ordering using ordering of its children. + expr.get_ordering(&child_states) + } +} + +/// Represents a node in the dependency map used to construct projected orderings. +/// +/// A `DependencyNode` contains information about a particular sort expression, +/// including its target sort expression and a set of dependencies on other sort +/// expressions. +/// +/// # Fields +/// +/// - `target_sort_expr`: An optional `PhysicalSortExpr` representing the target +/// sort expression associated with the node. It is `None` if the sort expression +/// cannot be projected. +/// - `dependencies`: A [`Dependencies`] containing dependencies on other sort +/// expressions that are referred to by the target sort expression. +#[derive(Debug, Clone, PartialEq, Eq)] +struct DependencyNode { + target_sort_expr: Option, + dependencies: Dependencies, +} + +impl DependencyNode { + // Insert dependency to the state (if exists). + fn insert_dependency(&mut self, dependency: Option<&PhysicalSortExpr>) { + if let Some(dep) = dependency { + self.dependencies.insert(dep.clone()); + } + } +} + +type DependencyMap = HashMap; +type Dependencies = HashSet; + +/// This function recursively analyzes the dependencies of the given sort +/// expression within the given dependency map to construct lexicographical +/// orderings that include the sort expression and its dependencies. +/// +/// # Parameters +/// +/// - `referred_sort_expr`: A reference to the sort expression (`PhysicalSortExpr`) +/// for which lexicographical orderings satisfying its dependencies are to be +/// constructed. +/// - `dependency_map`: A reference to the `DependencyMap` that contains +/// dependencies for different `PhysicalSortExpr`s. +/// +/// # Returns +/// +/// A vector of lexicographical orderings (`Vec`) based on the given +/// sort expression and its dependencies. +fn construct_orderings( + referred_sort_expr: &PhysicalSortExpr, + dependency_map: &DependencyMap, +) -> Vec { + // We are sure that `referred_sort_expr` is inside `dependency_map`. + let node = &dependency_map[referred_sort_expr]; + // Since we work on intermediate nodes, we are sure `val.target_sort_expr` + // exists. + let target_sort_expr = node.target_sort_expr.clone().unwrap(); + if node.dependencies.is_empty() { + vec![vec![target_sort_expr]] + } else { + node.dependencies + .iter() + .flat_map(|dep| { + let mut orderings = construct_orderings(dep, dependency_map); + for ordering in orderings.iter_mut() { + ordering.push(target_sort_expr.clone()) + } + orderings + }) + .collect() + } +} + +/// Calculate ordering equivalence properties for the given join operation. +pub fn join_equivalence_properties( + left: EquivalenceProperties, + right: EquivalenceProperties, + join_type: &JoinType, + join_schema: SchemaRef, + maintains_input_order: &[bool], + probe_side: Option, + on: &[(Column, Column)], +) -> EquivalenceProperties { + let left_size = left.schema.fields.len(); + let mut result = EquivalenceProperties::new(join_schema); + result.add_equivalence_group(left.eq_group().join( + right.eq_group(), + join_type, + left_size, + on, + )); + + let left_oeq_class = left.oeq_class; + let mut right_oeq_class = right.oeq_class; + match maintains_input_order { + [true, false] => { + // In this special case, right side ordering can be prefixed with + // the left side ordering. + if let (Some(JoinSide::Left), JoinType::Inner) = (probe_side, join_type) { + updated_right_ordering_equivalence_class( + &mut right_oeq_class, + join_type, + left_size, + ); + + // Right side ordering equivalence properties should be prepended + // with those of the left side while constructing output ordering + // equivalence properties since stream side is the left side. + // + // For example, if the right side ordering equivalences contain + // `b ASC`, and the left side ordering equivalences contain `a ASC`, + // then we should add `a ASC, b ASC` to the ordering equivalences + // of the join output. + let out_oeq_class = left_oeq_class.join_suffix(&right_oeq_class); + result.add_ordering_equivalence_class(out_oeq_class); + } else { + result.add_ordering_equivalence_class(left_oeq_class); + } + } + [false, true] => { + updated_right_ordering_equivalence_class( + &mut right_oeq_class, + join_type, + left_size, + ); + // In this special case, left side ordering can be prefixed with + // the right side ordering. + if let (Some(JoinSide::Right), JoinType::Inner) = (probe_side, join_type) { + // Left side ordering equivalence properties should be prepended + // with those of the right side while constructing output ordering + // equivalence properties since stream side is the right side. + // + // For example, if the left side ordering equivalences contain + // `a ASC`, and the right side ordering equivalences contain `b ASC`, + // then we should add `b ASC, a ASC` to the ordering equivalences + // of the join output. + let out_oeq_class = right_oeq_class.join_suffix(&left_oeq_class); + result.add_ordering_equivalence_class(out_oeq_class); + } else { + result.add_ordering_equivalence_class(right_oeq_class); + } + } + [false, false] => {} + [true, true] => unreachable!("Cannot maintain ordering of both sides"), + _ => unreachable!("Join operators can not have more than two children"), + } + result +} + +/// In the context of a join, update the right side `OrderingEquivalenceClass` +/// so that they point to valid indices in the join output schema. +/// +/// To do so, we increment column indices by the size of the left table when +/// join schema consists of a combination of the left and right schemas. This +/// is the case for `Inner`, `Left`, `Full` and `Right` joins. For other cases, +/// indices do not change. +fn updated_right_ordering_equivalence_class( + right_oeq_class: &mut OrderingEquivalenceClass, + join_type: &JoinType, + left_size: usize, +) { + if matches!( + join_type, + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right + ) { + right_oeq_class.add_offset(left_size); + } +} + +/// Wrapper struct for `Arc` to use them as keys in a hash map. +#[derive(Debug, Clone)] +struct ExprWrapper(Arc); + +impl PartialEq for ExprWrapper { + fn eq(&self, other: &Self) -> bool { + self.0.eq(&other.0) + } +} + +impl Eq for ExprWrapper {} + +impl Hash for ExprWrapper { + fn hash(&self, state: &mut H) { + self.0.hash(state); + } +} + +#[cfg(test)] +mod tests { + use std::ops::Not; + use std::sync::Arc; + + use super::*; + use crate::equivalence::add_offset_to_expr; + use crate::equivalence::tests::{ + convert_to_orderings, convert_to_sort_exprs, convert_to_sort_reqs, + create_random_schema, create_test_params, create_test_schema, + generate_table_for_eq_properties, is_table_same_after_sort, output_schema, + }; + use crate::execution_props::ExecutionProps; + use crate::expressions::{col, BinaryExpr, Column}; + use crate::functions::create_physical_expr; + use crate::PhysicalSortExpr; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{Fields, SortOptions, TimeUnit}; + use datafusion_common::Result; + use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; + + #[test] + fn project_equivalence_properties_test() -> Result<()> { + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + ])); + + let input_properties = EquivalenceProperties::new(input_schema.clone()); + let col_a = col("a", &input_schema)?; + + // a as a1, a as a2, a as a3, a as a3 + let proj_exprs = vec![ + (col_a.clone(), "a1".to_string()), + (col_a.clone(), "a2".to_string()), + (col_a.clone(), "a3".to_string()), + (col_a.clone(), "a4".to_string()), + ]; + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + let out_schema = output_schema(&projection_mapping, &input_schema)?; + // a as a1, a as a2, a as a3, a as a3 + let proj_exprs = vec![ + (col_a.clone(), "a1".to_string()), + (col_a.clone(), "a2".to_string()), + (col_a.clone(), "a3".to_string()), + (col_a.clone(), "a4".to_string()), + ]; + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + // a as a1, a as a2, a as a3, a as a3 + let col_a1 = &col("a1", &out_schema)?; + let col_a2 = &col("a2", &out_schema)?; + let col_a3 = &col("a3", &out_schema)?; + let col_a4 = &col("a4", &out_schema)?; + let out_properties = input_properties.project(&projection_mapping, out_schema); + + // At the output a1=a2=a3=a4 + assert_eq!(out_properties.eq_group().len(), 1); + let eq_class = &out_properties.eq_group().classes[0]; + assert_eq!(eq_class.len(), 4); + assert!(eq_class.contains(col_a1)); + assert!(eq_class.contains(col_a2)); + assert!(eq_class.contains(col_a3)); + assert!(eq_class.contains(col_a4)); + + Ok(()) + } + + #[test] + fn test_join_equivalence_properties() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let offset = schema.fields.len(); + let col_a2 = &add_offset_to_expr(col_a.clone(), offset); + let col_b2 = &add_offset_to_expr(col_b.clone(), offset); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let test_cases = vec![ + // ------- TEST CASE 1 -------- + // [a ASC], [b ASC] + ( + // [a ASC], [b ASC] + vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], + // [a ASC], [b ASC] + vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], + // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC] + vec![ + vec![(col_a, option_asc), (col_a2, option_asc)], + vec![(col_a, option_asc), (col_b2, option_asc)], + vec![(col_b, option_asc), (col_a2, option_asc)], + vec![(col_b, option_asc), (col_b2, option_asc)], + ], + ), + // ------- TEST CASE 2 -------- + // [a ASC], [b ASC] + ( + // [a ASC], [b ASC], [c ASC] + vec![ + vec![(col_a, option_asc)], + vec![(col_b, option_asc)], + vec![(col_c, option_asc)], + ], + // [a ASC], [b ASC] + vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], + // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC], [c ASC, a2 ASC], [c ASC, b2 ASC] + vec![ + vec![(col_a, option_asc), (col_a2, option_asc)], + vec![(col_a, option_asc), (col_b2, option_asc)], + vec![(col_b, option_asc), (col_a2, option_asc)], + vec![(col_b, option_asc), (col_b2, option_asc)], + vec![(col_c, option_asc), (col_a2, option_asc)], + vec![(col_c, option_asc), (col_b2, option_asc)], + ], + ), + ]; + for (left_orderings, right_orderings, expected) in test_cases { + let mut left_eq_properties = EquivalenceProperties::new(schema.clone()); + let mut right_eq_properties = EquivalenceProperties::new(schema.clone()); + let left_orderings = convert_to_orderings(&left_orderings); + let right_orderings = convert_to_orderings(&right_orderings); + let expected = convert_to_orderings(&expected); + left_eq_properties.add_new_orderings(left_orderings); + right_eq_properties.add_new_orderings(right_orderings); + let join_eq = join_equivalence_properties( + left_eq_properties, + right_eq_properties, + &JoinType::Inner, + Arc::new(Schema::empty()), + &[true, false], + Some(JoinSide::Left), + &[], + ); + let orderings = &join_eq.oeq_class.orderings; + let err_msg = format!("expected: {:?}, actual:{:?}", expected, orderings); + assert_eq!( + join_eq.oeq_class.orderings.len(), + expected.len(), + "{}", + err_msg + ); + for ordering in orderings { + assert!( + expected.contains(ordering), + "{}, ordering: {:?}", + err_msg, + ordering + ); + } + } + Ok(()) + } + + #[test] + fn test_expr_consists_of_constants() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_d = col("d", &schema)?; + let b_plus_d = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let constants = vec![col_a.clone(), col_b.clone()]; + let expr = b_plus_d.clone(); + assert!(!is_constant_recurse(&constants, &expr)); + + let constants = vec![col_a.clone(), col_b.clone(), col_d.clone()]; + let expr = b_plus_d.clone(); + assert!(is_constant_recurse(&constants, &expr)); + Ok(()) + } + + #[test] + fn test_get_updated_right_ordering_equivalence_properties() -> Result<()> { + let join_type = JoinType::Inner; + // Join right child schema + let child_fields: Fields = ["x", "y", "z", "w"] + .into_iter() + .map(|name| Field::new(name, DataType::Int32, true)) + .collect(); + let child_schema = Schema::new(child_fields); + let col_x = &col("x", &child_schema)?; + let col_y = &col("y", &child_schema)?; + let col_z = &col("z", &child_schema)?; + let col_w = &col("w", &child_schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + // [x ASC, y ASC], [z ASC, w ASC] + let orderings = vec![ + vec![(col_x, option_asc), (col_y, option_asc)], + vec![(col_z, option_asc), (col_w, option_asc)], + ]; + let orderings = convert_to_orderings(&orderings); + // Right child ordering equivalences + let mut right_oeq_class = OrderingEquivalenceClass::new(orderings); + + let left_columns_len = 4; + + let fields: Fields = ["a", "b", "c", "d", "x", "y", "z", "w"] + .into_iter() + .map(|name| Field::new(name, DataType::Int32, true)) + .collect(); + + // Join Schema + let schema = Schema::new(fields); + let col_a = &col("a", &schema)?; + let col_d = &col("d", &schema)?; + let col_x = &col("x", &schema)?; + let col_y = &col("y", &schema)?; + let col_z = &col("z", &schema)?; + let col_w = &col("w", &schema)?; + + let mut join_eq_properties = EquivalenceProperties::new(Arc::new(schema)); + // a=x and d=w + join_eq_properties.add_equal_conditions(col_a, col_x); + join_eq_properties.add_equal_conditions(col_d, col_w); + + updated_right_ordering_equivalence_class( + &mut right_oeq_class, + &join_type, + left_columns_len, + ); + join_eq_properties.add_ordering_equivalence_class(right_oeq_class); + let result = join_eq_properties.oeq_class().clone(); + + // [x ASC, y ASC], [z ASC, w ASC] + let orderings = vec![ + vec![(col_x, option_asc), (col_y, option_asc)], + vec![(col_z, option_asc), (col_w, option_asc)], + ]; + let orderings = convert_to_orderings(&orderings); + let expected = OrderingEquivalenceClass::new(orderings); + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn test_normalize_ordering_equivalence_classes() -> Result<()> { + let sort_options = SortOptions::default(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let col_a_expr = col("a", &schema)?; + let col_b_expr = col("b", &schema)?; + let col_c_expr = col("c", &schema)?; + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); + + eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr); + let others = vec![ + vec![PhysicalSortExpr { + expr: col_b_expr.clone(), + options: sort_options, + }], + vec![PhysicalSortExpr { + expr: col_c_expr.clone(), + options: sort_options, + }], + ]; + eq_properties.add_new_orderings(others); + + let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema)); + expected_eqs.add_new_orderings([ + vec![PhysicalSortExpr { + expr: col_b_expr.clone(), + options: sort_options, + }], + vec![PhysicalSortExpr { + expr: col_c_expr.clone(), + options: sort_options, + }], + ]); + + let oeq_class = eq_properties.oeq_class().clone(); + let expected = expected_eqs.oeq_class(); + assert!(oeq_class.eq(expected)); + + Ok(()) + } + + #[test] + fn test_get_indices_of_matching_sort_exprs_with_order_eq() -> Result<()> { + let sort_options = SortOptions::default(); + let sort_options_not = SortOptions::default().not(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let required_columns = [col_b.clone(), col_a.clone()]; + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); + eq_properties.add_new_orderings([vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: sort_options_not, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: sort_options, + }, + ]]); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + assert_eq!(idxs, vec![0, 1]); + assert_eq!( + result, + vec![ + PhysicalSortExpr { + expr: col_b.clone(), + options: sort_options_not + }, + PhysicalSortExpr { + expr: col_a.clone(), + options: sort_options + } + ] + ); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let required_columns = [col_b.clone(), col_a.clone()]; + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); + eq_properties.add_new_orderings([ + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options: sort_options, + }], + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: sort_options_not, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: sort_options, + }, + ], + ]); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + assert_eq!(idxs, vec![0, 1]); + assert_eq!( + result, + vec![ + PhysicalSortExpr { + expr: col_b.clone(), + options: sort_options_not + }, + PhysicalSortExpr { + expr: col_a.clone(), + options: sort_options + } + ] + ); + + let required_columns = [ + Arc::new(Column::new("b", 1)) as _, + Arc::new(Column::new("a", 0)) as _, + ]; + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); + + // not satisfied orders + eq_properties.add_new_orderings([vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: sort_options_not, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options: sort_options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: sort_options, + }, + ]]); + let (_, idxs) = eq_properties.find_longest_permutation(&required_columns); + assert_eq!(idxs, vec![0]); + + Ok(()) + } + + #[test] + fn test_update_ordering() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + ]); + + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + // b=a (e.g they are aliases) + eq_properties.add_equal_conditions(col_b, col_a); + // [b ASC], [d ASC] + eq_properties.add_new_orderings(vec![ + vec![PhysicalSortExpr { + expr: col_b.clone(), + options: option_asc, + }], + vec![PhysicalSortExpr { + expr: col_d.clone(), + options: option_asc, + }], + ]); + + let test_cases = vec![ + // d + b + ( + Arc::new(BinaryExpr::new( + col_d.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc, + SortProperties::Ordered(option_asc), + ), + // b + (col_b.clone(), SortProperties::Ordered(option_asc)), + // a + (col_a.clone(), SortProperties::Ordered(option_asc)), + // a + c + ( + Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_c.clone(), + )), + SortProperties::Unordered, + ), + ]; + for (expr, expected) in test_cases { + let leading_orderings = eq_properties + .oeq_class() + .iter() + .flat_map(|ordering| ordering.first().cloned()) + .collect::>(); + let expr_ordering = eq_properties.get_expr_ordering(expr.clone()); + let err_msg = format!( + "expr:{:?}, expected: {:?}, actual: {:?}, leading_orderings: {leading_orderings:?}", + expr, expected, expr_ordering.state + ); + assert_eq!(expr_ordering.state, expected, "{}", err_msg); + } + + Ok(()) + } + + #[test] + fn test_find_longest_permutation_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let exprs = exprs.into_iter().cloned().collect::>(); + let (ordering, indices) = + eq_properties.find_longest_permutation(&exprs); + // Make sure that find_longest_permutation return values are consistent + let ordering2 = indices + .iter() + .zip(ordering.iter()) + .map(|(&idx, sort_expr)| PhysicalSortExpr { + expr: exprs[idx].clone(), + options: sort_expr.options, + }) + .collect::>(); + assert_eq!( + ordering, ordering2, + "indices and lexicographical ordering do not match" + ); + + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + assert_eq!(ordering.len(), indices.len(), "{}", err_msg); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + table_data_with_properties.clone(), + )?, + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + #[test] + fn test_find_longest_permutation() -> Result<()> { + // Schema satisfies following orderings: + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + // and + // Column [a=c] (e.g they are aliases). + // At below we add [d ASC, h DESC] also, for test purposes + let (test_schema, mut eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_h = &col("h", &test_schema)?; + // a + d + let a_plus_d = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + // [d ASC, h ASC] also satisfies schema. + eq_properties.add_new_orderings([vec![ + PhysicalSortExpr { + expr: col_d.clone(), + options: option_asc, + }, + PhysicalSortExpr { + expr: col_h.clone(), + options: option_desc, + }, + ]]); + let test_cases = vec![ + // TEST CASE 1 + (vec![col_a], vec![(col_a, option_asc)]), + // TEST CASE 2 + (vec![col_c], vec![(col_c, option_asc)]), + // TEST CASE 3 + ( + vec![col_d, col_e, col_b], + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_b, option_asc), + ], + ), + // TEST CASE 4 + (vec![col_b], vec![]), + // TEST CASE 5 + (vec![col_d], vec![(col_d, option_asc)]), + // TEST CASE 5 + (vec![&a_plus_d], vec![(&a_plus_d, option_asc)]), + // TEST CASE 6 + ( + vec![col_b, col_d], + vec![(col_d, option_asc), (col_b, option_asc)], + ), + // TEST CASE 6 + ( + vec![col_c, col_e], + vec![(col_c, option_asc), (col_e, option_desc)], + ), + ]; + for (exprs, expected) in test_cases { + let exprs = exprs.into_iter().cloned().collect::>(); + let expected = convert_to_sort_exprs(&expected); + let (actual, _) = eq_properties.find_longest_permutation(&exprs); + assert_eq!(actual, expected); + } + + Ok(()) + } + #[test] + fn test_get_meet_ordering() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let eq_properties = EquivalenceProperties::new(schema); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let tests_cases = vec![ + // Get meet ordering between [a ASC] and [a ASC, b ASC] + // result should be [a ASC] + ( + vec![(col_a, option_asc)], + vec![(col_a, option_asc), (col_b, option_asc)], + Some(vec![(col_a, option_asc)]), + ), + // Get meet ordering between [a ASC] and [a DESC] + // result should be None. + (vec![(col_a, option_asc)], vec![(col_a, option_desc)], None), + // Get meet ordering between [a ASC, b ASC] and [a ASC, b DESC] + // result should be [a ASC]. + ( + vec![(col_a, option_asc), (col_b, option_asc)], + vec![(col_a, option_asc), (col_b, option_desc)], + Some(vec![(col_a, option_asc)]), + ), + ]; + for (lhs, rhs, expected) in tests_cases { + let lhs = convert_to_sort_exprs(&lhs); + let rhs = convert_to_sort_exprs(&rhs); + let expected = expected.map(|expected| convert_to_sort_exprs(&expected)); + let finer = eq_properties.get_meet_ordering(&lhs, &rhs); + assert_eq!(finer, expected) + } + + Ok(()) + } + + #[test] + fn test_get_finer() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let eq_properties = EquivalenceProperties::new(schema); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + // First entry, and second entry are the physical sort requirement that are argument for get_finer_requirement. + // Third entry is the expected result. + let tests_cases = vec![ + // Get finer requirement between [a Some(ASC)] and [a None, b Some(ASC)] + // result should be [a Some(ASC), b Some(ASC)] + ( + vec![(col_a, Some(option_asc))], + vec![(col_a, None), (col_b, Some(option_asc))], + Some(vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))]), + ), + // Get finer requirement between [a Some(ASC), b Some(ASC), c Some(ASC)] and [a Some(ASC), b Some(ASC)] + // result should be [a Some(ASC), b Some(ASC), c Some(ASC)] + ( + vec![ + (col_a, Some(option_asc)), + (col_b, Some(option_asc)), + (col_c, Some(option_asc)), + ], + vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], + Some(vec![ + (col_a, Some(option_asc)), + (col_b, Some(option_asc)), + (col_c, Some(option_asc)), + ]), + ), + // Get finer requirement between [a Some(ASC), b Some(ASC)] and [a Some(ASC), b Some(DESC)] + // result should be None + ( + vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], + vec![(col_a, Some(option_asc)), (col_b, Some(option_desc))], + None, + ), + ]; + for (lhs, rhs, expected) in tests_cases { + let lhs = convert_to_sort_reqs(&lhs); + let rhs = convert_to_sort_reqs(&rhs); + let expected = expected.map(|expected| convert_to_sort_reqs(&expected)); + let finer = eq_properties.get_finer_requirement(&lhs, &rhs); + assert_eq!(finer, expected) + } + + Ok(()) + } + + #[test] + fn test_normalize_sort_reqs() -> Result<()> { + // Schema satisfies following properties + // a=c + // and following orderings are valid + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function + let requirements = vec![ + ( + vec![(col_a, Some(option_asc))], + vec![(col_a, Some(option_asc))], + ), + ( + vec![(col_a, Some(option_desc))], + vec![(col_a, Some(option_desc))], + ), + (vec![(col_a, None)], vec![(col_a, None)]), + // Test whether equivalence works as expected + ( + vec![(col_c, Some(option_asc))], + vec![(col_a, Some(option_asc))], + ), + (vec![(col_c, None)], vec![(col_a, None)]), + // Test whether ordering equivalence works as expected + ( + vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], + vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], + ), + ( + vec![(col_d, None), (col_b, None)], + vec![(col_d, None), (col_b, None)], + ), + ( + vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], + vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], + ), + // We should be able to normalize in compatible requirements also (not exactly equal) + ( + vec![(col_e, Some(option_desc)), (col_f, None)], + vec![(col_e, Some(option_desc)), (col_f, None)], + ), + ( + vec![(col_e, None), (col_f, None)], + vec![(col_e, None), (col_f, None)], + ), + ]; + + for (reqs, expected_normalized) in requirements.into_iter() { + let req = convert_to_sort_reqs(&reqs); + let expected_normalized = convert_to_sort_reqs(&expected_normalized); + + assert_eq!( + eq_properties.normalize_sort_requirements(&req), + expected_normalized + ); + } + + Ok(()) + } + + #[test] + fn test_schema_normalize_sort_requirement_with_equivalence() -> Result<()> { + let option1 = SortOptions { + descending: false, + nulls_first: false, + }; + // Assume that column a and c are aliases. + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + + // Test cases for equivalence normalization + // First entry in the tuple is PhysicalSortRequirement, second entry in the tuple is + // expected PhysicalSortRequirement after normalization. + let test_cases = vec![ + (vec![(col_a, Some(option1))], vec![(col_a, Some(option1))]), + // In the normalized version column c should be replace with column a + (vec![(col_c, Some(option1))], vec![(col_a, Some(option1))]), + (vec![(col_c, None)], vec![(col_a, None)]), + (vec![(col_d, Some(option1))], vec![(col_d, Some(option1))]), + ]; + for (reqs, expected) in test_cases.into_iter() { + let reqs = convert_to_sort_reqs(&reqs); + let expected = convert_to_sort_reqs(&expected); + + let normalized = eq_properties.normalize_sort_requirements(&reqs); + assert!( + expected.eq(&normalized), + "error in test: reqs: {reqs:?}, expected: {expected:?}, normalized: {normalized:?}" + ); + } + + Ok(()) + } +} From 78832f11a45dd47e5490583c2f0e90aef20b073f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 26 Dec 2023 06:59:40 -0500 Subject: [PATCH 32/63] Move parquet_schema.rs from sql to parquet tests (#8644) --- datafusion/core/tests/parquet/mod.rs | 1 + .../parquet_schema.rs => parquet/schema.rs} | 17 +++++++++++++++-- datafusion/core/tests/sql/mod.rs | 1 - 3 files changed, 16 insertions(+), 3 deletions(-) rename datafusion/core/tests/{sql/parquet_schema.rs => parquet/schema.rs} (95%) diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 3f003c077d6a..943f7fdbf4ac 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -44,6 +44,7 @@ mod file_statistics; mod filter_pushdown; mod page_pruning; mod row_group_pruning; +mod schema; mod schema_coercion; #[cfg(test)] diff --git a/datafusion/core/tests/sql/parquet_schema.rs b/datafusion/core/tests/parquet/schema.rs similarity index 95% rename from datafusion/core/tests/sql/parquet_schema.rs rename to datafusion/core/tests/parquet/schema.rs index bc1578da2c58..30d4e1193022 100644 --- a/datafusion/core/tests/sql/parquet_schema.rs +++ b/datafusion/core/tests/parquet/schema.rs @@ -22,6 +22,7 @@ use ::parquet::arrow::ArrowWriter; use tempfile::TempDir; use super::*; +use datafusion_common::assert_batches_sorted_eq; #[tokio::test] async fn schema_merge_ignores_metadata_by_default() { @@ -90,7 +91,13 @@ async fn schema_merge_ignores_metadata_by_default() { .await .unwrap(); - let actual = execute_to_batches(&ctx, "SELECT * from t").await; + let actual = ctx + .sql("SELECT * from t") + .await + .unwrap() + .collect() + .await + .unwrap(); assert_batches_sorted_eq!(expected, &actual); assert_no_metadata(&actual); } @@ -151,7 +158,13 @@ async fn schema_merge_can_preserve_metadata() { .await .unwrap(); - let actual = execute_to_batches(&ctx, "SELECT * from t").await; + let actual = ctx + .sql("SELECT * from t") + .await + .unwrap() + .collect() + .await + .unwrap(); assert_batches_sorted_eq!(expected, &actual); assert_metadata(&actual, &expected_metadata); } diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index a3d5e32097c6..849d85dec6bf 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -79,7 +79,6 @@ pub mod expr; pub mod group_by; pub mod joins; pub mod order; -pub mod parquet_schema; pub mod partitioned_csv; pub mod predicates; pub mod references; From 26a8000fe2343e6a187dcd6e4e8fc037d55e213f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 26 Dec 2023 07:04:43 -0500 Subject: [PATCH 33/63] Fix group by aliased expression in LogicalPLanBuilder::aggregate (#8629) --- datafusion/core/src/dataframe/mod.rs | 36 ++++++++++++- datafusion/expr/src/logical_plan/builder.rs | 58 ++++++++++++++------- 2 files changed, 73 insertions(+), 21 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 2ae4a7c21a9c..3c3bcd497b7f 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1769,8 +1769,8 @@ mod tests { let df_results = df.collect().await?; #[rustfmt::skip] - assert_batches_sorted_eq!( - [ "+----+", + assert_batches_sorted_eq!([ + "+----+", "| id |", "+----+", "| 1 |", @@ -1781,6 +1781,38 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_alias() -> Result<()> { + let df = test_table().await?; + + let df = df + // GROUP BY `c2 + 1` + .aggregate(vec![col("c2") + lit(1)], vec![])? + // SELECT `c2 + 1` as c2 + .select(vec![(col("c2") + lit(1)).alias("c2")])? + // GROUP BY c2 as "c2" (alias in expr is not supported by SQL) + .aggregate(vec![col("c2").alias("c2")], vec![])?; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!([ + "+----+", + "| c2 |", + "+----+", + "| 2 |", + "| 3 |", + "| 4 |", + "| 5 |", + "| 6 |", + "+----+", + ], + &df_results + ); + + Ok(()) + } + #[tokio::test] async fn test_distinct() -> Result<()> { let t = test_table().await?; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 88310dab82a2..549c25f89bae 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -904,27 +904,11 @@ impl LogicalPlanBuilder { group_expr: impl IntoIterator>, aggr_expr: impl IntoIterator>, ) -> Result { - let mut group_expr = normalize_cols(group_expr, &self.plan)?; + let group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; - // Rewrite groupby exprs according to functional dependencies - let group_by_expr_names = group_expr - .iter() - .map(|group_by_expr| group_by_expr.display_name()) - .collect::>>()?; - let schema = self.plan.schema(); - if let Some(target_indices) = - get_target_functional_dependencies(schema, &group_by_expr_names) - { - for idx in target_indices { - let field = schema.field(idx); - let expr = - Expr::Column(Column::new(field.qualifier().cloned(), field.name())); - if !group_expr.contains(&expr) { - group_expr.push(expr); - } - } - } + let group_expr = + add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr) .map(LogicalPlan::Aggregate) .map(Self::from) @@ -1189,6 +1173,42 @@ pub fn build_join_schema( schema.with_functional_dependencies(func_dependencies) } +/// Add additional "synthetic" group by expressions based on functional +/// dependencies. +/// +/// For example, if we are grouping on `[c1]`, and we know from +/// functional dependencies that column `c1` determines `c2`, this function +/// adds `c2` to the group by list. +/// +/// This allows MySQL style selects like +/// `SELECT col FROM t WHERE pk = 5` if col is unique +fn add_group_by_exprs_from_dependencies( + mut group_expr: Vec, + schema: &DFSchemaRef, +) -> Result> { + // Names of the fields produced by the GROUP BY exprs for example, `GROUP BY + // c1 + 1` produces an output field named `"c1 + 1"` + let mut group_by_field_names = group_expr + .iter() + .map(|e| e.display_name()) + .collect::>>()?; + + if let Some(target_indices) = + get_target_functional_dependencies(schema, &group_by_field_names) + { + for idx in target_indices { + let field = schema.field(idx); + let expr = + Expr::Column(Column::new(field.qualifier().cloned(), field.name())); + let expr_name = expr.display_name()?; + if !group_by_field_names.contains(&expr_name) { + group_by_field_names.push(expr_name); + group_expr.push(expr); + } + } + } + Ok(group_expr) +} /// Errors if one or more expressions have equal names. pub(crate) fn validate_unique_names<'a>( node_name: &str, From 58b0a2bfd4ec9b671fd60b8992b111fc8acd4889 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 27 Dec 2023 13:51:14 +0100 Subject: [PATCH 34/63] Refactor `array_union` and `array_intersect` functions to one general function (#8516) * Refactor array_union and array_intersect functions * fix cli * fix ci * add tests for null * modify the return type * update tests * fix clippy * fix clippy * add tests for largelist * fix clippy * Add field parameter to generic_set_lists() function * Add large array drop statements * fix clippy --- datafusion/expr/src/built_in_function.rs | 13 +- .../physical-expr/src/array_expressions.rs | 283 +++++++++-------- datafusion/sqllogictest/test_files/array.slt | 294 +++++++++++++++++- 3 files changed, 446 insertions(+), 144 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 3818e8ee5658..c454a9781eda 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -618,7 +618,18 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), - BuiltinScalarFunction::ArrayUnion | BuiltinScalarFunction::ArrayIntersect => { + BuiltinScalarFunction::ArrayIntersect => { + match (input_expr_types[0].clone(), input_expr_types[1].clone()) { + (DataType::Null, DataType::Null) | (DataType::Null, _) => { + Ok(DataType::Null) + } + (_, DataType::Null) => { + Ok(List(Arc::new(Field::new("item", Null, true)))) + } + (dt, _) => Ok(dt), + } + } + BuiltinScalarFunction::ArrayUnion => { match (input_expr_types[0].clone(), input_expr_types[1].clone()) { (DataType::Null, dt) => Ok(dt), (dt, DataType::Null) => Ok(dt), diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 3ee99d7e8e55..274d1db4eb0d 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -19,6 +19,7 @@ use std::any::type_name; use std::collections::HashSet; +use std::fmt::{Display, Formatter}; use std::sync::Arc; use arrow::array::*; @@ -1777,97 +1778,173 @@ macro_rules! to_string { }}; } -fn union_generic_lists( +#[derive(Debug, PartialEq)] +enum SetOp { + Union, + Intersect, +} + +impl Display for SetOp { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + SetOp::Union => write!(f, "array_union"), + SetOp::Intersect => write!(f, "array_intersect"), + } + } +} + +fn generic_set_lists( l: &GenericListArray, r: &GenericListArray, - field: &FieldRef, -) -> Result> { - let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; + field: Arc, + set_op: SetOp, +) -> Result { + if matches!(l.value_type(), DataType::Null) { + let field = Arc::new(Field::new("item", r.value_type(), true)); + return general_array_distinct::(r, &field); + } else if matches!(r.value_type(), DataType::Null) { + let field = Arc::new(Field::new("item", l.value_type(), true)); + return general_array_distinct::(l, &field); + } - let nulls = NullBuffer::union(l.nulls(), r.nulls()); - let l_values = l.values().clone(); - let r_values = r.values().clone(); - let l_values = converter.convert_columns(&[l_values])?; - let r_values = converter.convert_columns(&[r_values])?; + if l.value_type() != r.value_type() { + return internal_err!("{set_op:?} is not implemented for '{l:?}' and '{r:?}'"); + } - // Might be worth adding an upstream OffsetBufferBuilder - let mut offsets = Vec::::with_capacity(l.len() + 1); - offsets.push(OffsetSize::usize_as(0)); - let mut rows = Vec::with_capacity(l_values.num_rows() + r_values.num_rows()); - let mut dedup = HashSet::new(); - for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) { - let l_slice = l_w[0].as_usize()..l_w[1].as_usize(); - let r_slice = r_w[0].as_usize()..r_w[1].as_usize(); - for i in l_slice { - let left_row = l_values.row(i); - if dedup.insert(left_row) { - rows.push(left_row); - } - } - for i in r_slice { - let right_row = r_values.row(i); - if dedup.insert(right_row) { - rows.push(right_row); + let dt = l.value_type(); + + let mut offsets = vec![OffsetSize::usize_as(0)]; + let mut new_arrays = vec![]; + + let converter = RowConverter::new(vec![SortField::new(dt)])?; + for (first_arr, second_arr) in l.iter().zip(r.iter()) { + if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { + let l_values = converter.convert_columns(&[first_arr])?; + let r_values = converter.convert_columns(&[second_arr])?; + + let l_iter = l_values.iter().sorted().dedup(); + let values_set: HashSet<_> = l_iter.clone().collect(); + let mut rows = if set_op == SetOp::Union { + l_iter.collect::>() + } else { + vec![] + }; + for r_val in r_values.iter().sorted().dedup() { + match set_op { + SetOp::Union => { + if !values_set.contains(&r_val) { + rows.push(r_val); + } + } + SetOp::Intersect => { + if values_set.contains(&r_val) { + rows.push(r_val); + } + } + } } + + let last_offset = match offsets.last().copied() { + Some(offset) => offset, + None => return internal_err!("offsets should not be empty"), + }; + offsets.push(last_offset + OffsetSize::usize_as(rows.len())); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.first() { + Some(array) => array.clone(), + None => { + return internal_err!("{set_op}: failed to get array from rows"); + } + }; + new_arrays.push(array); } - offsets.push(OffsetSize::usize_as(rows.len())); - dedup.clear(); } - let values = converter.convert_rows(rows)?; let offsets = OffsetBuffer::new(offsets.into()); - let result = values[0].clone(); - Ok(GenericListArray::::new( - field.clone(), - offsets, - result, - nulls, - )) + let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let values = compute::concat(&new_arrays_ref)?; + let arr = GenericListArray::::try_new(field, offsets, values, None)?; + Ok(Arc::new(arr)) } -/// Array_union SQL function -pub fn array_union(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_union needs 2 arguments"); - } - let array1 = &args[0]; - let array2 = &args[1]; +fn general_set_op( + array1: &ArrayRef, + array2: &ArrayRef, + set_op: SetOp, +) -> Result { + match (array1.data_type(), array2.data_type()) { + (DataType::Null, DataType::List(field)) => { + if set_op == SetOp::Intersect { + return Ok(new_empty_array(&DataType::Null)); + } + let array = as_list_array(&array2)?; + general_array_distinct::(array, field) + } - fn union_arrays( - array1: &ArrayRef, - array2: &ArrayRef, - l_field_ref: &Arc, - r_field_ref: &Arc, - ) -> Result { - match (l_field_ref.data_type(), r_field_ref.data_type()) { - (DataType::Null, _) => Ok(array2.clone()), - (_, DataType::Null) => Ok(array1.clone()), - (_, _) => { - let list1 = array1.as_list::(); - let list2 = array2.as_list::(); - let result = union_generic_lists::(list1, list2, l_field_ref)?; - Ok(Arc::new(result)) + (DataType::List(field), DataType::Null) => { + if set_op == SetOp::Intersect { + return make_array(&[]); } + let array = as_list_array(&array1)?; + general_array_distinct::(array, field) } - } + (DataType::Null, DataType::LargeList(field)) => { + if set_op == SetOp::Intersect { + return Ok(new_empty_array(&DataType::Null)); + } + let array = as_large_list_array(&array2)?; + general_array_distinct::(array, field) + } + (DataType::LargeList(field), DataType::Null) => { + if set_op == SetOp::Intersect { + return make_array(&[]); + } + let array = as_large_list_array(&array1)?; + general_array_distinct::(array, field) + } + (DataType::Null, DataType::Null) => Ok(new_empty_array(&DataType::Null)), - match (array1.data_type(), array2.data_type()) { - (DataType::Null, _) => Ok(array2.clone()), - (_, DataType::Null) => Ok(array1.clone()), - (DataType::List(l_field_ref), DataType::List(r_field_ref)) => { - union_arrays::(array1, array2, l_field_ref, r_field_ref) + (DataType::List(field), DataType::List(_)) => { + let array1 = as_list_array(&array1)?; + let array2 = as_list_array(&array2)?; + generic_set_lists::(array1, array2, field.clone(), set_op) } - (DataType::LargeList(l_field_ref), DataType::LargeList(r_field_ref)) => { - union_arrays::(array1, array2, l_field_ref, r_field_ref) + (DataType::LargeList(field), DataType::LargeList(_)) => { + let array1 = as_large_list_array(&array1)?; + let array2 = as_large_list_array(&array2)?; + generic_set_lists::(array1, array2, field.clone(), set_op) } - _ => { + (data_type1, data_type2) => { internal_err!( - "array_union only support list with offsets of type int32 and int64" + "{set_op} does not support types '{data_type1:?}' and '{data_type2:?}'" ) } } } +/// Array_union SQL function +pub fn array_union(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_union needs two arguments"); + } + let array1 = &args[0]; + let array2 = &args[1]; + + general_set_op(array1, array2, SetOp::Union) +} + +/// array_intersect SQL function +pub fn array_intersect(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_intersect needs two arguments"); + } + + let array1 = &args[0]; + let array2 = &args[1]; + + general_set_op(array1, array2, SetOp::Intersect) +} + /// Array_to_string SQL function pub fn array_to_string(args: &[ArrayRef]) -> Result { if args.len() < 2 || args.len() > 3 { @@ -2228,7 +2305,7 @@ pub fn array_has(args: &[ArrayRef]) -> Result { DataType::LargeList(_) => { general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) } - _ => internal_err!("array_has does not support type '{array_type:?}'."), + _ => exec_err!("array_has does not support type '{array_type:?}'."), } } @@ -2359,74 +2436,6 @@ pub fn string_to_array(args: &[ArrayRef]) -> Result Result { - if args.len() != 2 { - return exec_err!("array_intersect needs two arguments"); - } - - let first_array = &args[0]; - let second_array = &args[1]; - - match (first_array.data_type(), second_array.data_type()) { - (DataType::Null, _) => Ok(second_array.clone()), - (_, DataType::Null) => Ok(first_array.clone()), - _ => { - let first_array = as_list_array(&first_array)?; - let second_array = as_list_array(&second_array)?; - - if first_array.value_type() != second_array.value_type() { - return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'"); - } - - let dt = first_array.value_type(); - - let mut offsets = vec![0]; - let mut new_arrays = vec![]; - - let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; - for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) { - if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { - let l_values = converter.convert_columns(&[first_arr])?; - let r_values = converter.convert_columns(&[second_arr])?; - - let values_set: HashSet<_> = l_values.iter().collect(); - let mut rows = Vec::with_capacity(r_values.num_rows()); - for r_val in r_values.iter().sorted().dedup() { - if values_set.contains(&r_val) { - rows.push(r_val); - } - } - - let last_offset: i32 = match offsets.last().copied() { - Some(offset) => offset, - None => return internal_err!("offsets should not be empty"), - }; - offsets.push(last_offset + rows.len() as i32); - let arrays = converter.convert_rows(rows)?; - let array = match arrays.first() { - Some(array) => array.clone(), - None => { - return internal_err!( - "array_intersect: failed to get array from rows" - ) - } - }; - new_arrays.push(array); - } - } - - let field = Arc::new(Field::new("item", dt, true)); - let offsets = OffsetBuffer::new(offsets.into()); - let new_arrays_ref = - new_arrays.iter().map(|v| v.as_ref()).collect::>(); - let values = compute::concat(&new_arrays_ref)?; - let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?); - Ok(arr) - } - } -} - pub fn general_array_distinct( array: &GenericListArray, field: &FieldRef, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 283f2d67b7a0..4c4adbabfda5 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -231,6 +231,19 @@ AS VALUES (make_array(11, 22), make_array(11), make_array(11,22,33), make_array(11,33), make_array(11,33,55), make_array(22,44,66,88,11,33)) ; +statement ok +CREATE TABLE large_array_intersect_table_1D +AS + SELECT + arrow_cast(column1, 'LargeList(Int64)') as column1, + arrow_cast(column2, 'LargeList(Int64)') as column2, + arrow_cast(column3, 'LargeList(Int64)') as column3, + arrow_cast(column4, 'LargeList(Int64)') as column4, + arrow_cast(column5, 'LargeList(Int64)') as column5, + arrow_cast(column6, 'LargeList(Int64)') as column6 +FROM array_intersect_table_1D +; + statement ok CREATE TABLE array_intersect_table_1D_Float AS VALUES @@ -238,6 +251,19 @@ AS VALUES (make_array(3.0, 4.0, 5.0), make_array(2.0), make_array(1.0,2.0,3.0,4.0), make_array(2.0,5.0), make_array(2.22, 1.11), make_array(1.11, 3.33)) ; +statement ok +CREATE TABLE large_array_intersect_table_1D_Float +AS + SELECT + arrow_cast(column1, 'LargeList(Float64)') as column1, + arrow_cast(column2, 'LargeList(Float64)') as column2, + arrow_cast(column3, 'LargeList(Float64)') as column3, + arrow_cast(column4, 'LargeList(Float64)') as column4, + arrow_cast(column5, 'LargeList(Float64)') as column5, + arrow_cast(column6, 'LargeList(Float64)') as column6 +FROM array_intersect_table_1D_Float +; + statement ok CREATE TABLE array_intersect_table_1D_Boolean AS VALUES @@ -245,6 +271,19 @@ AS VALUES (make_array(false, false, false), make_array(false), make_array(true, false, true), make_array(true, true), make_array(true, true), make_array(false,false,true)) ; +statement ok +CREATE TABLE large_array_intersect_table_1D_Boolean +AS + SELECT + arrow_cast(column1, 'LargeList(Boolean)') as column1, + arrow_cast(column2, 'LargeList(Boolean)') as column2, + arrow_cast(column3, 'LargeList(Boolean)') as column3, + arrow_cast(column4, 'LargeList(Boolean)') as column4, + arrow_cast(column5, 'LargeList(Boolean)') as column5, + arrow_cast(column6, 'LargeList(Boolean)') as column6 +FROM array_intersect_table_1D_Boolean +; + statement ok CREATE TABLE array_intersect_table_1D_UTF8 AS VALUES @@ -252,6 +291,19 @@ AS VALUES (make_array('a', 'bc', 'def'), make_array('defg'), make_array('datafusion', 'rust', 'arrow'), make_array('datafusion', 'rust', 'arrow', 'python'), make_array('rust', 'arrow'), make_array('datafusion', 'rust', 'arrow')) ; +statement ok +CREATE TABLE large_array_intersect_table_1D_UTF8 +AS + SELECT + arrow_cast(column1, 'LargeList(Utf8)') as column1, + arrow_cast(column2, 'LargeList(Utf8)') as column2, + arrow_cast(column3, 'LargeList(Utf8)') as column3, + arrow_cast(column4, 'LargeList(Utf8)') as column4, + arrow_cast(column5, 'LargeList(Utf8)') as column5, + arrow_cast(column6, 'LargeList(Utf8)') as column6 +FROM array_intersect_table_1D_UTF8 +; + statement ok CREATE TABLE array_intersect_table_2D AS VALUES @@ -259,6 +311,17 @@ AS VALUES (make_array([3,4], [5]), make_array([3,4]), make_array([1,2,3,4], [5,6,7], [8,9,10]), make_array([1,2,3], [5,6,7], [8,9,10])) ; +statement ok +CREATE TABLE large_array_intersect_table_2D +AS + SELECT + arrow_cast(column1, 'LargeList(List(Int64))') as column1, + arrow_cast(column2, 'LargeList(List(Int64))') as column2, + arrow_cast(column3, 'LargeList(List(Int64))') as column3, + arrow_cast(column4, 'LargeList(List(Int64))') as column4 +FROM array_intersect_table_2D +; + statement ok CREATE TABLE array_intersect_table_2D_float AS VALUES @@ -266,6 +329,15 @@ AS VALUES (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.0], [1.1, 2.2], [3.3])) ; +statement ok +CREATE TABLE large_array_intersect_table_2D_Float +AS + SELECT + arrow_cast(column1, 'LargeList(List(Float64))') as column1, + arrow_cast(column2, 'LargeList(List(Float64))') as column2 +FROM array_intersect_table_2D_Float +; + statement ok CREATE TABLE array_intersect_table_3D AS VALUES @@ -273,6 +345,15 @@ AS VALUES (make_array([[1,2]]), make_array([[1,2]])) ; +statement ok +CREATE TABLE large_array_intersect_table_3D +AS + SELECT + arrow_cast(column1, 'LargeList(List(List(Int64)))') as column1, + arrow_cast(column2, 'LargeList(List(List(Int64)))') as column2 +FROM array_intersect_table_3D +; + statement ok CREATE TABLE arrays_values_without_nulls AS VALUES @@ -2589,24 +2670,44 @@ select array_union([1, 2, 3, 4], [5, 6, 3, 4]); ---- [1, 2, 3, 4, 5, 6] +query ? +select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, 6, 3, 4], 'LargeList(Int64)')); +---- +[1, 2, 3, 4, 5, 6] + # array_union scalar function #2 query ? select array_union([1, 2, 3, 4], [5, 6, 7, 8]); ---- [1, 2, 3, 4, 5, 6, 7, 8] +query ? +select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, 6, 7, 8], 'LargeList(Int64)')); +---- +[1, 2, 3, 4, 5, 6, 7, 8] + # array_union scalar function #3 query ? select array_union([1,2,3], []); ---- [1, 2, 3] +query ? +select array_union(arrow_cast([1,2,3], 'LargeList(Int64)'), arrow_cast([], 'LargeList(Null)')); +---- +[1, 2, 3] + # array_union scalar function #4 query ? select array_union([1, 2, 3, 4], [5, 4]); ---- [1, 2, 3, 4, 5] +query ? +select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, 4], 'LargeList(Int64)')); +---- +[1, 2, 3, 4, 5] + # array_union scalar function #5 statement ok CREATE TABLE arrays_with_repeating_elements_for_union @@ -2623,6 +2724,13 @@ select array_union(column1, column2) from arrays_with_repeating_elements_for_uni [2, 3] [3, 4] +query ? +select array_union(arrow_cast(column1, 'LargeList(Int64)'), arrow_cast(column2, 'LargeList(Int64)')) from arrays_with_repeating_elements_for_union; +---- +[1, 2] +[2, 3] +[3, 4] + statement ok drop table arrays_with_repeating_elements_for_union; @@ -2632,24 +2740,44 @@ select array_union([], []); ---- [] +query ? +select array_union(arrow_cast([], 'LargeList(Null)'), arrow_cast([], 'LargeList(Null)')); +---- +[] + # array_union scalar function #7 query ? select array_union([[null]], []); ---- [[]] +query ? +select array_union(arrow_cast([[null]], 'LargeList(List(Null))'), arrow_cast([], 'LargeList(Null)')); +---- +[[]] + # array_union scalar function #8 query ? select array_union([null], [null]); ---- [] +query ? +select array_union(arrow_cast([[null]], 'LargeList(List(Null))'), arrow_cast([[null]], 'LargeList(List(Null))')); +---- +[[]] + # array_union scalar function #9 query ? select array_union(null, []); ---- [] +query ? +select array_union(null, arrow_cast([], 'LargeList(Null)')); +---- +[] + # array_union scalar function #10 query ? select array_union(null, null); @@ -2658,21 +2786,47 @@ NULL # array_union scalar function #11 query ? -select array_union([1.2, 3.0], [1.2, 3.0, 5.7]); +select array_union([1, 1, 2, 2, 3, 3], null); ---- -[1.2, 3.0, 5.7] +[1, 2, 3] -# array_union scalar function #12 query ? -select array_union(['hello'], ['hello','datafusion']); +select array_union(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null); ---- -[hello, datafusion] +[1, 2, 3] +# array_union scalar function #12 +query ? +select array_union(null, [1, 1, 2, 2, 3, 3]); +---- +[1, 2, 3] +query ? +select array_union(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); +---- +[1, 2, 3] +# array_union scalar function #13 +query ? +select array_union([1.2, 3.0], [1.2, 3.0, 5.7]); +---- +[1.2, 3.0, 5.7] +query ? +select array_union(arrow_cast([1.2, 3.0], 'LargeList(Float64)'), arrow_cast([1.2, 3.0, 5.7], 'LargeList(Float64)')); +---- +[1.2, 3.0, 5.7] +# array_union scalar function #14 +query ? +select array_union(['hello'], ['hello','datafusion']); +---- +[hello, datafusion] +query ? +select array_union(arrow_cast(['hello'], 'LargeList(Utf8)'), arrow_cast(['hello','datafusion'], 'LargeList(Utf8)')); +---- +[hello, datafusion] # list_to_string scalar function #4 (function alias `array_to_string`) @@ -3536,6 +3690,15 @@ from array_intersect_table_1D; [1] [1, 3] [1, 3] [11] [11, 33] [11, 33] +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from large_array_intersect_table_1D; +---- +[1] [1, 3] [1, 3] +[11] [11, 33] [11, 33] + query ??? select array_intersect(column1, column2), array_intersect(column3, column4), @@ -3554,6 +3717,15 @@ from array_intersect_table_1D_Boolean; [] [false, true] [false] [false] [true] [true] +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from large_array_intersect_table_1D_Boolean; +---- +[] [false, true] [false] +[false] [true] [true] + query ??? select array_intersect(column1, column2), array_intersect(column3, column4), @@ -3563,6 +3735,15 @@ from array_intersect_table_1D_UTF8; [bc] [arrow, rust] [] [] [arrow, datafusion, rust] [arrow, rust] +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from large_array_intersect_table_1D_UTF8; +---- +[bc] [arrow, rust] [] +[] [arrow, datafusion, rust] [arrow, rust] + query ?? select array_intersect(column1, column2), array_intersect(column3, column4) @@ -3571,6 +3752,15 @@ from array_intersect_table_2D; [] [[4, 5], [6, 7]] [[3, 4]] [[5, 6, 7], [8, 9, 10]] +query ?? +select array_intersect(column1, column2), + array_intersect(column3, column4) +from large_array_intersect_table_2D; +---- +[] [[4, 5], [6, 7]] +[[3, 4]] [[5, 6, 7], [8, 9, 10]] + + query ? select array_intersect(column1, column2) from array_intersect_table_2D_float; @@ -3578,6 +3768,13 @@ from array_intersect_table_2D_float; [[1.1, 2.2], [3.3]] [[1.1, 2.2], [3.3]] +query ? +select array_intersect(column1, column2) +from large_array_intersect_table_2D_float; +---- +[[1.1, 2.2], [3.3]] +[[1.1, 2.2], [3.3]] + query ? select array_intersect(column1, column2) from array_intersect_table_3D; @@ -3585,6 +3782,13 @@ from array_intersect_table_3D; [] [[[1, 2]]] +query ? +select array_intersect(column1, column2) +from large_array_intersect_table_3D; +---- +[] +[[[1, 2]]] + query ?????? SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), array_intersect(make_array(1,3,5), make_array(2,4,6)), @@ -3596,21 +3800,67 @@ SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), ---- [2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] +query ?????? +SELECT array_intersect(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(2,3,4), 'LargeList(Int64)')), + array_intersect(arrow_cast(make_array(1,3,5), 'LargeList(Int64)'), arrow_cast(make_array(2,4,6), 'LargeList(Int64)')), + array_intersect(arrow_cast(make_array('aa','bb','cc'), 'LargeList(Utf8)'), arrow_cast(make_array('cc','aa','dd'), 'LargeList(Utf8)')), + array_intersect(arrow_cast(make_array(true, false), 'LargeList(Boolean)'), arrow_cast(make_array(true), 'LargeList(Boolean)')), + array_intersect(arrow_cast(make_array(1.1, 2.2, 3.3), 'LargeList(Float64)'), arrow_cast(make_array(2.2, 3.3, 4.4), 'LargeList(Float64)')), + array_intersect(arrow_cast(make_array([1, 1], [2, 2], [3, 3]), 'LargeList(List(Int64))'), arrow_cast(make_array([2, 2], [3, 3], [4, 4]), 'LargeList(List(Int64))')) +; +---- +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] + query ? select array_intersect([], []); ---- [] +query ? +select array_intersect(arrow_cast([], 'LargeList(Null)'), arrow_cast([], 'LargeList(Null)')); +---- +[] + +query ? +select array_intersect([1, 1, 2, 2, 3, 3], null); +---- +[] + +query ? +select array_intersect(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null); +---- +[] + +query ? +select array_intersect(null, [1, 1, 2, 2, 3, 3]); +---- +NULL + +query ? +select array_intersect(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); +---- +NULL + query ? select array_intersect([], null); ---- [] query ? -select array_intersect(null, []); +select array_intersect(arrow_cast([], 'LargeList(Null)'), null); ---- [] +query ? +select array_intersect(null, []); +---- +NULL + +query ? +select array_intersect(null, arrow_cast([], 'LargeList(Null)')); +---- +NULL + query ? select array_intersect(null, null); ---- @@ -3627,6 +3877,17 @@ SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)), ---- [2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] +query ?????? +SELECT list_intersect(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(2,3,4), 'LargeList(Int64)')), + list_intersect(arrow_cast(make_array(1,3,5), 'LargeList(Int64)'), arrow_cast(make_array(2,4,6), 'LargeList(Int64)')), + list_intersect(arrow_cast(make_array('aa','bb','cc'), 'LargeList(Utf8)'), arrow_cast(make_array('cc','aa','dd'), 'LargeList(Utf8)')), + list_intersect(arrow_cast(make_array(true, false), 'LargeList(Boolean)'), arrow_cast(make_array(true), 'LargeList(Boolean)')), + list_intersect(arrow_cast(make_array(1.1, 2.2, 3.3), 'LargeList(Float64)'), arrow_cast(make_array(2.2, 3.3, 4.4), 'LargeList(Float64)')), + list_intersect(arrow_cast(make_array([1, 1], [2, 2], [3, 3]), 'LargeList(List(Int64))'), arrow_cast(make_array([2, 2], [3, 3], [4, 4]), 'LargeList(List(Int64))')) +; +---- +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] + query BBBB select list_has_all(make_array(1,2,3), make_array(4,5,6)), list_has_all(make_array(1,2,3), make_array(1,2)), @@ -4106,24 +4367,45 @@ drop table array_has_table_3D; statement ok drop table array_intersect_table_1D; +statement ok +drop table large_array_intersect_table_1D; + statement ok drop table array_intersect_table_1D_Float; +statement ok +drop table large_array_intersect_table_1D_Float; + statement ok drop table array_intersect_table_1D_Boolean; +statement ok +drop table large_array_intersect_table_1D_Boolean; + statement ok drop table array_intersect_table_1D_UTF8; +statement ok +drop table large_array_intersect_table_1D_UTF8; + statement ok drop table array_intersect_table_2D; +statement ok +drop table large_array_intersect_table_2D; + statement ok drop table array_intersect_table_2D_float; +statement ok +drop table large_array_intersect_table_2D_float; + statement ok drop table array_intersect_table_3D; +statement ok +drop table large_array_intersect_table_3D; + statement ok drop table arrays_values_without_nulls; From bb99d2a97df3c654ee8c1d5520ffd15ef5612193 Mon Sep 17 00:00:00 2001 From: Chih Wang Date: Wed, 27 Dec 2023 22:56:14 +0800 Subject: [PATCH 35/63] Avoid extra clone in datafusion-proto::physical_plan (#8650) --- datafusion/proto/src/physical_plan/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index df01097cfa78..24ede3fcaf62 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -486,7 +486,7 @@ impl AsExecutionPlan for PhysicalPlanNode { physical_aggr_expr, physical_filter_expr, input, - Arc::new(input_schema.try_into()?), + physical_schema, )?)) } PhysicalPlanType::HashJoin(hashjoin) => { From 28ca6d1ad9692d0f159ed1f1f45a20c0998a47ea Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 27 Dec 2023 10:08:39 -0500 Subject: [PATCH 36/63] Minor: name some constant values in arrow writer, parquet writer (#8642) * Minor: name some constant values in arrow writer * Add constants to parquet.rs, update doc comments * fix --- .../core/src/datasource/file_format/arrow.rs | 13 ++++++++++--- .../core/src/datasource/file_format/avro.rs | 2 +- .../core/src/datasource/file_format/csv.rs | 2 +- .../core/src/datasource/file_format/json.rs | 2 +- .../src/datasource/file_format/parquet.rs | 19 +++++++++++++++---- 5 files changed, 28 insertions(+), 10 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 7d393d9129dd..650f8c844eda 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Apache Arrow format abstractions +//! [`ArrowFormat`]: Apache Arrow [`FileFormat`] abstractions //! //! Works with files following the [Arrow IPC format](https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format) @@ -58,6 +58,13 @@ use super::file_compression_type::FileCompressionType; use super::write::demux::start_demuxer_task; use super::write::{create_writer, SharedBuffer}; +/// Initial writing buffer size. Note this is just a size hint for efficiency. It +/// will grow beyond the set value if needed. +const INITIAL_BUFFER_BYTES: usize = 1048576; + +/// If the buffered Arrow data exceeds this size, it is flushed to object store +const BUFFER_FLUSH_BYTES: usize = 1024000; + /// Arrow `FileFormat` implementation. #[derive(Default, Debug)] pub struct ArrowFormat; @@ -239,7 +246,7 @@ impl DataSink for ArrowFileSink { IpcWriteOptions::try_new(64, false, arrow_ipc::MetadataVersion::V5)? .try_with_compression(Some(CompressionType::LZ4_FRAME))?; while let Some((path, mut rx)) = file_stream_rx.recv().await { - let shared_buffer = SharedBuffer::new(1048576); + let shared_buffer = SharedBuffer::new(INITIAL_BUFFER_BYTES); let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options( shared_buffer.clone(), &self.get_writer_schema(), @@ -257,7 +264,7 @@ impl DataSink for ArrowFileSink { row_count += batch.num_rows(); arrow_writer.write(&batch)?; let mut buff_to_flush = shared_buffer.buffer.try_lock().unwrap(); - if buff_to_flush.len() > 1024000 { + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { object_store_writer .write_all(buff_to_flush.as_slice()) .await?; diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index a24a28ad6fdd..6d424bf0b28f 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Apache Avro format abstractions +//! [`AvroFormat`] Apache Avro [`FileFormat`] abstractions use std::any::Any; use std::sync::Arc; diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index df6689af6b73..4033bcd3b557 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! CSV format abstractions +//! [`CsvFormat`], Comma Separated Value (CSV) [`FileFormat`] abstractions use std::any::Any; use std::collections::HashSet; diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 9893a1db45de..fcb1d5f8e527 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Line delimited JSON format abstractions +//! [`JsonFormat`]: Line delimited JSON [`FileFormat`] abstractions use std::any::Any; use std::fmt; diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 0c813b6ccbf0..7044acccd6dc 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Parquet format abstractions +//! [`ParquetFormat`]: Parquet [`FileFormat`] abstractions use arrow_array::RecordBatch; use async_trait::async_trait; @@ -75,6 +75,17 @@ use crate::physical_plan::{ Statistics, }; +/// Size of the buffer for [`AsyncArrowWriter`]. +const PARQUET_WRITER_BUFFER_SIZE: usize = 10485760; + +/// Initial writing buffer size. Note this is just a size hint for efficiency. It +/// will grow beyond the set value if needed. +const INITIAL_BUFFER_BYTES: usize = 1048576; + +/// When writing parquet files in parallel, if the buffered Parquet data exceeds +/// this size, it is flushed to object store +const BUFFER_FLUSH_BYTES: usize = 1024000; + /// The Apache Parquet `FileFormat` implementation /// /// Note it is recommended these are instead configured on the [`ConfigOptions`] @@ -680,7 +691,7 @@ impl ParquetSink { let writer = AsyncArrowWriter::try_new( multipart_writer, self.get_writer_schema(), - 10485760, + PARQUET_WRITER_BUFFER_SIZE, Some(parquet_props), )?; Ok(writer) @@ -1004,7 +1015,7 @@ async fn concatenate_parallel_row_groups( writer_props: Arc, mut object_store_writer: AbortableWrite>, ) -> Result { - let merged_buff = SharedBuffer::new(1048576); + let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES); let schema_desc = arrow_to_parquet_schema(schema.as_ref())?; let mut parquet_writer = SerializedFileWriter::new( @@ -1025,7 +1036,7 @@ async fn concatenate_parallel_row_groups( for chunk in serialized_columns { chunk.append_to_row_group(&mut rg_out)?; let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); - if buff_to_flush.len() > 1024000 { + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { object_store_writer .write_all(buff_to_flush.as_slice()) .await?; From 6403222c1eda8ed3438fe2555229319b92bfa056 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Wed, 27 Dec 2023 23:18:27 +0300 Subject: [PATCH 37/63] TreeNode Refactor Part 2 (#8653) * Refactor TreeNode's * Update utils.rs * Final review * Remove unnecessary clones, more idiomatic Rust --------- Co-authored-by: Mehmet Ozan Kabak --- .../enforce_distribution.rs | 767 ++++++++---------- .../src/physical_optimizer/enforce_sorting.rs | 554 +++++++------ .../physical_optimizer/output_requirements.rs | 4 + .../physical_optimizer/pipeline_checker.rs | 32 +- .../replace_with_order_preserving_variants.rs | 292 +++---- .../src/physical_optimizer/sort_pushdown.rs | 138 ++-- .../core/src/physical_optimizer/utils.rs | 69 +- .../physical-expr/src/sort_properties.rs | 10 +- datafusion/physical-plan/src/union.rs | 20 +- 9 files changed, 872 insertions(+), 1014 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 0aef126578f3..d5a086227323 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -25,11 +25,11 @@ use std::fmt; use std::fmt::Formatter; use std::sync::Arc; +use super::output_requirements::OutputRequirementExec; use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::utils::{ - add_sort_above, get_children_exectrees, is_coalesce_partitions, is_repartition, - is_sort_preserving_merge, ExecTree, + is_coalesce_partitions, is_repartition, is_sort_preserving_merge, }; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; @@ -52,8 +52,10 @@ use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; use datafusion_physical_expr::{ - physical_exprs_equal, EquivalenceProperties, PhysicalExpr, + physical_exprs_equal, EquivalenceProperties, LexRequirementRef, PhysicalExpr, + PhysicalSortRequirement, }; +use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::windows::{get_best_fitting_window, BoundedWindowAggExec}; use datafusion_physical_plan::{get_plan_string, unbounded_output}; @@ -268,11 +270,12 @@ impl PhysicalOptimizerRule for EnforceDistribution { /// 5) For other types of operators, by default, pushdown the parent requirements to children. /// fn adjust_input_keys_ordering( - requirements: PlanWithKeyRequirements, + mut requirements: PlanWithKeyRequirements, ) -> Result> { let parent_required = requirements.required_key_ordering.clone(); let plan_any = requirements.plan.as_any(); - let transformed = if let Some(HashJoinExec { + + if let Some(HashJoinExec { left, right, on, @@ -287,7 +290,7 @@ fn adjust_input_keys_ordering( PartitionMode::Partitioned => { let join_constructor = |new_conditions: (Vec<(Column, Column)>, Vec)| { - Ok(Arc::new(HashJoinExec::try_new( + HashJoinExec::try_new( left.clone(), right.clone(), new_conditions.0, @@ -295,15 +298,17 @@ fn adjust_input_keys_ordering( join_type, PartitionMode::Partitioned, *null_equals_null, - )?) as Arc) + ) + .map(|e| Arc::new(e) as _) }; - Some(reorder_partitioned_join_keys( + reorder_partitioned_join_keys( requirements.plan.clone(), &parent_required, on, vec![], &join_constructor, - )?) + ) + .map(Transformed::Yes) } PartitionMode::CollectLeft => { let new_right_request = match join_type { @@ -321,15 +326,15 @@ fn adjust_input_keys_ordering( }; // Push down requirements to the right side - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![None, new_right_request], - }) + requirements.children[1].required_key_ordering = + new_right_request.unwrap_or(vec![]); + Ok(Transformed::Yes(requirements)) } PartitionMode::Auto => { // Can not satisfy, clear the current requirements and generate new empty requirements - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))) } } } else if let Some(CrossJoinExec { left, .. }) = @@ -337,14 +342,9 @@ fn adjust_input_keys_ordering( { let left_columns_len = left.schema().fields().len(); // Push down requirements to the right side - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![ - None, - shift_right_required(&parent_required, left_columns_len), - ], - }) + requirements.children[1].required_key_ordering = + shift_right_required(&parent_required, left_columns_len).unwrap_or_default(); + Ok(Transformed::Yes(requirements)) } else if let Some(SortMergeJoinExec { left, right, @@ -357,35 +357,40 @@ fn adjust_input_keys_ordering( { let join_constructor = |new_conditions: (Vec<(Column, Column)>, Vec)| { - Ok(Arc::new(SortMergeJoinExec::try_new( + SortMergeJoinExec::try_new( left.clone(), right.clone(), new_conditions.0, *join_type, new_conditions.1, *null_equals_null, - )?) as Arc) + ) + .map(|e| Arc::new(e) as _) }; - Some(reorder_partitioned_join_keys( + reorder_partitioned_join_keys( requirements.plan.clone(), &parent_required, on, sort_options.clone(), &join_constructor, - )?) + ) + .map(Transformed::Yes) } else if let Some(aggregate_exec) = plan_any.downcast_ref::() { if !parent_required.is_empty() { match aggregate_exec.mode() { - AggregateMode::FinalPartitioned => Some(reorder_aggregate_keys( + AggregateMode::FinalPartitioned => reorder_aggregate_keys( requirements.plan.clone(), &parent_required, aggregate_exec, - )?), - _ => Some(PlanWithKeyRequirements::new(requirements.plan.clone())), + ) + .map(Transformed::Yes), + _ => Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))), } } else { // Keep everything unchanged - None + Ok(Transformed::No(requirements)) } } else if let Some(proj) = plan_any.downcast_ref::() { let expr = proj.expr(); @@ -394,34 +399,28 @@ fn adjust_input_keys_ordering( // Construct a mapping from new name to the the orginal Column let new_required = map_columns_before_projection(&parent_required, expr); if new_required.len() == parent_required.len() { - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![Some(new_required.clone())], - }) + requirements.children[0].required_key_ordering = new_required; + Ok(Transformed::Yes(requirements)) } else { // Can not satisfy, clear the current requirements and generate new empty requirements - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))) } } else if plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() { - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))) } else { // By default, push down the parent requirements to children - let children_len = requirements.plan.children().len(); - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![Some(parent_required.clone()); children_len], - }) - }; - Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) - } else { - Transformed::No(requirements) - }) + requirements.children.iter_mut().for_each(|child| { + child.required_key_ordering = parent_required.clone(); + }); + Ok(Transformed::Yes(requirements)) + } } fn reorder_partitioned_join_keys( @@ -452,28 +451,24 @@ where for idx in 0..sort_options.len() { new_sort_options.push(sort_options[new_positions[idx]]) } - - Ok(PlanWithKeyRequirements { - plan: join_constructor((new_join_on, new_sort_options))?, - required_key_ordering: vec![], - request_key_ordering: vec![Some(left_keys), Some(right_keys)], - }) + let mut requirement_tree = PlanWithKeyRequirements::new(join_constructor(( + new_join_on, + new_sort_options, + ))?); + requirement_tree.children[0].required_key_ordering = left_keys; + requirement_tree.children[1].required_key_ordering = right_keys; + Ok(requirement_tree) } else { - Ok(PlanWithKeyRequirements { - plan: join_plan, - required_key_ordering: vec![], - request_key_ordering: vec![Some(left_keys), Some(right_keys)], - }) + let mut requirement_tree = PlanWithKeyRequirements::new(join_plan); + requirement_tree.children[0].required_key_ordering = left_keys; + requirement_tree.children[1].required_key_ordering = right_keys; + Ok(requirement_tree) } } else { - Ok(PlanWithKeyRequirements { - plan: join_plan, - required_key_ordering: vec![], - request_key_ordering: vec![ - Some(join_key_pairs.left_keys), - Some(join_key_pairs.right_keys), - ], - }) + let mut requirement_tree = PlanWithKeyRequirements::new(join_plan); + requirement_tree.children[0].required_key_ordering = join_key_pairs.left_keys; + requirement_tree.children[1].required_key_ordering = join_key_pairs.right_keys; + Ok(requirement_tree) } } @@ -868,59 +863,24 @@ fn new_join_conditions( .collect() } -/// Updates `dist_onward` such that, to keep track of -/// `input` in the `exec_tree`. -/// -/// # Arguments -/// -/// * `input`: Current execution plan -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until child of `input` (`input` should have single child). -/// * `input_idx`: index of the `input`, for its parent. -/// -fn update_distribution_onward( - input: Arc, - dist_onward: &mut Option, - input_idx: usize, -) { - // Update the onward tree if there is an active branch - if let Some(exec_tree) = dist_onward { - // When we add a new operator to change distribution - // we add RepartitionExec, SortPreservingMergeExec, CoalescePartitionsExec - // in this case, we need to update exec tree idx such that exec tree is now child of these - // operators (change the 0, since all of the operators have single child). - exec_tree.idx = 0; - *exec_tree = ExecTree::new(input, input_idx, vec![exec_tree.clone()]); - } else { - *dist_onward = Some(ExecTree::new(input, input_idx, vec![])); - } -} - /// Adds RoundRobin repartition operator to the plan increase parallelism. /// /// # Arguments /// -/// * `input`: Current execution plan +/// * `input`: Current node. /// * `n_target`: desired target partition number, if partition number of the /// current executor is less than this value. Partition number will be increased. -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until `input` plan. -/// * `input_idx`: index of the `input`, for its parent. /// /// # Returns /// -/// A [Result] object that contains new execution plan, where desired partition number -/// is achieved by adding RoundRobin Repartition. +/// A [`Result`] object that contains new execution plan where the desired +/// partition number is achieved by adding a RoundRobin repartition. fn add_roundrobin_on_top( - input: Arc, + input: DistributionContext, n_target: usize, - dist_onward: &mut Option, - input_idx: usize, -) -> Result> { - // Adding repartition is helpful - if input.output_partitioning().partition_count() < n_target { +) -> Result { + // Adding repartition is helpful: + if input.plan.output_partitioning().partition_count() < n_target { // When there is an existing ordering, we preserve ordering // during repartition. This will be un-done in the future // If any of the following conditions is true @@ -928,13 +888,16 @@ fn add_roundrobin_on_top( // - Usage of order preserving variants is not desirable // (determined by flag `config.optimizer.prefer_existing_sort`) let partitioning = Partitioning::RoundRobinBatch(n_target); - let repartition = - RepartitionExec::try_new(input, partitioning)?.with_preserve_order(); + let repartition = RepartitionExec::try_new(input.plan.clone(), partitioning)? + .with_preserve_order(); - // update distribution onward with new operator - let new_plan = Arc::new(repartition) as Arc; - update_distribution_onward(new_plan.clone(), dist_onward, input_idx); - Ok(new_plan) + let new_plan = Arc::new(repartition) as _; + + Ok(DistributionContext { + plan: new_plan, + distribution_connection: true, + children_nodes: vec![input], + }) } else { // Partition is not helpful, we already have desired number of partitions. Ok(input) @@ -948,46 +911,38 @@ fn add_roundrobin_on_top( /// /// # Arguments /// -/// * `input`: Current execution plan +/// * `input`: Current node. /// * `hash_exprs`: Stores Physical Exprs that are used during hashing. /// * `n_target`: desired target partition number, if partition number of the /// current executor is less than this value. Partition number will be increased. -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until `input` plan. -/// * `input_idx`: index of the `input`, for its parent. /// /// # Returns /// -/// A [`Result`] object that contains new execution plan, where desired distribution is -/// satisfied by adding Hash Repartition. +/// A [`Result`] object that contains new execution plan where the desired +/// distribution is satisfied by adding a Hash repartition. fn add_hash_on_top( - input: Arc, + mut input: DistributionContext, hash_exprs: Vec>, - // Repartition(Hash) will have `n_target` partitions at the output. n_target: usize, - // Stores executors starting from Repartition(RoundRobin) until - // current executor. When Repartition(Hash) is added, `dist_onward` - // is updated such that it stores connection from Repartition(RoundRobin) - // until Repartition(Hash). - dist_onward: &mut Option, - input_idx: usize, repartition_beneficial_stats: bool, -) -> Result> { - if n_target == input.output_partitioning().partition_count() && n_target == 1 { - // In this case adding a hash repartition is unnecessary as the hash - // requirement is implicitly satisfied. +) -> Result { + let partition_count = input.plan.output_partitioning().partition_count(); + // Early return if hash repartition is unnecessary + if n_target == partition_count && n_target == 1 { return Ok(input); } + let satisfied = input + .plan .output_partitioning() .satisfy(Distribution::HashPartitioned(hash_exprs.clone()), || { - input.equivalence_properties() + input.plan.equivalence_properties() }); + // Add hash repartitioning when: // - The hash distribution requirement is not satisfied, or // - We can increase parallelism by adding hash partitioning. - if !satisfied || n_target > input.output_partitioning().partition_count() { + if !satisfied || n_target > input.plan.output_partitioning().partition_count() { // When there is an existing ordering, we preserve ordering during // repartition. This will be rolled back in the future if any of the // following conditions is true: @@ -995,75 +950,66 @@ fn add_hash_on_top( // requirements. // - Usage of order preserving variants is not desirable (per the flag // `config.optimizer.prefer_existing_sort`). - let mut new_plan = if repartition_beneficial_stats { + if repartition_beneficial_stats { // Since hashing benefits from partitioning, add a round-robin repartition // before it: - add_roundrobin_on_top(input, n_target, dist_onward, 0)? - } else { - input - }; + input = add_roundrobin_on_top(input, n_target)?; + } + let partitioning = Partitioning::Hash(hash_exprs, n_target); - let repartition = RepartitionExec::try_new(new_plan, partitioning)? - // preserve any ordering if possible + let repartition = RepartitionExec::try_new(input.plan.clone(), partitioning)? .with_preserve_order(); - new_plan = Arc::new(repartition) as _; - // update distribution onward with new operator - update_distribution_onward(new_plan.clone(), dist_onward, input_idx); - Ok(new_plan) - } else { - Ok(input) + input.children_nodes = vec![input.clone()]; + input.distribution_connection = true; + input.plan = Arc::new(repartition) as _; } + + Ok(input) } -/// Adds a `SortPreservingMergeExec` operator on top of input executor: -/// - to satisfy single distribution requirement. +/// Adds a [`SortPreservingMergeExec`] operator on top of input executor +/// to satisfy single distribution requirement. /// /// # Arguments /// -/// * `input`: Current execution plan -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until `input` plan. -/// * `input_idx`: index of the `input`, for its parent. +/// * `input`: Current node. /// /// # Returns /// -/// New execution plan, where desired single -/// distribution is satisfied by adding `SortPreservingMergeExec`. -fn add_spm_on_top( - input: Arc, - dist_onward: &mut Option, - input_idx: usize, -) -> Arc { +/// Updated node with an execution plan, where desired single +/// distribution is satisfied by adding [`SortPreservingMergeExec`]. +fn add_spm_on_top(input: DistributionContext) -> DistributionContext { // Add SortPreservingMerge only when partition count is larger than 1. - if input.output_partitioning().partition_count() > 1 { + if input.plan.output_partitioning().partition_count() > 1 { // When there is an existing ordering, we preserve ordering - // during decreasıng partıtıons. This will be un-done in the future - // If any of the following conditions is true + // when decreasing partitions. This will be un-done in the future + // if any of the following conditions is true // - Preserving ordering is not helpful in terms of satisfying ordering requirements // - Usage of order preserving variants is not desirable - // (determined by flag `config.optimizer.prefer_existing_sort`) - let should_preserve_ordering = input.output_ordering().is_some(); - let new_plan: Arc = if should_preserve_ordering { - let existing_ordering = input.output_ordering().unwrap_or(&[]); + // (determined by flag `config.optimizer.bounded_order_preserving_variants`) + let should_preserve_ordering = input.plan.output_ordering().is_some(); + + let new_plan = if should_preserve_ordering { Arc::new(SortPreservingMergeExec::new( - existing_ordering.to_vec(), - input, + input.plan.output_ordering().unwrap_or(&[]).to_vec(), + input.plan.clone(), )) as _ } else { - Arc::new(CoalescePartitionsExec::new(input)) as _ + Arc::new(CoalescePartitionsExec::new(input.plan.clone())) as _ }; - // update repartition onward with new operator - update_distribution_onward(new_plan.clone(), dist_onward, input_idx); - new_plan + DistributionContext { + plan: new_plan, + distribution_connection: true, + children_nodes: vec![input], + } } else { input } } -/// Updates the physical plan inside `distribution_context` so that distribution +/// Updates the physical plan inside [`DistributionContext`] so that distribution /// changing operators are removed from the top. If they are necessary, they will /// be added in subsequent stages. /// @@ -1081,48 +1027,23 @@ fn add_spm_on_top( /// "ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", /// ``` fn remove_dist_changing_operators( - distribution_context: DistributionContext, + mut distribution_context: DistributionContext, ) -> Result { - let DistributionContext { - mut plan, - mut distribution_onwards, - } = distribution_context; - - // Remove any distribution changing operators at the beginning: - // Note that they will be re-inserted later on if necessary or helpful. - while is_repartition(&plan) - || is_coalesce_partitions(&plan) - || is_sort_preserving_merge(&plan) + while is_repartition(&distribution_context.plan) + || is_coalesce_partitions(&distribution_context.plan) + || is_sort_preserving_merge(&distribution_context.plan) { - // All of above operators have a single child. When we remove the top - // operator, we take the first child. - plan = plan.children().swap_remove(0); - distribution_onwards = - get_children_exectrees(plan.children().len(), &distribution_onwards[0]); + // All of above operators have a single child. First child is only child. + let child = distribution_context.children_nodes.swap_remove(0); + // Remove any distribution changing operators at the beginning: + // Note that they will be re-inserted later on if necessary or helpful. + distribution_context = child; } - // Create a plan with the updated children: - Ok(DistributionContext { - plan, - distribution_onwards, - }) + Ok(distribution_context) } -/// Updates the physical plan `input` by using `dist_onward` replace order preserving operator variants -/// with their corresponding operators that do not preserve order. It is a wrapper for `replace_order_preserving_variants_helper` -fn replace_order_preserving_variants( - input: &mut Arc, - dist_onward: &mut Option, -) -> Result<()> { - if let Some(dist_onward) = dist_onward { - *input = replace_order_preserving_variants_helper(dist_onward)?; - } - *dist_onward = None; - Ok(()) -} - -/// Updates the physical plan inside `ExecTree` if preserving ordering while changing partitioning -/// is not helpful or desirable. +/// Updates the [`DistributionContext`] if preserving ordering while changing partitioning is not helpful or desirable. /// /// Assume that following plan is given: /// ```text @@ -1132,7 +1053,7 @@ fn replace_order_preserving_variants( /// " ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", /// ``` /// -/// This function converts plan above (inside `ExecTree`) to the following: +/// This function converts plan above to the following: /// /// ```text /// "CoalescePartitionsExec" @@ -1140,30 +1061,75 @@ fn replace_order_preserving_variants( /// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", /// " ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", /// ``` -fn replace_order_preserving_variants_helper( - exec_tree: &ExecTree, -) -> Result> { - let mut updated_children = exec_tree.plan.children(); - for child in &exec_tree.children { - updated_children[child.idx] = replace_order_preserving_variants_helper(child)?; - } - if is_sort_preserving_merge(&exec_tree.plan) { - return Ok(Arc::new(CoalescePartitionsExec::new( - updated_children.swap_remove(0), - ))); - } - if let Some(repartition) = exec_tree.plan.as_any().downcast_ref::() { +fn replace_order_preserving_variants( + mut context: DistributionContext, +) -> Result { + let mut updated_children = context + .children_nodes + .iter() + .map(|child| { + if child.distribution_connection { + replace_order_preserving_variants(child.clone()) + } else { + Ok(child.clone()) + } + }) + .collect::>>()?; + + if is_sort_preserving_merge(&context.plan) { + let child = updated_children.swap_remove(0); + context.plan = Arc::new(CoalescePartitionsExec::new(child.plan.clone())); + context.children_nodes = vec![child]; + return Ok(context); + } else if let Some(repartition) = + context.plan.as_any().downcast_ref::() + { if repartition.preserve_order() { - return Ok(Arc::new( - // new RepartitionExec don't preserve order - RepartitionExec::try_new( - updated_children.swap_remove(0), - repartition.partitioning().clone(), - )?, - )); + let child = updated_children.swap_remove(0); + context.plan = Arc::new(RepartitionExec::try_new( + child.plan.clone(), + repartition.partitioning().clone(), + )?); + context.children_nodes = vec![child]; + return Ok(context); + } + } + + context.plan = context + .plan + .clone() + .with_new_children(updated_children.into_iter().map(|c| c.plan).collect())?; + Ok(context) +} + +/// This utility function adds a [`SortExec`] above an operator according to the +/// given ordering requirements while preserving the original partitioning. +fn add_sort_preserving_partitions( + node: DistributionContext, + sort_requirement: LexRequirementRef, + fetch: Option, +) -> DistributionContext { + // If the ordering requirement is already satisfied, do not add a sort. + if !node + .plan + .equivalence_properties() + .ordering_satisfy_requirement(sort_requirement) + { + let sort_expr = PhysicalSortRequirement::to_sort_exprs(sort_requirement.to_vec()); + let new_sort = SortExec::new(sort_expr, node.plan.clone()).with_fetch(fetch); + + DistributionContext { + plan: Arc::new(if node.plan.output_partitioning().partition_count() > 1 { + new_sort.with_preserve_partitioning(true) + } else { + new_sort + }), + distribution_connection: false, + children_nodes: vec![node], } + } else { + node } - exec_tree.plan.clone().with_new_children(updated_children) } /// This function checks whether we need to add additional data exchange @@ -1174,6 +1140,12 @@ fn ensure_distribution( dist_context: DistributionContext, config: &ConfigOptions, ) -> Result> { + let dist_context = dist_context.update_children()?; + + if dist_context.plan.children().is_empty() { + return Ok(Transformed::No(dist_context)); + } + let target_partitions = config.execution.target_partitions; // When `false`, round robin repartition will not be added to increase parallelism let enable_round_robin = config.optimizer.enable_round_robin_repartition; @@ -1186,14 +1158,11 @@ fn ensure_distribution( let order_preserving_variants_desirable = is_unbounded || config.optimizer.prefer_existing_sort; - if dist_context.plan.children().is_empty() { - return Ok(Transformed::No(dist_context)); - } - // Remove unnecessary repartition from the physical plan if any let DistributionContext { mut plan, - mut distribution_onwards, + distribution_connection, + children_nodes, } = remove_dist_changing_operators(dist_context)?; if let Some(exec) = plan.as_any().downcast_ref::() { @@ -1213,33 +1182,23 @@ fn ensure_distribution( plan = updated_window; } }; - let n_children = plan.children().len(); + // This loop iterates over all the children to: // - Increase parallelism for every child if it is beneficial. // - Satisfy the distribution requirements of every child, if it is not // already satisfied. // We store the updated children in `new_children`. - let new_children = izip!( - plan.children().into_iter(), + let children_nodes = izip!( + children_nodes.into_iter(), plan.required_input_distribution().iter(), plan.required_input_ordering().iter(), - distribution_onwards.iter_mut(), plan.benefits_from_input_partitioning(), - plan.maintains_input_order(), - 0..n_children + plan.maintains_input_order() ) .map( - |( - mut child, - requirement, - required_input_ordering, - dist_onward, - would_benefit, - maintains, - child_idx, - )| { + |(mut child, requirement, required_input_ordering, would_benefit, maintains)| { // Don't need to apply when the returned row count is not greater than 1: - let num_rows = child.statistics()?.num_rows; + let num_rows = child.plan.statistics()?.num_rows; let repartition_beneficial_stats = if num_rows.is_exact().unwrap_or(false) { num_rows .get_value() @@ -1248,45 +1207,39 @@ fn ensure_distribution( } else { true }; + if enable_round_robin // Operator benefits from partitioning (e.g. filter): && (would_benefit && repartition_beneficial_stats) // Unless partitioning doesn't increase the partition count, it is not beneficial: - && child.output_partitioning().partition_count() < target_partitions + && child.plan.output_partitioning().partition_count() < target_partitions { // When `repartition_file_scans` is set, attempt to increase // parallelism at the source. if repartition_file_scans { if let Some(new_child) = - child.repartitioned(target_partitions, config)? + child.plan.repartitioned(target_partitions, config)? { - child = new_child; + child.plan = new_child; } } // Increase parallelism by adding round-robin repartitioning // on top of the operator. Note that we only do this if the // partition count is not already equal to the desired partition // count. - child = add_roundrobin_on_top( - child, - target_partitions, - dist_onward, - child_idx, - )?; + child = add_roundrobin_on_top(child, target_partitions)?; } // Satisfy the distribution requirement if it is unmet. match requirement { Distribution::SinglePartition => { - child = add_spm_on_top(child, dist_onward, child_idx); + child = add_spm_on_top(child); } Distribution::HashPartitioned(exprs) => { child = add_hash_on_top( child, exprs.to_vec(), target_partitions, - dist_onward, - child_idx, repartition_beneficial_stats, )?; } @@ -1299,31 +1252,38 @@ fn ensure_distribution( // - Ordering requirement cannot be satisfied by preserving ordering through repartitions, or // - using order preserving variant is not desirable. let ordering_satisfied = child + .plan .equivalence_properties() .ordering_satisfy_requirement(required_input_ordering); - if !ordering_satisfied || !order_preserving_variants_desirable { - replace_order_preserving_variants(&mut child, dist_onward)?; + if (!ordering_satisfied || !order_preserving_variants_desirable) + && child.distribution_connection + { + child = replace_order_preserving_variants(child)?; // If ordering requirements were satisfied before repartitioning, // make sure ordering requirements are still satisfied after. if ordering_satisfied { // Make sure to satisfy ordering requirement: - add_sort_above(&mut child, required_input_ordering, None); + child = add_sort_preserving_partitions( + child, + required_input_ordering, + None, + ); } } // Stop tracking distribution changing operators - *dist_onward = None; + child.distribution_connection = false; } else { // no ordering requirement match requirement { // Operator requires specific distribution. Distribution::SinglePartition | Distribution::HashPartitioned(_) => { // Since there is no ordering requirement, preserving ordering is pointless - replace_order_preserving_variants(&mut child, dist_onward)?; + child = replace_order_preserving_variants(child)?; } Distribution::UnspecifiedDistribution => { // Since ordering is lost, trying to preserve ordering is pointless - if !maintains { - replace_order_preserving_variants(&mut child, dist_onward)?; + if !maintains || plan.as_any().is::() { + child = replace_order_preserving_variants(child)?; } } } @@ -1334,7 +1294,9 @@ fn ensure_distribution( .collect::>>()?; let new_distribution_context = DistributionContext { - plan: if plan.as_any().is::() && can_interleave(&new_children) { + plan: if plan.as_any().is::() + && can_interleave(children_nodes.iter().map(|c| c.plan.clone())) + { // Add a special case for [`UnionExec`] since we want to "bubble up" // hash-partitioned data. So instead of // @@ -1358,120 +1320,91 @@ fn ensure_distribution( // - Agg: // Repartition (hash): // Data - Arc::new(InterleaveExec::try_new(new_children)?) + Arc::new(InterleaveExec::try_new( + children_nodes.iter().map(|c| c.plan.clone()).collect(), + )?) } else { - plan.with_new_children(new_children)? + plan.with_new_children( + children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? }, - distribution_onwards, + distribution_connection, + children_nodes, }; + Ok(Transformed::Yes(new_distribution_context)) } -/// A struct to keep track of distribution changing executors +/// A struct to keep track of distribution changing operators /// (`RepartitionExec`, `SortPreservingMergeExec`, `CoalescePartitionsExec`), /// and their associated parents inside `plan`. Using this information, /// we can optimize distribution of the plan if/when necessary. #[derive(Debug, Clone)] struct DistributionContext { plan: Arc, - /// Keep track of associations for each child of the plan. If `None`, - /// there is no distribution changing operator in its descendants. - distribution_onwards: Vec>, + /// Indicates whether this plan is connected to a distribution-changing + /// operator. + distribution_connection: bool, + children_nodes: Vec, } impl DistributionContext { - /// Creates an empty context. + /// Creates a tree according to the plan with empty states. fn new(plan: Arc) -> Self { - let length = plan.children().len(); - DistributionContext { + let children = plan.children(); + Self { plan, - distribution_onwards: vec![None; length], + distribution_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } - /// Constructs a new context from children contexts. - fn new_from_children_nodes( - children_nodes: Vec, - parent_plan: Arc, - ) -> Result { - let children_plans = children_nodes - .iter() - .map(|item| item.plan.clone()) - .collect(); - let distribution_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, context)| { - let DistributionContext { - plan, - // The `distribution_onwards` tree keeps track of operators - // that change distribution, or preserves the existing - // distribution (starting from an operator that change distribution). - distribution_onwards, - } = context; - if plan.children().is_empty() { - // Plan has no children, there is nothing to propagate. - None - } else if distribution_onwards[0].is_none() { - if let Some(repartition) = - plan.as_any().downcast_ref::() - { - match repartition.partitioning() { - Partitioning::RoundRobinBatch(_) - | Partitioning::Hash(_, _) => { - // Start tracking operators starting from this repartition (either roundrobin or hash): - return Some(ExecTree::new(plan, idx, vec![])); - } - _ => {} - } - } else if plan.as_any().is::() - || plan.as_any().is::() - { - // Start tracking operators starting from this sort preserving merge: - return Some(ExecTree::new(plan, idx, vec![])); - } - None - } else { - // Propagate children distribution tracking to the above - let new_distribution_onwards = izip!( - plan.required_input_distribution().iter(), - distribution_onwards.into_iter() - ) - .flat_map(|(required_dist, distribution_onwards)| { - if let Some(distribution_onwards) = distribution_onwards { - // Operator can safely propagate the distribution above. - // This is similar to maintaining order in the EnforceSorting rule. - if let Distribution::UnspecifiedDistribution = required_dist { - return Some(distribution_onwards); - } - } - None - }) - .collect::>(); - // Either: - // - None of the children has a connection to an operator that modifies distribution, or - // - The current operator requires distribution at its input so doesn't propagate it above. - if new_distribution_onwards.is_empty() { - None - } else { - Some(ExecTree::new(plan, idx, new_distribution_onwards)) - } + fn update_children(mut self) -> Result { + for child_context in self.children_nodes.iter_mut() { + child_context.distribution_connection = match child_context.plan.as_any() { + plan_any if plan_any.is::() => matches!( + plan_any + .downcast_ref::() + .unwrap() + .partitioning(), + Partitioning::RoundRobinBatch(_) | Partitioning::Hash(_, _) + ), + plan_any + if plan_any.is::() + || plan_any.is::() => + { + true } - }) - .collect(); - Ok(DistributionContext { - plan: with_new_children_if_necessary(parent_plan, children_plans)?.into(), - distribution_onwards, - }) - } + _ => { + child_context.plan.children().is_empty() + || child_context.children_nodes[0].distribution_connection + || child_context + .plan + .required_input_distribution() + .iter() + .zip(child_context.children_nodes.iter()) + .any(|(required_dist, child_context)| { + child_context.distribution_connection + && matches!( + required_dist, + Distribution::UnspecifiedDistribution + ) + }) + } + }; + } - /// Computes distribution tracking contexts for every child of the plan. - fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(DistributionContext::new) - .collect() + let children_plans = self + .children_nodes + .iter() + .map(|context| context.plan.clone()) + .collect::>(); + + Ok(Self { + plan: with_new_children_if_necessary(self.plan, children_plans)?.into(), + distribution_connection: false, + children_nodes: self.children_nodes, + }) } } @@ -1480,8 +1413,8 @@ impl TreeNode for DistributionContext { where F: FnMut(&Self) -> Result, { - for child in self.children() { - match op(&child)? { + for child in &self.children_nodes { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -1490,20 +1423,23 @@ impl TreeNode for DistributionContext { Ok(VisitRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - DistributionContext::new_from_children_nodes(children_nodes, self.plan) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } @@ -1512,11 +1448,11 @@ impl fmt::Display for DistributionContext { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let plan_string = get_plan_string(&self.plan); write!(f, "plan: {:?}", plan_string)?; - for (idx, child) in self.distribution_onwards.iter().enumerate() { - if let Some(child) = child { - write!(f, "idx:{:?}, exec_tree:{}", idx, child)?; - } - } + write!( + f, + "distribution_connection:{}", + self.distribution_connection, + )?; write!(f, "") } } @@ -1532,37 +1468,18 @@ struct PlanWithKeyRequirements { plan: Arc, /// Parent required key ordering required_key_ordering: Vec>, - /// The request key ordering to children - request_key_ordering: Vec>>>, + children: Vec, } impl PlanWithKeyRequirements { fn new(plan: Arc) -> Self { - let children_len = plan.children().len(); - PlanWithKeyRequirements { + let children = plan.children(); + Self { plan, required_key_ordering: vec![], - request_key_ordering: vec![None; children_len], + children: children.into_iter().map(Self::new).collect(), } } - - fn children(&self) -> Vec { - let plan_children = self.plan.children(); - assert_eq!(plan_children.len(), self.request_key_ordering.len()); - plan_children - .into_iter() - .zip(self.request_key_ordering.clone()) - .map(|(child, required)| { - let from_parent = required.unwrap_or_default(); - let length = child.children().len(); - PlanWithKeyRequirements { - plan: child, - required_key_ordering: from_parent, - request_key_ordering: vec![None; length], - } - }) - .collect() - } } impl TreeNode for PlanWithKeyRequirements { @@ -1570,9 +1487,8 @@ impl TreeNode for PlanWithKeyRequirements { where F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { + for child in &self.children { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -1582,28 +1498,23 @@ impl TreeNode for PlanWithKeyRequirements { Ok(VisitRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if !children.is_empty() { - let new_children: Result> = - children.into_iter().map(transform).collect(); - - let children_plans = new_children? + if !self.children.is_empty() { + self.children = self + .children .into_iter() - .map(|child| child.plan) - .collect::>(); - let new_plan = with_new_children_if_necessary(self.plan, children_plans)?; - Ok(PlanWithKeyRequirements { - plan: new_plan.into(), - required_key_ordering: self.required_key_ordering, - request_key_ordering: self.request_key_ordering, - }) - } else { - Ok(self) + .map(transform) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 2ecc1e11b985..77d04a61c59e 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -44,7 +44,7 @@ use crate::physical_optimizer::replace_with_order_preserving_variants::{ use crate::physical_optimizer::sort_pushdown::{pushdown_sorts, SortPushDown}; use crate::physical_optimizer::utils::{ add_sort_above, is_coalesce_partitions, is_limit, is_repartition, is_sort, - is_sort_preserving_merge, is_union, is_window, ExecTree, + is_sort_preserving_merge, is_union, is_window, }; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -81,78 +81,66 @@ impl EnforceSorting { #[derive(Debug, Clone)] struct PlanWithCorrespondingSort { plan: Arc, - // For every child, keep a subtree of `ExecutionPlan`s starting from the - // child until the `SortExec`(s) -- could be multiple for n-ary plans like - // Union -- that determine the output ordering of the child. If the child - // has no connection to any sort, simply store None (and not a subtree). - sort_onwards: Vec>, + // For every child, track `ExecutionPlan`s starting from the child until + // the `SortExec`(s). If the child has no connection to any sort, it simply + // stores false. + sort_connection: bool, + children_nodes: Vec, } impl PlanWithCorrespondingSort { fn new(plan: Arc) -> Self { - let length = plan.children().len(); - PlanWithCorrespondingSort { + let children = plan.children(); + Self { plan, - sort_onwards: vec![None; length], + sort_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } - fn new_from_children_nodes( - children_nodes: Vec, + fn update_children( parent_plan: Arc, + mut children_nodes: Vec, ) -> Result { - let children_plans = children_nodes - .iter() - .map(|item| item.plan.clone()) - .collect::>(); - let sort_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, item)| { - let plan = &item.plan; - // Leaves of `sort_onwards` are `SortExec` operators, which impose - // an ordering. This tree collects all the intermediate executors - // that maintain this ordering. If we just saw a order imposing - // operator, we reset the tree and start accumulating. - if is_sort(plan) { - return Some(ExecTree::new(item.plan, idx, vec![])); - } else if is_limit(plan) { - // There is no sort linkage for this path, it starts at a limit. - return None; - } + for node in children_nodes.iter_mut() { + let plan = &node.plan; + // Leaves of `sort_onwards` are `SortExec` operators, which impose + // an ordering. This tree collects all the intermediate executors + // that maintain this ordering. If we just saw a order imposing + // operator, we reset the tree and start accumulating. + node.sort_connection = if is_sort(plan) { + // Initiate connection + true + } else if is_limit(plan) { + // There is no sort linkage for this path, it starts at a limit. + false + } else { let is_spm = is_sort_preserving_merge(plan); let required_orderings = plan.required_input_ordering(); let flags = plan.maintains_input_order(); - let children = izip!(flags, item.sort_onwards, required_orderings) - .filter_map(|(maintains, element, required_ordering)| { - if (required_ordering.is_none() && maintains) || is_spm { - element - } else { - None - } - }) - .collect::>(); - if !children.is_empty() { - // Add parent node to the tree if there is at least one - // child with a subtree: - Some(ExecTree::new(item.plan, idx, children)) - } else { - // There is no sort linkage for this child, do nothing. - None - } - }) - .collect(); + // Add parent node to the tree if there is at least one + // child with a sort connection: + izip!(flags, required_orderings).any(|(maintains, required_ordering)| { + let propagates_ordering = + (maintains && required_ordering.is_none()) || is_spm; + let connected_to_sort = + node.children_nodes.iter().any(|item| item.sort_connection); + propagates_ordering && connected_to_sort + }) + } + } + let children_plans = children_nodes + .iter() + .map(|item| item.plan.clone()) + .collect::>(); let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); - Ok(PlanWithCorrespondingSort { plan, sort_onwards }) - } - fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(PlanWithCorrespondingSort::new) - .collect() + Ok(Self { + plan, + sort_connection: false, + children_nodes, + }) } } @@ -161,9 +149,8 @@ impl TreeNode for PlanWithCorrespondingSort { where F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { + for child in &self.children_nodes { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -173,102 +160,79 @@ impl TreeNode for PlanWithCorrespondingSort { Ok(VisitRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - PlanWithCorrespondingSort::new_from_children_nodes(children_nodes, self.plan) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } -/// This object is used within the [EnforceSorting] rule to track the closest +/// This object is used within the [`EnforceSorting`] rule to track the closest /// [`CoalescePartitionsExec`] descendant(s) for every child of a plan. #[derive(Debug, Clone)] struct PlanWithCorrespondingCoalescePartitions { plan: Arc, - // For every child, keep a subtree of `ExecutionPlan`s starting from the - // child until the `CoalescePartitionsExec`(s) -- could be multiple for - // n-ary plans like Union -- that affect the output partitioning of the - // child. If the child has no connection to any `CoalescePartitionsExec`, - // simply store None (and not a subtree). - coalesce_onwards: Vec>, + // Stores whether the plan is a `CoalescePartitionsExec` or it is connected to + // a `CoalescePartitionsExec` via its children. + coalesce_connection: bool, + children_nodes: Vec, } impl PlanWithCorrespondingCoalescePartitions { + /// Creates an empty tree with empty connections. fn new(plan: Arc) -> Self { - let length = plan.children().len(); - PlanWithCorrespondingCoalescePartitions { + let children = plan.children(); + Self { plan, - coalesce_onwards: vec![None; length], + coalesce_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } - fn new_from_children_nodes( - children_nodes: Vec, - parent_plan: Arc, - ) -> Result { - let children_plans = children_nodes + fn update_children(mut self) -> Result { + self.coalesce_connection = if self.plan.children().is_empty() { + // Plan has no children, it cannot be a `CoalescePartitionsExec`. + false + } else if is_coalesce_partitions(&self.plan) { + // Initiate a connection + true + } else { + self.children_nodes + .iter() + .enumerate() + .map(|(idx, node)| { + // Only consider operators that don't require a + // single partition, and connected to any coalesce + node.coalesce_connection + && !matches!( + self.plan.required_input_distribution()[idx], + Distribution::SinglePartition + ) + // If all children are None. There is nothing to track, set connection false. + }) + .any(|c| c) + }; + + let children_plans = self + .children_nodes .iter() .map(|item| item.plan.clone()) .collect(); - let coalesce_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, item)| { - // Leaves of the `coalesce_onwards` tree are `CoalescePartitionsExec` - // operators. This tree collects all the intermediate executors that - // maintain a single partition. If we just saw a `CoalescePartitionsExec` - // operator, we reset the tree and start accumulating. - let plan = item.plan; - if plan.children().is_empty() { - // Plan has no children, there is nothing to propagate. - None - } else if is_coalesce_partitions(&plan) { - Some(ExecTree::new(plan, idx, vec![])) - } else { - let children = item - .coalesce_onwards - .into_iter() - .flatten() - .filter(|item| { - // Only consider operators that don't require a - // single partition. - !matches!( - plan.required_input_distribution()[item.idx], - Distribution::SinglePartition - ) - }) - .collect::>(); - if children.is_empty() { - None - } else { - Some(ExecTree::new(plan, idx, children)) - } - } - }) - .collect(); - let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); - Ok(PlanWithCorrespondingCoalescePartitions { - plan, - coalesce_onwards, - }) - } - - fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(PlanWithCorrespondingCoalescePartitions::new) - .collect() + self.plan = with_new_children_if_necessary(self.plan, children_plans)?.into(); + Ok(self) } } @@ -277,9 +241,8 @@ impl TreeNode for PlanWithCorrespondingCoalescePartitions { where F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { + for child in &self.children_nodes { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -289,23 +252,23 @@ impl TreeNode for PlanWithCorrespondingCoalescePartitions { Ok(VisitRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - PlanWithCorrespondingCoalescePartitions::new_from_children_nodes( - children_nodes, + .collect::>()?; + self.plan = with_new_children_if_necessary( self.plan, - ) + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } @@ -332,6 +295,7 @@ impl PhysicalOptimizerRule for EnforceSorting { } else { adjusted.plan }; + let plan_with_pipeline_fixer = OrderPreservationContext::new(new_plan); let updated_plan = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| { @@ -345,7 +309,8 @@ impl PhysicalOptimizerRule for EnforceSorting { // Execute a top-down traversal to exploit sort push-down opportunities // missed by the bottom-up traversal: - let sort_pushdown = SortPushDown::init(updated_plan.plan); + let mut sort_pushdown = SortPushDown::new(updated_plan.plan); + sort_pushdown.assign_initial_requirements(); let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?; Ok(adjusted.plan) } @@ -376,16 +341,21 @@ impl PhysicalOptimizerRule for EnforceSorting { fn parallelize_sorts( requirements: PlanWithCorrespondingCoalescePartitions, ) -> Result> { - let plan = requirements.plan; - let mut coalesce_onwards = requirements.coalesce_onwards; - if plan.children().is_empty() || coalesce_onwards[0].is_none() { + let PlanWithCorrespondingCoalescePartitions { + mut plan, + coalesce_connection, + mut children_nodes, + } = requirements.update_children()?; + + if plan.children().is_empty() || !children_nodes[0].coalesce_connection { // We only take an action when the plan is either a SortExec, a // SortPreservingMergeExec or a CoalescePartitionsExec, and they // all have a single child. Therefore, if the first child is `None`, // we can return immediately. return Ok(Transformed::No(PlanWithCorrespondingCoalescePartitions { plan, - coalesce_onwards, + coalesce_connection, + children_nodes, })); } else if (is_sort(&plan) || is_sort_preserving_merge(&plan)) && plan.output_partitioning().partition_count() <= 1 @@ -395,34 +365,30 @@ fn parallelize_sorts( // executors don't require single partition), then we can replace // the CoalescePartitionsExec + Sort cascade with a SortExec + // SortPreservingMergeExec cascade to parallelize sorting. - let mut prev_layer = plan.clone(); - update_child_to_remove_coalesce(&mut prev_layer, &mut coalesce_onwards[0])?; let (sort_exprs, fetch) = get_sort_exprs(&plan)?; - add_sort_above( - &mut prev_layer, - &PhysicalSortRequirement::from_sort_exprs(sort_exprs), - fetch, - ); - let spm = SortPreservingMergeExec::new(sort_exprs.to_vec(), prev_layer) - .with_fetch(fetch); - return Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { - plan: Arc::new(spm), - coalesce_onwards: vec![None], - })); + let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs); + let sort_exprs = sort_exprs.to_vec(); + update_child_to_remove_coalesce(&mut plan, &mut children_nodes[0])?; + add_sort_above(&mut plan, &sort_reqs, fetch); + let spm = SortPreservingMergeExec::new(sort_exprs, plan).with_fetch(fetch); + + return Ok(Transformed::Yes( + PlanWithCorrespondingCoalescePartitions::new(Arc::new(spm)), + )); } else if is_coalesce_partitions(&plan) { // There is an unnecessary `CoalescePartitionsExec` in the plan. - let mut prev_layer = plan.clone(); - update_child_to_remove_coalesce(&mut prev_layer, &mut coalesce_onwards[0])?; - let new_plan = plan.with_new_children(vec![prev_layer])?; - return Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { - plan: new_plan, - coalesce_onwards: vec![None], - })); + update_child_to_remove_coalesce(&mut plan, &mut children_nodes[0])?; + + let new_plan = Arc::new(CoalescePartitionsExec::new(plan)) as _; + return Ok(Transformed::Yes( + PlanWithCorrespondingCoalescePartitions::new(new_plan), + )); } Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { plan, - coalesce_onwards, + coalesce_connection, + children_nodes, })) } @@ -431,91 +397,102 @@ fn parallelize_sorts( fn ensure_sorting( requirements: PlanWithCorrespondingSort, ) -> Result> { + let requirements = PlanWithCorrespondingSort::update_children( + requirements.plan, + requirements.children_nodes, + )?; + // Perform naive analysis at the beginning -- remove already-satisfied sorts: if requirements.plan.children().is_empty() { return Ok(Transformed::No(requirements)); } - let plan = requirements.plan; - let mut children = plan.children(); - let mut sort_onwards = requirements.sort_onwards; - if let Some(result) = analyze_immediate_sort_removal(&plan, &sort_onwards) { + if let Some(result) = analyze_immediate_sort_removal(&requirements) { return Ok(Transformed::Yes(result)); } - for (idx, (child, sort_onwards, required_ordering)) in izip!( - children.iter_mut(), - sort_onwards.iter_mut(), - plan.required_input_ordering() - ) - .enumerate() + + let plan = requirements.plan; + let mut children_nodes = requirements.children_nodes; + + for (idx, (child_node, required_ordering)) in + izip!(children_nodes.iter_mut(), plan.required_input_ordering()).enumerate() { - let physical_ordering = child.output_ordering(); + let mut child_plan = child_node.plan.clone(); + let physical_ordering = child_plan.output_ordering(); match (required_ordering, physical_ordering) { (Some(required_ordering), Some(_)) => { - if !child + if !child_plan .equivalence_properties() .ordering_satisfy_requirement(&required_ordering) { // Make sure we preserve the ordering requirements: - update_child_to_remove_unnecessary_sort(child, sort_onwards, &plan)?; - add_sort_above(child, &required_ordering, None); - if is_sort(child) { - *sort_onwards = Some(ExecTree::new(child.clone(), idx, vec![])); - } else { - *sort_onwards = None; + update_child_to_remove_unnecessary_sort(idx, child_node, &plan)?; + add_sort_above(&mut child_plan, &required_ordering, None); + if is_sort(&child_plan) { + *child_node = PlanWithCorrespondingSort::update_children( + child_plan, + vec![child_node.clone()], + )?; + child_node.sort_connection = true; } } } (Some(required), None) => { // Ordering requirement is not met, we should add a `SortExec` to the plan. - add_sort_above(child, &required, None); - *sort_onwards = Some(ExecTree::new(child.clone(), idx, vec![])); + add_sort_above(&mut child_plan, &required, None); + *child_node = PlanWithCorrespondingSort::update_children( + child_plan, + vec![child_node.clone()], + )?; + child_node.sort_connection = true; } (None, Some(_)) => { // We have a `SortExec` whose effect may be neutralized by // another order-imposing operator. Remove this sort. if !plan.maintains_input_order()[idx] || is_union(&plan) { - update_child_to_remove_unnecessary_sort(child, sort_onwards, &plan)?; + update_child_to_remove_unnecessary_sort(idx, child_node, &plan)?; } } (None, None) => { - update_child_to_remove_unnecessary_sort(child, sort_onwards, &plan)?; + update_child_to_remove_unnecessary_sort(idx, child_node, &plan)?; } } } // For window expressions, we can remove some sorts when we can // calculate the result in reverse: - if is_window(&plan) { - if let Some(tree) = &mut sort_onwards[0] { - if let Some(result) = analyze_window_sort_removal(tree, &plan)? { - return Ok(Transformed::Yes(result)); - } + if is_window(&plan) && children_nodes[0].sort_connection { + if let Some(result) = analyze_window_sort_removal(&mut children_nodes[0], &plan)? + { + return Ok(Transformed::Yes(result)); } } else if is_sort_preserving_merge(&plan) - && children[0].output_partitioning().partition_count() <= 1 + && children_nodes[0] + .plan + .output_partitioning() + .partition_count() + <= 1 { // This SortPreservingMergeExec is unnecessary, input already has a // single partition. - sort_onwards.truncate(1); - return Ok(Transformed::Yes(PlanWithCorrespondingSort { - plan: children.swap_remove(0), - sort_onwards, - })); + let child_node = children_nodes.swap_remove(0); + return Ok(Transformed::Yes(child_node)); } - Ok(Transformed::Yes(PlanWithCorrespondingSort { - plan: plan.with_new_children(children)?, - sort_onwards, - })) + Ok(Transformed::Yes( + PlanWithCorrespondingSort::update_children(plan, children_nodes)?, + )) } /// Analyzes a given [`SortExec`] (`plan`) to determine whether its input /// already has a finer ordering than it enforces. fn analyze_immediate_sort_removal( - plan: &Arc, - sort_onwards: &[Option], + node: &PlanWithCorrespondingSort, ) -> Option { + let PlanWithCorrespondingSort { + plan, + children_nodes, + .. + } = node; if let Some(sort_exec) = plan.as_any().downcast_ref::() { let sort_input = sort_exec.input().clone(); - // If this sort is unnecessary, we should remove it: if sort_input .equivalence_properties() @@ -533,20 +510,33 @@ fn analyze_immediate_sort_removal( sort_exec.expr().to_vec(), sort_input, )); - let new_tree = ExecTree::new( - new_plan.clone(), - 0, - sort_onwards.iter().flat_map(|e| e.clone()).collect(), - ); PlanWithCorrespondingSort { plan: new_plan, - sort_onwards: vec![Some(new_tree)], + // SortPreservingMergeExec has single child. + sort_connection: false, + children_nodes: children_nodes + .iter() + .cloned() + .map(|mut node| { + node.sort_connection = false; + node + }) + .collect(), } } else { // Remove the sort: PlanWithCorrespondingSort { plan: sort_input, - sort_onwards: sort_onwards.to_vec(), + sort_connection: false, + children_nodes: children_nodes[0] + .children_nodes + .iter() + .cloned() + .map(|mut node| { + node.sort_connection = false; + node + }) + .collect(), } }, ); @@ -558,15 +548,15 @@ fn analyze_immediate_sort_removal( /// Analyzes a [`WindowAggExec`] or a [`BoundedWindowAggExec`] to determine /// whether it may allow removing a sort. fn analyze_window_sort_removal( - sort_tree: &mut ExecTree, + sort_tree: &mut PlanWithCorrespondingSort, window_exec: &Arc, ) -> Result> { let requires_single_partition = matches!( - window_exec.required_input_distribution()[sort_tree.idx], + window_exec.required_input_distribution()[0], Distribution::SinglePartition ); - let mut window_child = - remove_corresponding_sort_from_sub_plan(sort_tree, requires_single_partition)?; + remove_corresponding_sort_from_sub_plan(sort_tree, requires_single_partition)?; + let mut window_child = sort_tree.plan.clone(); let (window_expr, new_window) = if let Some(exec) = window_exec.as_any().downcast_ref::() { ( @@ -628,9 +618,9 @@ fn analyze_window_sort_removal( /// Updates child to remove the unnecessary [`CoalescePartitionsExec`] below it. fn update_child_to_remove_coalesce( child: &mut Arc, - coalesce_onwards: &mut Option, + coalesce_onwards: &mut PlanWithCorrespondingCoalescePartitions, ) -> Result<()> { - if let Some(coalesce_onwards) = coalesce_onwards { + if coalesce_onwards.coalesce_connection { *child = remove_corresponding_coalesce_in_sub_plan(coalesce_onwards, child)?; } Ok(()) @@ -638,10 +628,10 @@ fn update_child_to_remove_coalesce( /// Removes the [`CoalescePartitionsExec`] from the plan in `coalesce_onwards`. fn remove_corresponding_coalesce_in_sub_plan( - coalesce_onwards: &mut ExecTree, + coalesce_onwards: &mut PlanWithCorrespondingCoalescePartitions, parent: &Arc, ) -> Result> { - Ok(if is_coalesce_partitions(&coalesce_onwards.plan) { + if is_coalesce_partitions(&coalesce_onwards.plan) { // We can safely use the 0th index since we have a `CoalescePartitionsExec`. let mut new_plan = coalesce_onwards.plan.children()[0].clone(); while new_plan.output_partitioning() == parent.output_partitioning() @@ -650,89 +640,113 @@ fn remove_corresponding_coalesce_in_sub_plan( { new_plan = new_plan.children().swap_remove(0) } - new_plan + Ok(new_plan) } else { let plan = coalesce_onwards.plan.clone(); let mut children = plan.children(); - for item in &mut coalesce_onwards.children { - children[item.idx] = remove_corresponding_coalesce_in_sub_plan(item, &plan)?; + for (idx, node) in coalesce_onwards.children_nodes.iter_mut().enumerate() { + if node.coalesce_connection { + children[idx] = remove_corresponding_coalesce_in_sub_plan(node, &plan)?; + } } - plan.with_new_children(children)? - }) + plan.with_new_children(children) + } } /// Updates child to remove the unnecessary sort below it. fn update_child_to_remove_unnecessary_sort( - child: &mut Arc, - sort_onwards: &mut Option, + child_idx: usize, + sort_onwards: &mut PlanWithCorrespondingSort, parent: &Arc, ) -> Result<()> { - if let Some(sort_onwards) = sort_onwards { + if sort_onwards.sort_connection { let requires_single_partition = matches!( - parent.required_input_distribution()[sort_onwards.idx], + parent.required_input_distribution()[child_idx], Distribution::SinglePartition ); - *child = remove_corresponding_sort_from_sub_plan( - sort_onwards, - requires_single_partition, - )?; + remove_corresponding_sort_from_sub_plan(sort_onwards, requires_single_partition)?; } - *sort_onwards = None; + sort_onwards.sort_connection = false; Ok(()) } /// Removes the sort from the plan in `sort_onwards`. fn remove_corresponding_sort_from_sub_plan( - sort_onwards: &mut ExecTree, + sort_onwards: &mut PlanWithCorrespondingSort, requires_single_partition: bool, -) -> Result> { +) -> Result<()> { // A `SortExec` is always at the bottom of the tree. - let mut updated_plan = if is_sort(&sort_onwards.plan) { - sort_onwards.plan.children().swap_remove(0) + if is_sort(&sort_onwards.plan) { + *sort_onwards = sort_onwards.children_nodes.swap_remove(0); } else { - let plan = &sort_onwards.plan; - let mut children = plan.children(); - for item in &mut sort_onwards.children { - let requires_single_partition = matches!( - plan.required_input_distribution()[item.idx], - Distribution::SinglePartition - ); - children[item.idx] = - remove_corresponding_sort_from_sub_plan(item, requires_single_partition)?; + let PlanWithCorrespondingSort { + plan, + sort_connection: _, + children_nodes, + } = sort_onwards; + let mut any_connection = false; + for (child_idx, child_node) in children_nodes.iter_mut().enumerate() { + if child_node.sort_connection { + any_connection = true; + let requires_single_partition = matches!( + plan.required_input_distribution()[child_idx], + Distribution::SinglePartition + ); + remove_corresponding_sort_from_sub_plan( + child_node, + requires_single_partition, + )?; + } } + if any_connection || children_nodes.is_empty() { + *sort_onwards = PlanWithCorrespondingSort::update_children( + plan.clone(), + children_nodes.clone(), + )?; + } + let PlanWithCorrespondingSort { + plan, + children_nodes, + .. + } = sort_onwards; // Replace with variants that do not preserve order. if is_sort_preserving_merge(plan) { - children.swap_remove(0) + children_nodes.swap_remove(0); + *plan = plan.children().swap_remove(0); } else if let Some(repartition) = plan.as_any().downcast_ref::() { - Arc::new( - // By default, RepartitionExec does not preserve order - RepartitionExec::try_new( - children.swap_remove(0), - repartition.partitioning().clone(), - )?, - ) - } else { - plan.clone().with_new_children(children)? + *plan = Arc::new(RepartitionExec::try_new( + children_nodes[0].plan.clone(), + repartition.output_partitioning(), + )?) as _; } }; // Deleting a merging sort may invalidate distribution requirements. // Ensure that we stay compliant with such requirements: if requires_single_partition - && updated_plan.output_partitioning().partition_count() > 1 + && sort_onwards.plan.output_partitioning().partition_count() > 1 { // If there is existing ordering, to preserve ordering use SortPreservingMergeExec // instead of CoalescePartitionsExec. - if let Some(ordering) = updated_plan.output_ordering() { - updated_plan = Arc::new(SortPreservingMergeExec::new( + if let Some(ordering) = sort_onwards.plan.output_ordering() { + let plan = Arc::new(SortPreservingMergeExec::new( ordering.to_vec(), - updated_plan, - )); + sort_onwards.plan.clone(), + )) as _; + *sort_onwards = PlanWithCorrespondingSort::update_children( + plan, + vec![sort_onwards.clone()], + )?; } else { - updated_plan = Arc::new(CoalescePartitionsExec::new(updated_plan)); + let plan = + Arc::new(CoalescePartitionsExec::new(sort_onwards.plan.clone())) as _; + *sort_onwards = PlanWithCorrespondingSort::update_children( + plan, + vec![sort_onwards.clone()], + )?; } } - Ok(updated_plan) + Ok(()) } /// Converts an [ExecutionPlan] trait object to a [PhysicalSortExpr] slice when possible. diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index f8bf3bb965e8..4d03840d3dd3 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -147,6 +147,10 @@ impl ExecutionPlan for OutputRequirementExec { self.input.output_ordering() } + fn maintains_input_order(&self) -> Vec { + vec![true] + } + fn children(&self) -> Vec> { vec![self.input.clone()] } diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index d59248aadf05..9e9f647d073f 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -24,13 +24,13 @@ use std::sync::Arc; use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::joins::SymmetricHashJoinExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::config::OptimizerOptions; use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; +use datafusion_physical_plan::joins::SymmetricHashJoinExec; /// The PipelineChecker rule rejects non-runnable query plans that use /// pipeline-breaking operators on infinite input(s). @@ -70,14 +70,14 @@ impl PhysicalOptimizerRule for PipelineChecker { pub struct PipelineStatePropagator { pub(crate) plan: Arc, pub(crate) unbounded: bool, - pub(crate) children: Vec, + pub(crate) children: Vec, } impl PipelineStatePropagator { /// Constructs a new, default pipelining state. pub fn new(plan: Arc) -> Self { let children = plan.children(); - PipelineStatePropagator { + Self { plan, unbounded: false, children: children.into_iter().map(Self::new).collect(), @@ -86,10 +86,7 @@ impl PipelineStatePropagator { /// Returns the children unboundedness information. pub fn children_unbounded(&self) -> Vec { - self.children - .iter() - .map(|c| c.unbounded) - .collect::>() + self.children.iter().map(|c| c.unbounded).collect() } } @@ -109,26 +106,23 @@ impl TreeNode for PipelineStatePropagator { Ok(VisitRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { if !self.children.is_empty() { - let new_children = self + self.children = self .children .into_iter() .map(transform) - .collect::>>()?; - let children_plans = new_children.iter().map(|c| c.plan.clone()).collect(); - - Ok(PipelineStatePropagator { - plan: with_new_children_if_necessary(self.plan, children_plans)?.into(), - unbounded: self.unbounded, - children: new_children, - }) - } else { - Ok(self) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 0ff7e9f48edc..91f3d2abc6ff 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -21,14 +21,13 @@ use std::sync::Arc; +use super::utils::is_repartition; use crate::error::Result; -use crate::physical_optimizer::utils::{is_coalesce_partitions, is_sort, ExecTree}; +use crate::physical_optimizer::utils::{is_coalesce_partitions, is_sort}; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use super::utils::is_repartition; - use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_physical_plan::unbounded_output; @@ -40,80 +39,67 @@ use datafusion_physical_plan::unbounded_output; #[derive(Debug, Clone)] pub(crate) struct OrderPreservationContext { pub(crate) plan: Arc, - ordering_onwards: Vec>, + ordering_connection: bool, + children_nodes: Vec, } impl OrderPreservationContext { - /// Creates a "default" order-preservation context. + /// Creates an empty context tree. Each node has `false` connections. pub fn new(plan: Arc) -> Self { - let length = plan.children().len(); - OrderPreservationContext { + let children = plan.children(); + Self { plan, - ordering_onwards: vec![None; length], + ordering_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } /// Creates a new order-preservation context from those of children nodes. - pub fn new_from_children_nodes( - children_nodes: Vec, - parent_plan: Arc, - ) -> Result { - let children_plans = children_nodes - .iter() - .map(|item| item.plan.clone()) - .collect(); - let ordering_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, item)| { - // `ordering_onwards` tree keeps track of executors that maintain - // ordering, (or that can maintain ordering with the replacement of - // its variant) - let plan = item.plan; - let children = plan.children(); - let ordering_onwards = item.ordering_onwards; - if children.is_empty() { - // Plan has no children, there is nothing to propagate. - None - } else if ordering_onwards[0].is_none() - && ((is_repartition(&plan) && !plan.maintains_input_order()[0]) - || (is_coalesce_partitions(&plan) - && children[0].output_ordering().is_some())) - { - Some(ExecTree::new(plan, idx, vec![])) - } else { - let children = ordering_onwards - .into_iter() - .flatten() - .filter(|item| { - // Only consider operators that maintains ordering - plan.maintains_input_order()[item.idx] - || is_coalesce_partitions(&plan) - || is_repartition(&plan) - }) - .collect::>(); - if children.is_empty() { - None - } else { - Some(ExecTree::new(plan, idx, children)) - } - } - }) - .collect(); - let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); - Ok(OrderPreservationContext { - plan, - ordering_onwards, - }) - } + pub fn update_children(mut self) -> Result { + for node in self.children_nodes.iter_mut() { + let plan = node.plan.clone(); + let children = plan.children(); + let maintains_input_order = plan.maintains_input_order(); + let inspect_child = |idx| { + maintains_input_order[idx] + || is_coalesce_partitions(&plan) + || is_repartition(&plan) + }; + + // We cut the path towards nodes that do not maintain ordering. + for (idx, c) in node.children_nodes.iter_mut().enumerate() { + c.ordering_connection &= inspect_child(idx); + } + + node.ordering_connection = if children.is_empty() { + false + } else if !node.children_nodes[0].ordering_connection + && ((is_repartition(&plan) && !maintains_input_order[0]) + || (is_coalesce_partitions(&plan) + && children[0].output_ordering().is_some())) + { + // We either have a RepartitionExec or a CoalescePartitionsExec + // and they lose their input ordering, so initiate connection: + true + } else { + // Maintain connection if there is a child with a connection, + // and operator can possibly maintain that connection (either + // in its current form or when we replace it with the corresponding + // order preserving operator). + node.children_nodes + .iter() + .enumerate() + .any(|(idx, c)| c.ordering_connection && inspect_child(idx)) + } + } - /// Computes order-preservation contexts for every child of the plan. - pub fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(OrderPreservationContext::new) - .collect() + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); + self.ordering_connection = false; + Ok(self) } } @@ -122,8 +108,8 @@ impl TreeNode for OrderPreservationContext { where F: FnMut(&Self) -> Result, { - for child in self.children() { - match op(&child)? { + for child in &self.children_nodes { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -132,68 +118,88 @@ impl TreeNode for OrderPreservationContext { Ok(VisitRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - OrderPreservationContext::new_from_children_nodes(children_nodes, self.plan) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } -/// Calculates the updated plan by replacing executors that lose ordering -/// inside the `ExecTree` with their order-preserving variants. This will +/// Calculates the updated plan by replacing operators that lose ordering +/// inside `sort_input` with their order-preserving variants. This will /// generate an alternative plan, which will be accepted or rejected later on /// depending on whether it helps us remove a `SortExec`. fn get_updated_plan( - exec_tree: &ExecTree, + mut sort_input: OrderPreservationContext, // Flag indicating that it is desirable to replace `RepartitionExec`s with // `SortPreservingRepartitionExec`s: is_spr_better: bool, // Flag indicating that it is desirable to replace `CoalescePartitionsExec`s // with `SortPreservingMergeExec`s: is_spm_better: bool, -) -> Result> { - let plan = exec_tree.plan.clone(); +) -> Result { + let updated_children = sort_input + .children_nodes + .clone() + .into_iter() + .map(|item| { + // Update children and their descendants in the given tree if the connection is open: + if item.ordering_connection { + get_updated_plan(item, is_spr_better, is_spm_better) + } else { + Ok(item) + } + }) + .collect::>>()?; - let mut children = plan.children(); - // Update children and their descendants in the given tree: - for item in &exec_tree.children { - children[item.idx] = get_updated_plan(item, is_spr_better, is_spm_better)?; - } - // Construct the plan with updated children: - let mut plan = plan.with_new_children(children)?; + sort_input.plan = sort_input + .plan + .with_new_children(updated_children.iter().map(|c| c.plan.clone()).collect())?; + sort_input.ordering_connection = false; + sort_input.children_nodes = updated_children; // When a `RepartitionExec` doesn't preserve ordering, replace it with - // a `SortPreservingRepartitionExec` if appropriate: - if is_repartition(&plan) && !plan.maintains_input_order()[0] && is_spr_better { - let child = plan.children().swap_remove(0); - let repartition = RepartitionExec::try_new(child, plan.output_partitioning())? - .with_preserve_order(); - plan = Arc::new(repartition) as _ - } - // When the input of a `CoalescePartitionsExec` has an ordering, replace it - // with a `SortPreservingMergeExec` if appropriate: - let mut children = plan.children(); - if is_coalesce_partitions(&plan) - && children[0].output_ordering().is_some() - && is_spm_better + // a sort-preserving variant if appropriate: + if is_repartition(&sort_input.plan) + && !sort_input.plan.maintains_input_order()[0] + && is_spr_better { - let child = children.swap_remove(0); - plan = Arc::new(SortPreservingMergeExec::new( - child.output_ordering().unwrap_or(&[]).to_vec(), - child, - )) as _ + let child = sort_input.plan.children().swap_remove(0); + let repartition = + RepartitionExec::try_new(child, sort_input.plan.output_partitioning())? + .with_preserve_order(); + sort_input.plan = Arc::new(repartition) as _; + sort_input.children_nodes[0].ordering_connection = true; + } else if is_coalesce_partitions(&sort_input.plan) && is_spm_better { + // When the input of a `CoalescePartitionsExec` has an ordering, replace it + // with a `SortPreservingMergeExec` if appropriate: + if let Some(ordering) = sort_input.children_nodes[0] + .plan + .output_ordering() + .map(|o| o.to_vec()) + { + // Now we can mutate `new_node.children_nodes` safely + let child = sort_input.children_nodes.clone().swap_remove(0); + sort_input.plan = + Arc::new(SortPreservingMergeExec::new(ordering, child.plan)) as _; + sort_input.children_nodes[0].ordering_connection = true; + } } - Ok(plan) + + Ok(sort_input) } /// The `replace_with_order_preserving_variants` optimizer sub-rule tries to @@ -211,11 +217,11 @@ fn get_updated_plan( /// /// The algorithm flow is simply like this: /// 1. Visit nodes of the physical plan bottom-up and look for `SortExec` nodes. -/// 1_1. During the traversal, build an `ExecTree` to keep track of operators -/// that maintain ordering (or can maintain ordering when replaced by an -/// order-preserving variant) until a `SortExec` is found. +/// 1_1. During the traversal, keep track of operators that maintain ordering +/// (or can maintain ordering when replaced by an order-preserving variant) until +/// a `SortExec` is found. /// 2. When a `SortExec` is found, update the child of the `SortExec` by replacing -/// operators that do not preserve ordering in the `ExecTree` with their order +/// operators that do not preserve ordering in the tree with their order /// preserving variants. /// 3. Check if the `SortExec` is still necessary in the updated plan by comparing /// its input ordering with the output ordering it imposes. We do this because @@ -239,37 +245,41 @@ pub(crate) fn replace_with_order_preserving_variants( is_spm_better: bool, config: &ConfigOptions, ) -> Result> { - let plan = &requirements.plan; - let ordering_onwards = &requirements.ordering_onwards; - if is_sort(plan) { - let exec_tree = if let Some(exec_tree) = &ordering_onwards[0] { - exec_tree - } else { - return Ok(Transformed::No(requirements)); - }; - // For unbounded cases, replace with the order-preserving variant in - // any case, as doing so helps fix the pipeline. - // Also do the replacement if opted-in via config options. - let use_order_preserving_variant = - config.optimizer.prefer_existing_sort || unbounded_output(plan); - let updated_sort_input = get_updated_plan( - exec_tree, - is_spr_better || use_order_preserving_variant, - is_spm_better || use_order_preserving_variant, - )?; - // If this sort is unnecessary, we should remove it and update the plan: - if updated_sort_input - .equivalence_properties() - .ordering_satisfy(plan.output_ordering().unwrap_or(&[])) - { - return Ok(Transformed::Yes(OrderPreservationContext { - plan: updated_sort_input, - ordering_onwards: vec![None], - })); - } + let mut requirements = requirements.update_children()?; + if !(is_sort(&requirements.plan) + && requirements.children_nodes[0].ordering_connection) + { + return Ok(Transformed::No(requirements)); } - Ok(Transformed::No(requirements)) + // For unbounded cases, replace with the order-preserving variant in + // any case, as doing so helps fix the pipeline. + // Also do the replacement if opted-in via config options. + let use_order_preserving_variant = + config.optimizer.prefer_existing_sort || unbounded_output(&requirements.plan); + + let mut updated_sort_input = get_updated_plan( + requirements.children_nodes.clone().swap_remove(0), + is_spr_better || use_order_preserving_variant, + is_spm_better || use_order_preserving_variant, + )?; + + // If this sort is unnecessary, we should remove it and update the plan: + if updated_sort_input + .plan + .equivalence_properties() + .ordering_satisfy(requirements.plan.output_ordering().unwrap_or(&[])) + { + for child in updated_sort_input.children_nodes.iter_mut() { + child.ordering_connection = false; + } + Ok(Transformed::Yes(updated_sort_input)) + } else { + for child in requirements.children_nodes.iter_mut() { + child.ordering_connection = false; + } + Ok(Transformed::Yes(requirements)) + } } #[cfg(test)] diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index b9502d92ac12..b0013863010a 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -36,8 +36,6 @@ use datafusion_physical_expr::{ LexRequirementRef, PhysicalSortExpr, PhysicalSortRequirement, }; -use itertools::izip; - /// This is a "data class" we use within the [`EnforceSorting`] rule to push /// down [`SortExec`] in the plan. In some cases, we can reduce the total /// computational cost by pushing down `SortExec`s through some executors. @@ -49,35 +47,26 @@ pub(crate) struct SortPushDown { pub plan: Arc, /// Parent required sort ordering required_ordering: Option>, - /// The adjusted request sort ordering to children. - /// By default they are the same as the plan's required input ordering, but can be adjusted based on parent required sort ordering properties. - adjusted_request_ordering: Vec>>, + children_nodes: Vec, } impl SortPushDown { - pub fn init(plan: Arc) -> Self { - let request_ordering = plan.required_input_ordering(); - SortPushDown { + /// Creates an empty tree with empty `required_ordering`'s. + pub fn new(plan: Arc) -> Self { + let children = plan.children(); + Self { plan, required_ordering: None, - adjusted_request_ordering: request_ordering, + children_nodes: children.into_iter().map(Self::new).collect(), } } - pub fn children(&self) -> Vec { - izip!( - self.plan.children().into_iter(), - self.adjusted_request_ordering.clone().into_iter(), - ) - .map(|(child, from_parent)| { - let child_request_ordering = child.required_input_ordering(); - SortPushDown { - plan: child, - required_ordering: from_parent, - adjusted_request_ordering: child_request_ordering, - } - }) - .collect() + /// Assigns the ordering requirement of the root node to the its children. + pub fn assign_initial_requirements(&mut self) { + let reqs = self.plan.required_input_ordering(); + for (child, requirement) in self.children_nodes.iter_mut().zip(reqs) { + child.required_ordering = requirement; + } } } @@ -86,9 +75,8 @@ impl TreeNode for SortPushDown { where F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { + for child in &self.children_nodes { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -97,64 +85,64 @@ impl TreeNode for SortPushDown { Ok(VisitRecursion::Continue) } - fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if !children.is_empty() { - let children_plans = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .map(|r| r.map(|s| s.plan)) - .collect::>>()?; - - match with_new_children_if_necessary(self.plan, children_plans)? { - Transformed::Yes(plan) | Transformed::No(plan) => { - self.plan = plan; - } - } - }; + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); + } Ok(self) } } pub(crate) fn pushdown_sorts( - requirements: SortPushDown, + mut requirements: SortPushDown, ) -> Result> { let plan = &requirements.plan; let parent_required = requirements.required_ordering.as_deref().unwrap_or(&[]); + if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let new_plan = if !plan + if !plan .equivalence_properties() .ordering_satisfy_requirement(parent_required) { // If the current plan is a SortExec, modify it to satisfy parent requirements: let mut new_plan = sort_exec.input().clone(); add_sort_above(&mut new_plan, parent_required, sort_exec.fetch()); - new_plan - } else { - requirements.plan + requirements.plan = new_plan; }; - let required_ordering = new_plan + + let required_ordering = requirements + .plan .output_ordering() .map(PhysicalSortRequirement::from_sort_exprs) .unwrap_or_default(); // Since new_plan is a SortExec, we can safely get the 0th index. - let child = new_plan.children().swap_remove(0); + let mut child = requirements.children_nodes.swap_remove(0); if let Some(adjusted) = - pushdown_requirement_to_children(&child, &required_ordering)? + pushdown_requirement_to_children(&child.plan, &required_ordering)? { + for (c, o) in child.children_nodes.iter_mut().zip(adjusted) { + c.required_ordering = o; + } // Can push down requirements - Ok(Transformed::Yes(SortPushDown { - plan: child, - required_ordering: None, - adjusted_request_ordering: adjusted, - })) + child.required_ordering = None; + Ok(Transformed::Yes(child)) } else { // Can not push down requirements - Ok(Transformed::Yes(SortPushDown::init(new_plan))) + let mut empty_node = SortPushDown::new(requirements.plan); + empty_node.assign_initial_requirements(); + Ok(Transformed::Yes(empty_node)) } } else { // Executors other than SortExec @@ -163,23 +151,27 @@ pub(crate) fn pushdown_sorts( .ordering_satisfy_requirement(parent_required) { // Satisfies parent requirements, immediately return. - return Ok(Transformed::Yes(SortPushDown { - required_ordering: None, - ..requirements - })); + let reqs = requirements.plan.required_input_ordering(); + for (child, order) in requirements.children_nodes.iter_mut().zip(reqs) { + child.required_ordering = order; + } + return Ok(Transformed::Yes(requirements)); } // Can not satisfy the parent requirements, check whether the requirements can be pushed down: if let Some(adjusted) = pushdown_requirement_to_children(plan, parent_required)? { - Ok(Transformed::Yes(SortPushDown { - plan: requirements.plan, - required_ordering: None, - adjusted_request_ordering: adjusted, - })) + for (c, o) in requirements.children_nodes.iter_mut().zip(adjusted) { + c.required_ordering = o; + } + requirements.required_ordering = None; + Ok(Transformed::Yes(requirements)) } else { // Can not push down requirements, add new SortExec: let mut new_plan = requirements.plan; add_sort_above(&mut new_plan, parent_required, None); - Ok(Transformed::Yes(SortPushDown::init(new_plan))) + let mut new_empty = SortPushDown::new(new_plan); + new_empty.assign_initial_requirements(); + // Can not push down requirements + Ok(Transformed::Yes(new_empty)) } } } @@ -297,10 +289,11 @@ fn pushdown_requirement_to_children( // TODO: Add support for Projection push down } -/// Determine the children requirements -/// If the children requirements are more specific, do not push down the parent requirements -/// If the the parent requirements are more specific, push down the parent requirements -/// If they are not compatible, need to add Sort. +/// Determine children requirements: +/// - If children requirements are more specific, do not push down parent +/// requirements. +/// - If parent requirements are more specific, push down parent requirements. +/// - If they are not compatible, need to add a sort. fn determine_children_requirement( parent_required: LexRequirementRef, request_child: LexRequirementRef, @@ -310,18 +303,15 @@ fn determine_children_requirement( .equivalence_properties() .requirements_compatible(request_child, parent_required) { - // request child requirements are more specific, no need to push down the parent requirements + // Child requirements are more specific, no need to push down. RequirementsCompatibility::Satisfy } else if child_plan .equivalence_properties() .requirements_compatible(parent_required, request_child) { - // parent requirements are more specific, adjust the request child requirements and push down the new requirements - let adjusted = if parent_required.is_empty() { - None - } else { - Some(parent_required.to_vec()) - }; + // Parent requirements are more specific, adjust child's requirements + // and push down the new requirements: + let adjusted = (!parent_required.is_empty()).then(|| parent_required.to_vec()); RequirementsCompatibility::Compatible(adjusted) } else { RequirementsCompatibility::NonCompatible diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index fccc1db0d359..f8063e969422 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -17,83 +17,18 @@ //! Collection of utility functions that are leveraged by the query optimizer rules -use std::fmt; -use std::fmt::Formatter; use std::sync::Arc; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; -use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; -use crate::physical_plan::{get_plan_string, ExecutionPlan}; +use crate::physical_plan::ExecutionPlan; use datafusion_physical_expr::{LexRequirementRef, PhysicalSortRequirement}; - -/// This object implements a tree that we use while keeping track of paths -/// leading to [`SortExec`]s. -#[derive(Debug, Clone)] -pub(crate) struct ExecTree { - /// The `ExecutionPlan` associated with this node - pub plan: Arc, - /// Child index of the plan in its parent - pub idx: usize, - /// Children of the plan that would need updating if we remove leaf executors - pub children: Vec, -} - -impl fmt::Display for ExecTree { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let plan_string = get_plan_string(&self.plan); - write!(f, "\nidx: {:?}", self.idx)?; - write!(f, "\nplan: {:?}", plan_string)?; - for child in self.children.iter() { - write!(f, "\nexec_tree:{}", child)?; - } - writeln!(f) - } -} - -impl ExecTree { - /// Create new Exec tree - pub fn new( - plan: Arc, - idx: usize, - children: Vec, - ) -> Self { - ExecTree { - plan, - idx, - children, - } - } -} - -/// Get `ExecTree` for each child of the plan if they are tracked. -/// # Arguments -/// -/// * `n_children` - Children count of the plan of interest -/// * `onward` - Contains `Some(ExecTree)` of the plan tracked. -/// - Contains `None` is plan is not tracked. -/// -/// # Returns -/// -/// A `Vec>` that contains tracking information of each child. -/// If a child is `None`, it is not tracked. If `Some(ExecTree)` child is tracked also. -pub(crate) fn get_children_exectrees( - n_children: usize, - onward: &Option, -) -> Vec> { - let mut children_onward = vec![None; n_children]; - if let Some(exec_tree) = &onward { - for child in &exec_tree.children { - children_onward[child.idx] = Some(child.clone()); - } - } - children_onward -} +use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; /// This utility function adds a `SortExec` above an operator according to the /// given ordering requirements while preserving the original partitioning. diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs index f51374461776..91238e5b04b4 100644 --- a/datafusion/physical-expr/src/sort_properties.rs +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -151,7 +151,7 @@ impl Neg for SortProperties { pub struct ExprOrdering { pub expr: Arc, pub state: SortProperties, - pub children: Vec, + pub children: Vec, } impl ExprOrdering { @@ -191,15 +191,13 @@ impl TreeNode for ExprOrdering { where F: FnMut(Self) -> Result, { - if self.children.is_empty() { - Ok(self) - } else { + if !self.children.is_empty() { self.children = self .children .into_iter() .map(transform) - .collect::>>()?; - Ok(self) + .collect::>()?; } + Ok(self) } } diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 14ef9c2ec27b..d01ea5507449 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -21,6 +21,7 @@ //! The Union operator combines multiple inputs with the same schema +use std::borrow::Borrow; use std::pin::Pin; use std::task::{Context, Poll}; use std::{any::Any, sync::Arc}; @@ -336,7 +337,7 @@ impl InterleaveExec { pub fn try_new(inputs: Vec>) -> Result { let schema = union_schema(&inputs); - if !can_interleave(&inputs) { + if !can_interleave(inputs.iter()) { return internal_err!( "Not all InterleaveExec children have a consistent hash partitioning" ); @@ -474,17 +475,18 @@ impl ExecutionPlan for InterleaveExec { /// It might be too strict here in the case that the input partition specs are compatible but not exactly the same. /// For example one input partition has the partition spec Hash('a','b','c') and /// other has the partition spec Hash('a'), It is safe to derive the out partition with the spec Hash('a','b','c'). -pub fn can_interleave(inputs: &[Arc]) -> bool { - if inputs.is_empty() { +pub fn can_interleave>>( + mut inputs: impl Iterator, +) -> bool { + let Some(first) = inputs.next() else { return false; - } + }; - let first_input_partition = inputs[0].output_partitioning(); - matches!(first_input_partition, Partitioning::Hash(_, _)) + let reference = first.borrow().output_partitioning(); + matches!(reference, Partitioning::Hash(_, _)) && inputs - .iter() - .map(|plan| plan.output_partitioning()) - .all(|partition| partition == first_input_partition) + .map(|plan| plan.borrow().output_partitioning()) + .all(|partition| partition == reference) } fn union_schema(inputs: &[Arc]) -> SchemaRef { From 1737d49185e9e37c15aa432342604ee559a1069d Mon Sep 17 00:00:00 2001 From: yi wang <48236141+my-vegetable-has-exploded@users.noreply.github.com> Date: Thu, 28 Dec 2023 20:12:49 +0800 Subject: [PATCH 38/63] feat: support inlist in LiteralGurantee for pruning (#8654) * support inlist in LiteralGuarantee for pruning. * add more tests * rm useless notes * Apply suggestions from code review Co-authored-by: Huaijin * add tests in row_groups * Apply suggestions from code review Co-authored-by: Ruihang Xia Co-authored-by: Andrew Lamb * update comment & add more tests --------- Co-authored-by: Huaijin Co-authored-by: Ruihang Xia Co-authored-by: Andrew Lamb --- .../physical_plan/parquet/row_groups.rs | 121 +-------- .../physical-expr/src/utils/guarantee.rs | 257 ++++++++++++++---- 2 files changed, 216 insertions(+), 162 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 8a1abb7d965f..5d18eac7d9fb 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -293,15 +293,10 @@ mod tests { use arrow::datatypes::DataType::Decimal128; use arrow::datatypes::Schema; use arrow::datatypes::{DataType, Field}; - use datafusion_common::{config::ConfigOptions, TableReference, ToDFSchema}; - use datafusion_common::{DataFusionError, Result}; - use datafusion_expr::{ - builder::LogicalTableSource, cast, col, lit, AggregateUDF, Expr, ScalarUDF, - TableSource, WindowUDF, - }; + use datafusion_common::{Result, ToDFSchema}; + use datafusion_expr::{cast, col, lit, Expr}; use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; - use datafusion_sql::planner::ContextProvider; use parquet::arrow::arrow_to_parquet_schema; use parquet::arrow::async_reader::ParquetObjectReader; use parquet::basic::LogicalType; @@ -1105,13 +1100,18 @@ mod tests { let data = bytes::Bytes::from(std::fs::read(path).unwrap()); // generate pruning predicate - let schema = Schema::new(vec![ - Field::new("String", DataType::Utf8, false), - Field::new("String3", DataType::Utf8, false), - ]); - let sql = - "SELECT * FROM tbl WHERE \"String\" IN ('Hello_Not_Exists', 'Hello_Not_Exists2')"; - let expr = sql_to_physical_plan(sql).unwrap(); + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + + let expr = col(r#""String""#).in_list( + vec![ + lit("Hello_Not_Exists"), + lit("Hello_Not_Exists2"), + lit("Hello_Not_Exists3"), + lit("Hello_Not_Exist4"), + ], + false, + ); + let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); @@ -1312,97 +1312,4 @@ mod tests { Ok(pruned_row_group) } - - fn sql_to_physical_plan(sql: &str) -> Result> { - use datafusion_optimizer::{ - analyzer::Analyzer, optimizer::Optimizer, OptimizerConfig, OptimizerContext, - }; - use datafusion_sql::{ - planner::SqlToRel, - sqlparser::{ast::Statement, parser::Parser}, - }; - use sqlparser::dialect::GenericDialect; - - // parse the SQL - let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... - let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); - let statement = &ast[0]; - - // create a logical query plan - let schema_provider = TestSchemaProvider::new(); - let sql_to_rel = SqlToRel::new(&schema_provider); - let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); - - // hard code the return value of now() - let config = OptimizerContext::new().with_skip_failing_rules(false); - let analyzer = Analyzer::new(); - let optimizer = Optimizer::new(); - // analyze and optimize the logical plan - let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; - let plan = optimizer.optimize(&plan, &config, |_, _| {})?; - // convert the logical plan into a physical plan - let exprs = plan.expressions(); - let expr = &exprs[0]; - let df_schema = plan.schema().as_ref().to_owned(); - let tb_schema: Schema = df_schema.clone().into(); - let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, &tb_schema, &execution_props) - } - - struct TestSchemaProvider { - options: ConfigOptions, - tables: HashMap>, - } - - impl TestSchemaProvider { - pub fn new() -> Self { - let mut tables = HashMap::new(); - tables.insert( - "tbl".to_string(), - create_table_source(vec![Field::new( - "String".to_string(), - DataType::Utf8, - false, - )]), - ); - - Self { - options: Default::default(), - tables, - } - } - } - - impl ContextProvider for TestSchemaProvider { - fn get_table_source(&self, name: TableReference) -> Result> { - match self.tables.get(name.table()) { - Some(table) => Ok(table.clone()), - _ => datafusion_common::plan_err!("Table not found: {}", name.table()), - } - } - - fn get_function_meta(&self, _name: &str) -> Option> { - None - } - - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None - } - - fn get_variable_type(&self, _variable_names: &[String]) -> Option { - None - } - - fn options(&self) -> &ConfigOptions { - &self.options - } - - fn get_window_meta(&self, _name: &str) -> Option> { - None - } - } - - fn create_table_source(fields: Vec) -> Arc { - Arc::new(LogicalTableSource::new(Arc::new(Schema::new(fields)))) - } } diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index 59ec255754c0..0aee2af67fdd 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -77,7 +77,7 @@ pub struct LiteralGuarantee { } /// What is guaranteed about the values for a [`LiteralGuarantee`]? -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Guarantee { /// Guarantee that the expression is `true` if `column` is one of the values. If /// `column` is not one of the values, the expression can not be `true`. @@ -94,15 +94,9 @@ impl LiteralGuarantee { /// create these structures from an predicate (boolean expression). fn try_new<'a>( column_name: impl Into, - op: Operator, + guarantee: Guarantee, literals: impl IntoIterator, ) -> Option { - let guarantee = match op { - Operator::Eq => Guarantee::In, - Operator::NotEq => Guarantee::NotIn, - _ => return None, - }; - let literals: HashSet<_> = literals.into_iter().cloned().collect(); Some(Self { @@ -120,7 +114,7 @@ impl LiteralGuarantee { /// expression is guaranteed to be `null` or `false`. /// /// # Notes: - /// 1. `expr` must be a boolean expression. + /// 1. `expr` must be a boolean expression or inlist expression. /// 2. `expr` is not simplified prior to analysis. pub fn analyze(expr: &Arc) -> Vec { // split conjunction: AND AND ... @@ -130,6 +124,39 @@ impl LiteralGuarantee { .fold(GuaranteeBuilder::new(), |builder, expr| { if let Some(cel) = ColOpLit::try_new(expr) { return builder.aggregate_conjunct(cel); + } else if let Some(inlist) = expr + .as_any() + .downcast_ref::() + { + // Only support single-column inlist currently, multi-column inlist is not supported + let col = inlist + .expr() + .as_any() + .downcast_ref::(); + let Some(col) = col else { + return builder; + }; + + let literals = inlist + .list() + .iter() + .map(|e| e.as_any().downcast_ref::()) + .collect::>>(); + let Some(literals) = literals else { + return builder; + }; + + let guarantee = if inlist.negated() { + Guarantee::NotIn + } else { + Guarantee::In + }; + + builder.aggregate_multi_conjunct( + col, + guarantee, + literals.iter().map(|e| e.value()), + ) } else { // split disjunction: OR OR ... let disjunctions = split_disjunction(expr); @@ -168,14 +195,21 @@ impl LiteralGuarantee { // if all terms are 'col literal' with the same column // and operation we can infer any guarantees + // + // For those like (a != foo AND (a != bar OR a != baz)). + // We can't combine the (a != bar OR a != baz) part, but + // it also doesn't invalidate our knowledge that a != + // foo is required for the expression to be true. + // So we can only create a multi value guarantee for `=` + // (or a single value). (e.g. ignore `a != foo OR a != bar`) let first_term = &terms[0]; if terms.iter().all(|term| { term.col.name() == first_term.col.name() - && term.op == first_term.op + && term.guarantee == Guarantee::In }) { builder.aggregate_multi_conjunct( first_term.col, - first_term.op, + Guarantee::In, terms.iter().map(|term| term.lit.value()), ) } else { @@ -197,9 +231,9 @@ struct GuaranteeBuilder<'a> { /// e.g. `a = foo AND a = bar` then the relevant guarantee will be None guarantees: Vec>, - /// Key is the (column name, operator type) + /// Key is the (column name, guarantee type) /// Value is the index into `guarantees` - map: HashMap<(&'a crate::expressions::Column, Operator), usize>, + map: HashMap<(&'a crate::expressions::Column, Guarantee), usize>, } impl<'a> GuaranteeBuilder<'a> { @@ -216,7 +250,7 @@ impl<'a> GuaranteeBuilder<'a> { fn aggregate_conjunct(self, col_op_lit: ColOpLit<'a>) -> Self { self.aggregate_multi_conjunct( col_op_lit.col, - col_op_lit.op, + col_op_lit.guarantee, [col_op_lit.lit.value()], ) } @@ -233,10 +267,10 @@ impl<'a> GuaranteeBuilder<'a> { fn aggregate_multi_conjunct( mut self, col: &'a crate::expressions::Column, - op: Operator, + guarantee: Guarantee, new_values: impl IntoIterator, ) -> Self { - let key = (col, op); + let key = (col, guarantee); if let Some(index) = self.map.get(&key) { // already have a guarantee for this column let entry = &mut self.guarantees[*index]; @@ -257,26 +291,20 @@ impl<'a> GuaranteeBuilder<'a> { // another `AND a != 6` we know that a must not be either 5 or 6 // for the expression to be true Guarantee::NotIn => { - // can extend if only single literal, otherwise invalidate let new_values: HashSet<_> = new_values.into_iter().collect(); - if new_values.len() == 1 { - existing.literals.extend(new_values.into_iter().cloned()) - } else { - // this is like (a != foo AND (a != bar OR a != baz)). - // We can't combine the (a != bar OR a != baz) part, but - // it also doesn't invalidate our knowledge that a != - // foo is required for the expression to be true - } + existing.literals.extend(new_values.into_iter().cloned()); } Guarantee::In => { - // for an IN guarantee, it is ok if the value is the same - // e.g. `a = foo AND a = foo` but not if the value is different - // e.g. `a = foo AND a = bar` - if new_values + let intersection = new_values .into_iter() - .all(|new_value| existing.literals.contains(new_value)) - { - // all values are already in the set + .filter(|new_value| existing.literals.contains(*new_value)) + .collect::>(); + // for an In guarantee, if the intersection is not empty, we can extend the guarantee + // e.g. `a IN (1,2,3) AND a IN (2,3,4)` is `a IN (2,3)` + // otherwise, we invalidate the guarantee + // e.g. `a IN (1,2,3) AND a IN (4,5,6)` is `a IN ()`, which is invalid + if !intersection.is_empty() { + existing.literals = intersection.into_iter().cloned().collect(); } else { // at least one was not, so invalidate the guarantee *entry = None; @@ -287,17 +315,12 @@ impl<'a> GuaranteeBuilder<'a> { // This is a new guarantee let new_values: HashSet<_> = new_values.into_iter().collect(); - // new_values are combined with OR, so we can only create a - // multi-column guarantee for `=` (or a single value). - // (e.g. ignore `a != foo OR a != bar`) - if op == Operator::Eq || new_values.len() == 1 { - if let Some(guarantee) = - LiteralGuarantee::try_new(col.name(), op, new_values) - { - // add it to the list of guarantees - self.guarantees.push(Some(guarantee)); - self.map.insert(key, self.guarantees.len() - 1); - } + if let Some(guarantee) = + LiteralGuarantee::try_new(col.name(), guarantee, new_values) + { + // add it to the list of guarantees + self.guarantees.push(Some(guarantee)); + self.map.insert(key, self.guarantees.len() - 1); } } @@ -311,10 +334,10 @@ impl<'a> GuaranteeBuilder<'a> { } } -/// Represents a single `col literal` expression +/// Represents a single `col [not]in literal` expression struct ColOpLit<'a> { col: &'a crate::expressions::Column, - op: Operator, + guarantee: Guarantee, lit: &'a crate::expressions::Literal, } @@ -322,7 +345,7 @@ impl<'a> ColOpLit<'a> { /// Returns Some(ColEqLit) if the expression is either: /// 1. `col literal` /// 2. `literal col` - /// + /// 3. operator is `=` or `!=` /// Returns None otherwise fn try_new(expr: &'a Arc) -> Option { let binary_expr = expr @@ -334,21 +357,32 @@ impl<'a> ColOpLit<'a> { binary_expr.op(), binary_expr.right().as_any(), ); - + let guarantee = match op { + Operator::Eq => Guarantee::In, + Operator::NotEq => Guarantee::NotIn, + _ => return None, + }; // col literal if let (Some(col), Some(lit)) = ( left.downcast_ref::(), right.downcast_ref::(), ) { - Some(Self { col, op: *op, lit }) + Some(Self { + col, + guarantee, + lit, + }) } // literal col else if let (Some(lit), Some(col)) = ( left.downcast_ref::(), right.downcast_ref::(), ) { - // Used swapped operator operator, if possible - op.swap().map(|op| Self { col, op, lit }) + Some(Self { + col, + guarantee, + lit, + }) } else { None } @@ -645,9 +679,122 @@ mod test { ); } - // TODO https://github.com/apache/arrow-datafusion/issues/8436 - // a IN (...) - // b NOT IN (...) + #[test] + fn test_single_inlist() { + // b IN (1, 2, 3) + test_analyze( + col("b").in_list(vec![lit(1), lit(2), lit(3)], false), + vec![in_guarantee("b", [1, 2, 3])], + ); + // b NOT IN (1, 2, 3) + test_analyze( + col("b").in_list(vec![lit(1), lit(2), lit(3)], true), + vec![not_in_guarantee("b", [1, 2, 3])], + ); + } + + #[test] + fn test_inlist_conjunction() { + // b IN (1, 2, 3) AND b IN (2, 3, 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").in_list(vec![lit(2), lit(3), lit(4)], false)), + vec![in_guarantee("b", [2, 3])], + ); + // b NOT IN (1, 2, 3) AND b IN (2, 3, 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").in_list(vec![lit(2), lit(3), lit(4)], false)), + vec![ + not_in_guarantee("b", [1, 2, 3]), + in_guarantee("b", [2, 3, 4]), + ], + ); + // b NOT IN (1, 2, 3) AND b NOT IN (2, 3, 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").in_list(vec![lit(2), lit(3), lit(4)], true)), + vec![not_in_guarantee("b", [1, 2, 3, 4])], + ); + // b IN (1, 2, 3) AND b = 4 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(4))), + vec![], + ); + // b IN (1, 2, 3) AND b = 2 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(2))), + vec![in_guarantee("b", [2])], + ); + // b IN (1, 2, 3) AND b != 2 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").not_eq(lit(2))), + vec![in_guarantee("b", [1, 2, 3]), not_in_guarantee("b", [2])], + ); + // b NOT IN (1, 2, 3) AND b != 4 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").not_eq(lit(4))), + vec![not_in_guarantee("b", [1, 2, 3, 4])], + ); + // b NOT IN (1, 2, 3) AND b != 2 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").not_eq(lit(2))), + vec![not_in_guarantee("b", [1, 2, 3])], + ); + } + + #[test] + fn test_inlist_with_disjunction() { + // b IN (1, 2, 3) AND (b = 3 OR b = 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(3)).or(col("b").eq(lit(4)))), + vec![in_guarantee("b", [3])], + ); + // b IN (1, 2, 3) AND (b = 4 OR b = 5) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(4)).or(col("b").eq(lit(5)))), + vec![], + ); + // b NOT IN (1, 2, 3) AND (b = 3 OR b = 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").eq(lit(3)).or(col("b").eq(lit(4)))), + vec![not_in_guarantee("b", [1, 2, 3]), in_guarantee("b", [3, 4])], + ); + // b IN (1, 2, 3) OR b = 2 + // TODO this should be in_guarantee("b", [1, 2, 3]) but currently we don't support to anylize this kind of disjunction. Only `ColOpLit OR ColOpLit` is supported. + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .or(col("b").eq(lit(2))), + vec![], + ); + // b IN (1, 2, 3) OR b != 3 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .or(col("b").not_eq(lit(3))), + vec![], + ); + } /// Tests that analyzing expr results in the expected guarantees fn test_analyze(expr: Expr, expected: Vec) { @@ -673,7 +820,7 @@ mod test { S: Into + 'a, { let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); - LiteralGuarantee::try_new(column, Operator::Eq, literals.iter()).unwrap() + LiteralGuarantee::try_new(column, Guarantee::In, literals.iter()).unwrap() } /// Guarantee that the expression is true if the column is NOT any of the specified values @@ -683,7 +830,7 @@ mod test { S: Into + 'a, { let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); - LiteralGuarantee::try_new(column, Operator::NotEq, literals.iter()).unwrap() + LiteralGuarantee::try_new(column, Guarantee::NotIn, literals.iter()).unwrap() } /// Convert a logical expression to a physical expression (without any simplification, etc) From fba5cc0b9062297e38cbe388d7f1b13debe8ba92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Thu, 28 Dec 2023 15:27:21 +0300 Subject: [PATCH 39/63] Streaming CLI support (#8651) * Streaming CLI support * Update Cargo.toml * Remove duplications * Clean up * Stream test will be added * Update print_format.rs * Address feedback * Final fix --------- Co-authored-by: Mehmet Ozan Kabak --- Cargo.toml | 2 +- datafusion-cli/Cargo.lock | 1 + datafusion-cli/Cargo.toml | 1 + datafusion-cli/src/exec.rs | 66 +++-- datafusion-cli/src/main.rs | 19 +- datafusion-cli/src/print_format.rs | 278 +++++++++++------- datafusion-cli/src/print_options.rs | 74 ++++- .../core/src/datasource/physical_plan/mod.rs | 15 + 8 files changed, 295 insertions(+), 161 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a698fbf471f9..4ee29ea6298c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ arrow = { version = "49.0.0", features = ["prettyprint"] } arrow-array = { version = "49.0.0", default-features = false, features = ["chrono-tz"] } arrow-buffer = { version = "49.0.0", default-features = false } arrow-flight = { version = "49.0.0", features = ["flight-sql-experimental"] } -arrow-ipc = { version = "49.0.0", default-features = false, features=["lz4"] } +arrow-ipc = { version = "49.0.0", default-features = false, features = ["lz4"] } arrow-ord = { version = "49.0.0", default-features = false } arrow-schema = { version = "49.0.0", default-features = false } async-trait = "0.1.73" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 9f75013c86dc..8e9bbd8a0dfd 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1160,6 +1160,7 @@ dependencies = [ "datafusion-common", "dirs", "env_logger", + "futures", "mimalloc", "object_store", "parking_lot", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index f57097683698..e1ddba4cad1a 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -38,6 +38,7 @@ datafusion = { path = "../datafusion/core", version = "34.0.0", features = ["avr datafusion-common = { path = "../datafusion/common" } dirs = "4.0.0" env_logger = "0.9" +futures = "0.3" mimalloc = { version = "0.1", default-features = false } object_store = { version = "0.8.0", features = ["aws", "gcp"] } parking_lot = { version = "0.12" } diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 8af534cd1375..ba9aa2e69aa6 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -17,6 +17,12 @@ //! Execution functions +use std::io::prelude::*; +use std::io::BufReader; +use std::time::Instant; +use std::{fs::File, sync::Arc}; + +use crate::print_format::PrintFormat; use crate::{ command::{Command, OutputFormat}, helper::{unescape_input, CliHelper}, @@ -26,21 +32,19 @@ use crate::{ }, print_options::{MaxRows, PrintOptions}, }; -use datafusion::common::plan_datafusion_err; + +use datafusion::common::{exec_datafusion_err, plan_datafusion_err}; +use datafusion::datasource::listing::ListingTableUrl; +use datafusion::datasource::physical_plan::is_plan_streaming; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::{CreateExternalTable, DdlStatement, LogicalPlan}; +use datafusion::physical_plan::{collect, execute_stream}; +use datafusion::prelude::SessionContext; use datafusion::sql::{parser::DFParser, sqlparser::dialect::dialect_from_str}; -use datafusion::{ - datasource::listing::ListingTableUrl, - error::{DataFusionError, Result}, - logical_expr::{CreateExternalTable, DdlStatement}, -}; -use datafusion::{logical_expr::LogicalPlan, prelude::SessionContext}; + use object_store::ObjectStore; use rustyline::error::ReadlineError; use rustyline::Editor; -use std::io::prelude::*; -use std::io::BufReader; -use std::time::Instant; -use std::{fs::File, sync::Arc}; use url::Url; /// run and execute SQL statements and commands, against a context with the given print options @@ -125,8 +129,6 @@ pub async fn exec_from_repl( ))); rl.load_history(".history").ok(); - let mut print_options = print_options.clone(); - loop { match rl.readline("❯ ") { Ok(line) if line.starts_with('\\') => { @@ -138,9 +140,7 @@ pub async fn exec_from_repl( Command::OutputFormat(subcommand) => { if let Some(subcommand) = subcommand { if let Ok(command) = subcommand.parse::() { - if let Err(e) = - command.execute(&mut print_options).await - { + if let Err(e) = command.execute(print_options).await { eprintln!("{e}") } } else { @@ -154,7 +154,7 @@ pub async fn exec_from_repl( } } _ => { - if let Err(e) = cmd.execute(ctx, &mut print_options).await { + if let Err(e) = cmd.execute(ctx, print_options).await { eprintln!("{e}") } } @@ -165,7 +165,7 @@ pub async fn exec_from_repl( } Ok(line) => { rl.add_history_entry(line.trim_end())?; - match exec_and_print(ctx, &print_options, line).await { + match exec_and_print(ctx, print_options, line).await { Ok(_) => {} Err(err) => eprintln!("{err}"), } @@ -198,7 +198,6 @@ async fn exec_and_print( sql: String, ) -> Result<()> { let now = Instant::now(); - let sql = unescape_input(&sql)?; let task_ctx = ctx.task_ctx(); let dialect = &task_ctx.session_config().options().sql_parser.dialect; @@ -227,18 +226,24 @@ async fn exec_and_print( if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { create_external_table(ctx, cmd).await?; } + let df = ctx.execute_logical_plan(plan).await?; - let results = df.collect().await?; + let physical_plan = df.create_physical_plan().await?; - let print_options = if should_ignore_maxrows { - PrintOptions { - maxrows: MaxRows::Unlimited, - ..print_options.clone() - } + if is_plan_streaming(&physical_plan)? { + let stream = execute_stream(physical_plan, task_ctx.clone())?; + print_options.print_stream(stream, now).await?; } else { - print_options.clone() - }; - print_options.print_batches(&results, now)?; + let mut print_options = print_options.clone(); + if should_ignore_maxrows { + print_options.maxrows = MaxRows::Unlimited; + } + if print_options.format == PrintFormat::Automatic { + print_options.format = PrintFormat::Table; + } + let results = collect(physical_plan, task_ctx.clone()).await?; + print_options.print_batches(&results, now)?; + } } Ok(()) @@ -272,10 +277,7 @@ async fn create_external_table( .object_store_registry .get_store(url) .map_err(|_| { - DataFusionError::Execution(format!( - "Unsupported object store scheme: {}", - scheme - )) + exec_datafusion_err!("Unsupported object store scheme: {}", scheme) })? } }; diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 8b74a797b57b..563d172f2c95 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -15,7 +15,12 @@ // specific language governing permissions and limitations // under the License. -use clap::Parser; +use std::collections::HashMap; +use std::env; +use std::path::Path; +use std::str::FromStr; +use std::sync::{Arc, OnceLock}; + use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionConfig; use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool}; @@ -29,12 +34,9 @@ use datafusion_cli::{ print_options::{MaxRows, PrintOptions}, DATAFUSION_CLI_VERSION, }; + +use clap::Parser; use mimalloc::MiMalloc; -use std::collections::HashMap; -use std::env; -use std::path::Path; -use std::str::FromStr; -use std::sync::{Arc, OnceLock}; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; @@ -111,7 +113,7 @@ struct Args { )] rc: Option>, - #[clap(long, arg_enum, default_value_t = PrintFormat::Table)] + #[clap(long, arg_enum, default_value_t = PrintFormat::Automatic)] format: PrintFormat, #[clap( @@ -331,9 +333,8 @@ fn extract_memory_pool_size(size: &str) -> Result { #[cfg(test)] mod tests { - use datafusion::assert_batches_eq; - use super::*; + use datafusion::assert_batches_eq; fn assert_conversion(input: &str, expected: Result) { let result = extract_memory_pool_size(input); diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index 0738bf6f9b47..ea418562495d 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -16,23 +16,27 @@ // under the License. //! Print format variants + +use std::str::FromStr; + use crate::print_options::MaxRows; + use arrow::csv::writer::WriterBuilder; use arrow::json::{ArrayWriter, LineDelimitedWriter}; +use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches_with_options; -use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::format::DEFAULT_FORMAT_OPTIONS; -use datafusion::error::{DataFusionError, Result}; -use std::str::FromStr; +use datafusion::error::Result; /// Allow records to be printed in different formats -#[derive(Debug, PartialEq, Eq, clap::ArgEnum, Clone)] +#[derive(Debug, PartialEq, Eq, clap::ArgEnum, Clone, Copy)] pub enum PrintFormat { Csv, Tsv, Table, Json, NdJson, + Automatic, } impl FromStr for PrintFormat { @@ -44,31 +48,44 @@ impl FromStr for PrintFormat { } macro_rules! batches_to_json { - ($WRITER: ident, $batches: expr) => {{ - let mut bytes = vec![]; + ($WRITER: ident, $writer: expr, $batches: expr) => {{ { - let mut writer = $WRITER::new(&mut bytes); - $batches.iter().try_for_each(|batch| writer.write(batch))?; - writer.finish()?; + if !$batches.is_empty() { + let mut json_writer = $WRITER::new(&mut *$writer); + for batch in $batches { + json_writer.write(batch)?; + } + json_writer.finish()?; + json_finish!($WRITER, $writer); + } } - String::from_utf8(bytes).map_err(|e| DataFusionError::External(Box::new(e)))? + Ok(()) as Result<()> }}; } -fn print_batches_with_sep(batches: &[RecordBatch], delimiter: u8) -> Result { - let mut bytes = vec![]; - { - let builder = WriterBuilder::new() - .with_header(true) - .with_delimiter(delimiter); - let mut writer = builder.build(&mut bytes); - for batch in batches { - writer.write(batch)?; - } +macro_rules! json_finish { + (ArrayWriter, $writer: expr) => {{ + writeln!($writer)?; + }}; + (LineDelimitedWriter, $writer: expr) => {{}}; +} + +fn print_batches_with_sep( + writer: &mut W, + batches: &[RecordBatch], + delimiter: u8, + with_header: bool, +) -> Result<()> { + let builder = WriterBuilder::new() + .with_header(with_header) + .with_delimiter(delimiter); + let mut csv_writer = builder.build(writer); + + for batch in batches { + csv_writer.write(batch)?; } - let formatted = - String::from_utf8(bytes).map_err(|e| DataFusionError::External(Box::new(e)))?; - Ok(formatted) + + Ok(()) } fn keep_only_maxrows(s: &str, maxrows: usize) -> String { @@ -88,97 +105,118 @@ fn keep_only_maxrows(s: &str, maxrows: usize) -> String { result.join("\n") } -fn format_batches_with_maxrows( +fn format_batches_with_maxrows( + writer: &mut W, batches: &[RecordBatch], maxrows: MaxRows, -) -> Result { +) -> Result<()> { match maxrows { MaxRows::Limited(maxrows) => { - // Only format enough batches for maxrows + // Filter batches to meet the maxrows condition let mut filtered_batches = Vec::new(); - let mut batches = batches; - let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); - if row_count > maxrows { - let mut accumulated_rows = 0; - - for batch in batches { + let mut row_count: usize = 0; + let mut over_limit = false; + for batch in batches { + if row_count + batch.num_rows() > maxrows { + // If adding this batch exceeds maxrows, slice the batch + let limit = maxrows - row_count; + let sliced_batch = batch.slice(0, limit); + filtered_batches.push(sliced_batch); + over_limit = true; + break; + } else { filtered_batches.push(batch.clone()); - if accumulated_rows + batch.num_rows() > maxrows { - break; - } - accumulated_rows += batch.num_rows(); + row_count += batch.num_rows(); } - - batches = &filtered_batches; } - let mut formatted = format!( - "{}", - pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)?, - ); - - if row_count > maxrows { - formatted = keep_only_maxrows(&formatted, maxrows); + let formatted = pretty_format_batches_with_options( + &filtered_batches, + &DEFAULT_FORMAT_OPTIONS, + )?; + if over_limit { + let mut formatted_str = format!("{}", formatted); + formatted_str = keep_only_maxrows(&formatted_str, maxrows); + writeln!(writer, "{}", formatted_str)?; + } else { + writeln!(writer, "{}", formatted)?; } - - Ok(formatted) } MaxRows::Unlimited => { - // maxrows not specified, print all rows - Ok(format!( - "{}", - pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)?, - )) + let formatted = + pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)?; + writeln!(writer, "{}", formatted)?; } } + + Ok(()) } impl PrintFormat { - /// print the batches to stdout using the specified format - /// `maxrows` option is only used for `Table` format: - /// If `maxrows` is Some(n), then at most n rows will be displayed - /// If `maxrows` is None, then every row will be displayed - pub fn print_batches(&self, batches: &[RecordBatch], maxrows: MaxRows) -> Result<()> { - if batches.is_empty() { + /// Print the batches to a writer using the specified format + pub fn print_batches( + &self, + writer: &mut W, + batches: &[RecordBatch], + maxrows: MaxRows, + with_header: bool, + ) -> Result<()> { + if batches.is_empty() || batches[0].num_rows() == 0 { return Ok(()); } match self { - Self::Csv => println!("{}", print_batches_with_sep(batches, b',')?), - Self::Tsv => println!("{}", print_batches_with_sep(batches, b'\t')?), + Self::Csv | Self::Automatic => { + print_batches_with_sep(writer, batches, b',', with_header) + } + Self::Tsv => print_batches_with_sep(writer, batches, b'\t', with_header), Self::Table => { if maxrows == MaxRows::Limited(0) { return Ok(()); } - println!("{}", format_batches_with_maxrows(batches, maxrows)?,) - } - Self::Json => println!("{}", batches_to_json!(ArrayWriter, batches)), - Self::NdJson => { - println!("{}", batches_to_json!(LineDelimitedWriter, batches)) + format_batches_with_maxrows(writer, batches, maxrows) } + Self::Json => batches_to_json!(ArrayWriter, writer, batches), + Self::NdJson => batches_to_json!(LineDelimitedWriter, writer, batches), } - Ok(()) } } #[cfg(test)] mod tests { + use std::io::{Cursor, Read, Write}; + use std::sync::Arc; + use super::*; + use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; - use std::sync::Arc; + use datafusion::error::Result; + + fn run_test(batches: &[RecordBatch], test_fn: F) -> Result + where + F: Fn(&mut Cursor>, &[RecordBatch]) -> Result<()>, + { + let mut buffer = Cursor::new(Vec::new()); + test_fn(&mut buffer, batches)?; + buffer.set_position(0); + let mut contents = String::new(); + buffer.read_to_string(&mut contents)?; + Ok(contents) + } #[test] - fn test_print_batches_with_sep() { - let batches = vec![]; - assert_eq!("", print_batches_with_sep(&batches, b',').unwrap()); + fn test_print_batches_with_sep() -> Result<()> { + let contents = run_test(&[], |buffer, batches| { + print_batches_with_sep(buffer, batches, b',', true) + })?; + assert_eq!(contents, ""); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), Field::new("c", DataType::Int32, false), ])); - let batch = RecordBatch::try_new( schema, vec![ @@ -186,29 +224,33 @@ mod tests { Arc::new(Int32Array::from(vec![4, 5, 6])), Arc::new(Int32Array::from(vec![7, 8, 9])), ], - ) - .unwrap(); + )?; - let batches = vec![batch]; - let r = print_batches_with_sep(&batches, b',').unwrap(); - assert_eq!("a,b,c\n1,4,7\n2,5,8\n3,6,9\n", r); + let contents = run_test(&[batch], |buffer, batches| { + print_batches_with_sep(buffer, batches, b',', true) + })?; + assert_eq!(contents, "a,b,c\n1,4,7\n2,5,8\n3,6,9\n"); + + Ok(()) } #[test] fn test_print_batches_to_json_empty() -> Result<()> { - let batches = vec![]; - let r = batches_to_json!(ArrayWriter, &batches); - assert_eq!("", r); + let contents = run_test(&[], |buffer, batches| { + batches_to_json!(ArrayWriter, buffer, batches) + })?; + assert_eq!(contents, ""); - let r = batches_to_json!(LineDelimitedWriter, &batches); - assert_eq!("", r); + let contents = run_test(&[], |buffer, batches| { + batches_to_json!(LineDelimitedWriter, buffer, batches) + })?; + assert_eq!(contents, ""); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), Field::new("c", DataType::Int32, false), ])); - let batch = RecordBatch::try_new( schema, vec![ @@ -216,25 +258,29 @@ mod tests { Arc::new(Int32Array::from(vec![4, 5, 6])), Arc::new(Int32Array::from(vec![7, 8, 9])), ], - ) - .unwrap(); - + )?; let batches = vec![batch]; - let r = batches_to_json!(ArrayWriter, &batches); - assert_eq!("[{\"a\":1,\"b\":4,\"c\":7},{\"a\":2,\"b\":5,\"c\":8},{\"a\":3,\"b\":6,\"c\":9}]", r); - let r = batches_to_json!(LineDelimitedWriter, &batches); - assert_eq!("{\"a\":1,\"b\":4,\"c\":7}\n{\"a\":2,\"b\":5,\"c\":8}\n{\"a\":3,\"b\":6,\"c\":9}\n", r); + let contents = run_test(&batches, |buffer, batches| { + batches_to_json!(ArrayWriter, buffer, batches) + })?; + assert_eq!(contents, "[{\"a\":1,\"b\":4,\"c\":7},{\"a\":2,\"b\":5,\"c\":8},{\"a\":3,\"b\":6,\"c\":9}]\n"); + + let contents = run_test(&batches, |buffer, batches| { + batches_to_json!(LineDelimitedWriter, buffer, batches) + })?; + assert_eq!(contents, "{\"a\":1,\"b\":4,\"c\":7}\n{\"a\":2,\"b\":5,\"c\":8}\n{\"a\":3,\"b\":6,\"c\":9}\n"); + Ok(()) } #[test] fn test_format_batches_with_maxrows() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - - let batch = - RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]) - .unwrap(); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; #[rustfmt::skip] let all_rows_expected = [ @@ -244,7 +290,7 @@ mod tests { "| 1 |", "| 2 |", "| 3 |", - "+---+", + "+---+\n", ].join("\n"); #[rustfmt::skip] @@ -256,7 +302,7 @@ mod tests { "| . |", "| . |", "| . |", - "+---+", + "+---+\n", ].join("\n"); #[rustfmt::skip] @@ -272,26 +318,36 @@ mod tests { "| . |", "| . |", "| . |", - "+---+", + "+---+\n", ].join("\n"); - let no_limit = format_batches_with_maxrows(&[batch.clone()], MaxRows::Unlimited)?; - assert_eq!(all_rows_expected, no_limit); - - let maxrows_less_than_actual = - format_batches_with_maxrows(&[batch.clone()], MaxRows::Limited(1))?; - assert_eq!(one_row_expected, maxrows_less_than_actual); - let maxrows_more_than_actual = - format_batches_with_maxrows(&[batch.clone()], MaxRows::Limited(5))?; - assert_eq!(all_rows_expected, maxrows_more_than_actual); - let maxrows_equals_actual = - format_batches_with_maxrows(&[batch.clone()], MaxRows::Limited(3))?; - assert_eq!(all_rows_expected, maxrows_equals_actual); - let multi_batches = format_batches_with_maxrows( + let no_limit = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Unlimited) + })?; + assert_eq!(no_limit, all_rows_expected); + + let maxrows_less_than_actual = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(1)) + })?; + assert_eq!(maxrows_less_than_actual, one_row_expected); + + let maxrows_more_than_actual = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(5)) + })?; + assert_eq!(maxrows_more_than_actual, all_rows_expected); + + let maxrows_equals_actual = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(3)) + })?; + assert_eq!(maxrows_equals_actual, all_rows_expected); + + let multi_batches = run_test( &[batch.clone(), batch.clone(), batch.clone()], - MaxRows::Limited(5), + |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(5)) + }, )?; - assert_eq!(multi_batches_expected, multi_batches); + assert_eq!(multi_batches, multi_batches_expected); Ok(()) } diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index 0a6c8d4c36fc..b8594352b585 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -15,13 +15,21 @@ // specific language governing permissions and limitations // under the License. -use crate::print_format::PrintFormat; -use datafusion::arrow::record_batch::RecordBatch; -use datafusion::error::Result; use std::fmt::{Display, Formatter}; +use std::io::Write; +use std::pin::Pin; use std::str::FromStr; use std::time::Instant; +use crate::print_format::PrintFormat; + +use arrow::record_batch::RecordBatch; +use datafusion::common::DataFusionError; +use datafusion::error::Result; +use datafusion::physical_plan::RecordBatchStream; + +use futures::StreamExt; + #[derive(Debug, Clone, PartialEq, Copy)] pub enum MaxRows { /// show all rows in the output @@ -85,20 +93,70 @@ fn get_timing_info_str( } impl PrintOptions { - /// print the batches to stdout using the specified format + /// Print the batches to stdout using the specified format pub fn print_batches( &self, batches: &[RecordBatch], query_start_time: Instant, ) -> Result<()> { + let stdout = std::io::stdout(); + let mut writer = stdout.lock(); + + self.format + .print_batches(&mut writer, batches, self.maxrows, true)?; + let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); - // Elapsed time should not count time for printing batches - let timing_info = get_timing_info_str(row_count, self.maxrows, query_start_time); + let timing_info = get_timing_info_str( + row_count, + if self.format == PrintFormat::Table { + self.maxrows + } else { + MaxRows::Unlimited + }, + query_start_time, + ); + + if !self.quiet { + writeln!(writer, "{timing_info}")?; + } + + Ok(()) + } + + /// Print the stream to stdout using the specified format + pub async fn print_stream( + &self, + mut stream: Pin>, + query_start_time: Instant, + ) -> Result<()> { + if self.format == PrintFormat::Table { + return Err(DataFusionError::External( + "PrintFormat::Table is not implemented".to_string().into(), + )); + }; + + let stdout = std::io::stdout(); + let mut writer = stdout.lock(); + + let mut row_count = 0_usize; + let mut with_header = true; + + while let Some(Ok(batch)) = stream.next().await { + row_count += batch.num_rows(); + self.format.print_batches( + &mut writer, + &[batch], + MaxRows::Unlimited, + with_header, + )?; + with_header = false; + } - self.format.print_batches(batches, self.maxrows)?; + let timing_info = + get_timing_info_str(row_count, MaxRows::Unlimited, query_start_time); if !self.quiet { - println!("{timing_info}"); + writeln!(writer, "{timing_info}")?; } Ok(()) diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 4a6ebeab09e1..5583991355c6 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -69,6 +69,7 @@ use arrow::{ use datafusion_common::{file_options::FileTypeWriterOptions, plan_err}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_plan::ExecutionPlan; use log::debug; use object_store::path::Path; @@ -507,6 +508,20 @@ fn get_projected_output_ordering( all_orderings } +/// Get output (un)boundedness information for the given `plan`. +pub fn is_plan_streaming(plan: &Arc) -> Result { + if plan.children().is_empty() { + plan.unbounded_output(&[]) + } else { + let children_unbounded_output = plan + .children() + .iter() + .map(is_plan_streaming) + .collect::>>(); + plan.unbounded_output(&children_unbounded_output?) + } +} + #[cfg(test)] mod tests { use arrow_array::cast::AsArray; From f39c040ace0b34b0775827907aa01d6bb71cbb14 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 28 Dec 2023 11:38:16 -0700 Subject: [PATCH 40/63] Add serde support for CSV FileTypeWriterOptions (#8641) --- datafusion/proto/proto/datafusion.proto | 18 ++ datafusion/proto/src/generated/pbjson.rs | 213 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 29 ++- datafusion/proto/src/logical_plan/mod.rs | 74 ++++++ .../proto/src/physical_plan/from_proto.rs | 12 +- .../tests/cases/roundtrip_logical_plan.rs | 64 +++++- 6 files changed, 406 insertions(+), 4 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d02fc8e91b41..59b82efcbb43 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1207,6 +1207,7 @@ message FileTypeWriterOptions { oneof FileType { JsonWriterOptions json_options = 1; ParquetWriterOptions parquet_options = 2; + CsvWriterOptions csv_options = 3; } } @@ -1218,6 +1219,23 @@ message ParquetWriterOptions { WriterProperties writer_properties = 1; } +message CsvWriterOptions { + // Optional column delimiter. Defaults to `b','` + string delimiter = 1; + // Whether to write column names as file headers. Defaults to `true` + bool has_header = 2; + // Optional date format for date arrays + string date_format = 3; + // Optional datetime format for datetime arrays + string datetime_format = 4; + // Optional timestamp format for timestamp arrays + string timestamp_format = 5; + // Optional time format for time arrays + string time_format = 6; + // Optional value to represent null + string null_value = 7; +} + message WriterProperties { uint64 data_page_size_limit = 1; uint64 dictionary_page_size_limit = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f860b1f1e6a0..956244ffdbc2 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -5151,6 +5151,205 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { deserializer.deserialize_struct("datafusion.CsvScanExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for CsvWriterOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.delimiter.is_empty() { + len += 1; + } + if self.has_header { + len += 1; + } + if !self.date_format.is_empty() { + len += 1; + } + if !self.datetime_format.is_empty() { + len += 1; + } + if !self.timestamp_format.is_empty() { + len += 1; + } + if !self.time_format.is_empty() { + len += 1; + } + if !self.null_value.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CsvWriterOptions", len)?; + if !self.delimiter.is_empty() { + struct_ser.serialize_field("delimiter", &self.delimiter)?; + } + if self.has_header { + struct_ser.serialize_field("hasHeader", &self.has_header)?; + } + if !self.date_format.is_empty() { + struct_ser.serialize_field("dateFormat", &self.date_format)?; + } + if !self.datetime_format.is_empty() { + struct_ser.serialize_field("datetimeFormat", &self.datetime_format)?; + } + if !self.timestamp_format.is_empty() { + struct_ser.serialize_field("timestampFormat", &self.timestamp_format)?; + } + if !self.time_format.is_empty() { + struct_ser.serialize_field("timeFormat", &self.time_format)?; + } + if !self.null_value.is_empty() { + struct_ser.serialize_field("nullValue", &self.null_value)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CsvWriterOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "delimiter", + "has_header", + "hasHeader", + "date_format", + "dateFormat", + "datetime_format", + "datetimeFormat", + "timestamp_format", + "timestampFormat", + "time_format", + "timeFormat", + "null_value", + "nullValue", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Delimiter, + HasHeader, + DateFormat, + DatetimeFormat, + TimestampFormat, + TimeFormat, + NullValue, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "delimiter" => Ok(GeneratedField::Delimiter), + "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), + "dateFormat" | "date_format" => Ok(GeneratedField::DateFormat), + "datetimeFormat" | "datetime_format" => Ok(GeneratedField::DatetimeFormat), + "timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat), + "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), + "nullValue" | "null_value" => Ok(GeneratedField::NullValue), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CsvWriterOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CsvWriterOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut delimiter__ = None; + let mut has_header__ = None; + let mut date_format__ = None; + let mut datetime_format__ = None; + let mut timestamp_format__ = None; + let mut time_format__ = None; + let mut null_value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Delimiter => { + if delimiter__.is_some() { + return Err(serde::de::Error::duplicate_field("delimiter")); + } + delimiter__ = Some(map_.next_value()?); + } + GeneratedField::HasHeader => { + if has_header__.is_some() { + return Err(serde::de::Error::duplicate_field("hasHeader")); + } + has_header__ = Some(map_.next_value()?); + } + GeneratedField::DateFormat => { + if date_format__.is_some() { + return Err(serde::de::Error::duplicate_field("dateFormat")); + } + date_format__ = Some(map_.next_value()?); + } + GeneratedField::DatetimeFormat => { + if datetime_format__.is_some() { + return Err(serde::de::Error::duplicate_field("datetimeFormat")); + } + datetime_format__ = Some(map_.next_value()?); + } + GeneratedField::TimestampFormat => { + if timestamp_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timestampFormat")); + } + timestamp_format__ = Some(map_.next_value()?); + } + GeneratedField::TimeFormat => { + if time_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timeFormat")); + } + time_format__ = Some(map_.next_value()?); + } + GeneratedField::NullValue => { + if null_value__.is_some() { + return Err(serde::de::Error::duplicate_field("nullValue")); + } + null_value__ = Some(map_.next_value()?); + } + } + } + Ok(CsvWriterOptions { + delimiter: delimiter__.unwrap_or_default(), + has_header: has_header__.unwrap_or_default(), + date_format: date_format__.unwrap_or_default(), + datetime_format: datetime_format__.unwrap_or_default(), + timestamp_format: timestamp_format__.unwrap_or_default(), + time_format: time_format__.unwrap_or_default(), + null_value: null_value__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.CsvWriterOptions", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CubeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -7893,6 +8092,9 @@ impl serde::Serialize for FileTypeWriterOptions { file_type_writer_options::FileType::ParquetOptions(v) => { struct_ser.serialize_field("parquetOptions", v)?; } + file_type_writer_options::FileType::CsvOptions(v) => { + struct_ser.serialize_field("csvOptions", v)?; + } } } struct_ser.end() @@ -7909,12 +8111,15 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { "jsonOptions", "parquet_options", "parquetOptions", + "csv_options", + "csvOptions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { JsonOptions, ParquetOptions, + CsvOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7938,6 +8143,7 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { match value { "jsonOptions" | "json_options" => Ok(GeneratedField::JsonOptions), "parquetOptions" | "parquet_options" => Ok(GeneratedField::ParquetOptions), + "csvOptions" | "csv_options" => Ok(GeneratedField::CsvOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7972,6 +8178,13 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { return Err(serde::de::Error::duplicate_field("parquetOptions")); } file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::ParquetOptions) +; + } + GeneratedField::CsvOptions => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("csvOptions")); + } + file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::CsvOptions) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 459d5a965cd3..32e892e663ef 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1642,7 +1642,7 @@ pub struct PartitionColumn { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FileTypeWriterOptions { - #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2")] + #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2, 3")] pub file_type: ::core::option::Option, } /// Nested message and enum types in `FileTypeWriterOptions`. @@ -1654,6 +1654,8 @@ pub mod file_type_writer_options { JsonOptions(super::JsonWriterOptions), #[prost(message, tag = "2")] ParquetOptions(super::ParquetWriterOptions), + #[prost(message, tag = "3")] + CsvOptions(super::CsvWriterOptions), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1670,6 +1672,31 @@ pub struct ParquetWriterOptions { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvWriterOptions { + /// Optional column delimiter. Defaults to `b','` + #[prost(string, tag = "1")] + pub delimiter: ::prost::alloc::string::String, + /// Whether to write column names as file headers. Defaults to `true` + #[prost(bool, tag = "2")] + pub has_header: bool, + /// Optional date format for date arrays + #[prost(string, tag = "3")] + pub date_format: ::prost::alloc::string::String, + /// Optional datetime format for datetime arrays + #[prost(string, tag = "4")] + pub datetime_format: ::prost::alloc::string::String, + /// Optional timestamp format for timestamp arrays + #[prost(string, tag = "5")] + pub timestamp_format: ::prost::alloc::string::String, + /// Optional time format for time arrays + #[prost(string, tag = "6")] + pub time_format: ::prost::alloc::string::String, + /// Optional value to represent null + #[prost(string, tag = "7")] + pub null_value: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct WriterProperties { #[prost(uint64, tag = "1")] pub data_page_size_limit: u64, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index d137a41fa19b..e997bcde426e 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow::csv::WriterBuilder; use std::collections::HashMap; use std::fmt::Debug; use std::str::FromStr; @@ -64,6 +65,7 @@ use datafusion_expr::{ }; use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_expr::dml::CopyOptions; use prost::bytes::BufMut; @@ -846,6 +848,20 @@ impl AsLogicalPlan for LogicalPlanNode { Some(copy_to_node::CopyOptions::WriterOptions(opt)) => { match &opt.file_type { Some(ft) => match ft { + file_type_writer_options::FileType::CsvOptions( + writer_options, + ) => { + let writer_builder = + csv_writer_options_from_proto(writer_options)?; + CopyOptions::WriterOptions(Box::new( + FileTypeWriterOptions::CSV( + CsvWriterOptions::new( + writer_builder, + CompressionTypeVariant::UNCOMPRESSED, + ), + ), + )) + } file_type_writer_options::FileType::ParquetOptions( writer_options, ) => { @@ -1630,6 +1646,40 @@ impl AsLogicalPlan for LogicalPlanNode { } CopyOptions::WriterOptions(opt) => { match opt.as_ref() { + FileTypeWriterOptions::CSV(csv_opts) => { + let csv_options = &csv_opts.writer_options; + let csv_writer_options = protobuf::CsvWriterOptions { + delimiter: (csv_options.delimiter() as char) + .to_string(), + has_header: csv_options.header(), + date_format: csv_options + .date_format() + .unwrap_or("") + .to_owned(), + datetime_format: csv_options + .datetime_format() + .unwrap_or("") + .to_owned(), + timestamp_format: csv_options + .timestamp_format() + .unwrap_or("") + .to_owned(), + time_format: csv_options + .time_format() + .unwrap_or("") + .to_owned(), + null_value: csv_options.null().to_owned(), + }; + let csv_options = + file_type_writer_options::FileType::CsvOptions( + csv_writer_options, + ); + Some(copy_to_node::CopyOptions::WriterOptions( + protobuf::FileTypeWriterOptions { + file_type: Some(csv_options), + }, + )) + } FileTypeWriterOptions::Parquet(parquet_opts) => { let parquet_writer_options = protobuf::ParquetWriterOptions { @@ -1674,6 +1724,30 @@ impl AsLogicalPlan for LogicalPlanNode { } } +pub(crate) fn csv_writer_options_from_proto( + writer_options: &protobuf::CsvWriterOptions, +) -> Result { + let mut builder = WriterBuilder::new(); + if !writer_options.delimiter.is_empty() { + if let Some(delimiter) = writer_options.delimiter.chars().next() { + if delimiter.is_ascii() { + builder = builder.with_delimiter(delimiter as u8); + } else { + return Err(proto_error("CSV Delimiter is not ASCII")); + } + } else { + return Err(proto_error("Error parsing CSV Delimiter")); + } + } + Ok(builder + .with_header(writer_options.has_header) + .with_date_format(writer_options.date_format.clone()) + .with_datetime_format(writer_options.datetime_format.clone()) + .with_timestamp_format(writer_options.timestamp_format.clone()) + .with_time_format(writer_options.time_format.clone()) + .with_null(writer_options.null_value.clone())) +} + pub(crate) fn writer_properties_to_proto( props: &WriterProperties, ) -> protobuf::WriterProperties { diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 824eb60a5715..6f1e811510c6 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -39,6 +39,7 @@ use datafusion::physical_plan::windows::create_window_expr; use datafusion::physical_plan::{ functions, ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, }; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; @@ -53,7 +54,7 @@ use crate::logical_plan; use crate::protobuf; use crate::protobuf::physical_expr_node::ExprType; -use crate::logical_plan::writer_properties_from_proto; +use crate::logical_plan::{csv_writer_options_from_proto, writer_properties_from_proto}; use chrono::{TimeZone, Utc}; use object_store::path::Path; use object_store::ObjectMeta; @@ -766,11 +767,18 @@ impl TryFrom<&protobuf::FileTypeWriterOptions> for FileTypeWriterOptions { let file_type = value .file_type .as_ref() - .ok_or_else(|| proto_error("Missing required field in protobuf"))?; + .ok_or_else(|| proto_error("Missing required file_type field in protobuf"))?; match file_type { protobuf::file_type_writer_options::FileType::JsonOptions(opts) => Ok( Self::JSON(JsonWriterOptions::new(opts.compression().into())), ), + protobuf::file_type_writer_options::FileType::CsvOptions(opt) => { + let write_options = csv_writer_options_from_proto(opt)?; + Ok(Self::CSV(CsvWriterOptions::new( + write_options, + CompressionTypeVariant::UNCOMPRESSED, + ))) + } protobuf::file_type_writer_options::FileType::ParquetOptions(opt) => { let props = opt.writer_properties.clone().unwrap_or_default(); let writer_properties = writer_properties_from_proto(&props)?; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3eeae01a643e..2d7d85abda96 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -20,6 +20,7 @@ use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; use arrow::array::{ArrayRef, FixedSizeListArray}; +use arrow::csv::WriterBuilder; use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, @@ -35,8 +36,10 @@ use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext}; use datafusion::test_util::{TestTableFactory, TestTableProvider}; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::file_options::StatementOptions; +use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{internal_err, not_impl_err, plan_err, FileTypeWriterOptions}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue}; use datafusion_common::{FileType, Result}; @@ -386,10 +389,69 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { } _ => panic!(), } - Ok(()) } +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { + let ctx = SessionContext::new(); + + let input = create_csv_scan(&ctx).await?; + + let writer_properties = WriterBuilder::new() + .with_delimiter(b'*') + .with_date_format("dd/MM/yyyy".to_string()) + .with_datetime_format("dd/MM/yyyy HH:mm:ss".to_string()) + .with_timestamp_format("HH:mm:ss.SSSSSS".to_string()) + .with_time_format("HH:mm:ss".to_string()) + .with_null("NIL".to_string()); + + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.csv".to_string(), + file_format: FileType::CSV, + single_file_output: true, + copy_options: CopyOptions::WriterOptions(Box::new(FileTypeWriterOptions::CSV( + CsvWriterOptions::new( + writer_properties, + CompressionTypeVariant::UNCOMPRESSED, + ), + ))), + }); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + match logical_round_trip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.csv", copy_to.output_url); + assert_eq!(FileType::CSV, copy_to.file_format); + assert!(copy_to.single_file_output); + match ©_to.copy_options { + CopyOptions::WriterOptions(y) => match y.as_ref() { + FileTypeWriterOptions::CSV(p) => { + let props = &p.writer_options; + assert_eq!(b'*', props.delimiter()); + assert_eq!("dd/MM/yyyy", props.date_format().unwrap()); + assert_eq!( + "dd/MM/yyyy HH:mm:ss", + props.datetime_format().unwrap() + ); + assert_eq!("HH:mm:ss.SSSSSS", props.timestamp_format().unwrap()); + assert_eq!("HH:mm:ss", props.time_format().unwrap()); + assert_eq!("NIL", props.null()); + } + _ => panic!(), + }, + _ => panic!(), + } + } + _ => panic!(), + } + + Ok(()) +} async fn create_csv_scan(ctx: &SessionContext) -> Result { ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) .await?; From b2cbc7809ee0656099169307a73aadff23ab1030 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 28 Dec 2023 15:07:32 -0500 Subject: [PATCH 41/63] Add trait based ScalarUDF API (#8578) * Introduce new trait based ScalarUDF API * change name to `Self::new_from_impl` * Improve documentation, add link to advanced_udf.rs in the user guide * typo * Improve docs for aliases * Apply suggestions from code review Co-authored-by: Liang-Chi Hsieh * improve docs --------- Co-authored-by: Liang-Chi Hsieh --- datafusion-examples/README.md | 3 +- datafusion-examples/examples/advanced_udf.rs | 243 ++++++++++++++++++ datafusion-examples/examples/simple_udf.rs | 6 + datafusion/expr/src/expr.rs | 55 ++-- datafusion/expr/src/expr_fn.rs | 85 +++++- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udf.rs | 169 +++++++++++- .../optimizer/src/analyzer/type_coercion.rs | 64 ++--- docs/source/library-user-guide/adding-udfs.md | 9 +- 9 files changed, 562 insertions(+), 74 deletions(-) create mode 100644 datafusion-examples/examples/advanced_udf.rs diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 057cdd475273..1296c74ea277 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -59,8 +59,9 @@ cargo run --example csv_sql - [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 - [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP - [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass +- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) +- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) -- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined (scalar) Function (UDF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) ## Distributed diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs new file mode 100644 index 000000000000..6ebf88a0b671 --- /dev/null +++ b/datafusion-examples/examples/advanced_udf.rs @@ -0,0 +1,243 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{ + arrow::{ + array::{ArrayRef, Float32Array, Float64Array}, + datatypes::DataType, + record_batch::RecordBatch, + }, + logical_expr::Volatility, +}; +use std::any::Any; + +use arrow::array::{new_null_array, Array, AsArray}; +use arrow::compute; +use arrow::datatypes::Float64Type; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::{internal_err, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature}; +use std::sync::Arc; + +/// This example shows how to use the full ScalarUDFImpl API to implement a user +/// defined function. As in the `simple_udf.rs` example, this struct implements +/// a function that takes two arguments and returns the first argument raised to +/// the power of the second argument `a^b`. +/// +/// To do so, we must implement the `ScalarUDFImpl` trait. +struct PowUdf { + signature: Signature, + aliases: Vec, +} + +impl PowUdf { + /// Create a new instance of the `PowUdf` struct + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take two arguments of type f64 + vec![DataType::Float64, DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + // we will also add an alias of "my_pow" + aliases: vec!["my_pow".to_string()], + } + } +} + +impl ScalarUDFImpl for PowUdf { + /// We implement as_any so that we can downcast the ScalarUDFImpl trait object + fn as_any(&self) -> &dyn Any { + self + } + + /// Return the name of this function + fn name(&self) -> &str { + "pow" + } + + /// Return the "signature" of this function -- namely what types of arguments it will take + fn signature(&self) -> &Signature { + &self.signature + } + + /// What is the type of value that will be returned by this function? In + /// this case it will always be a constant value, but it could also be a + /// function of the input types. + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + /// This is the function that actually calculates the results. + /// + /// This is the same way that functions built into DataFusion are invoked, + /// which permits important special cases when one or both of the arguments + /// are single values (constants). For example `pow(a, 2)` + /// + /// However, it also means the implementation is more complex than when + /// using `create_udf`. + fn invoke(&self, args: &[ColumnarValue]) -> Result { + // DataFusion has arranged for the correct inputs to be passed to this + // function, but we check again to make sure + assert_eq!(args.len(), 2); + let (base, exp) = (&args[0], &args[1]); + assert_eq!(base.data_type(), DataType::Float64); + assert_eq!(exp.data_type(), DataType::Float64); + + match (base, exp) { + // For demonstration purposes we also implement the scalar / scalar + // case here, but it is not typically required for high performance. + // + // For performance it is most important to optimize cases where at + // least one argument is an array. If all arguments are constants, + // the DataFusion expression simplification logic will often invoke + // this path once during planning, and simply use the result during + // execution. + ( + ColumnarValue::Scalar(ScalarValue::Float64(base)), + ColumnarValue::Scalar(ScalarValue::Float64(exp)), + ) => { + // compute the output. Note DataFusion treats `None` as NULL. + let res = match (base, exp) { + (Some(base), Some(exp)) => Some(base.powf(*exp)), + // one or both arguments were NULL + _ => None, + }; + Ok(ColumnarValue::Scalar(ScalarValue::from(res))) + } + // special case if the exponent is a constant + ( + ColumnarValue::Array(base_array), + ColumnarValue::Scalar(ScalarValue::Float64(exp)), + ) => { + let result_array = match exp { + // a ^ null = null + None => new_null_array(base_array.data_type(), base_array.len()), + // a ^ exp + Some(exp) => { + // DataFusion has ensured both arguments are Float64: + let base_array = base_array.as_primitive::(); + // calculate the result for every row. The `unary` + // kernel creates very fast "vectorized" code and + // handles things like null values for us. + let res: Float64Array = + compute::unary(base_array, |base| base.powf(*exp)); + Arc::new(res) + } + }; + Ok(ColumnarValue::Array(result_array)) + } + + // special case if the base is a constant (note this code is quite + // similar to the previous case, so we omit comments) + ( + ColumnarValue::Scalar(ScalarValue::Float64(base)), + ColumnarValue::Array(exp_array), + ) => { + let res = match base { + None => new_null_array(exp_array.data_type(), exp_array.len()), + Some(base) => { + let exp_array = exp_array.as_primitive::(); + let res: Float64Array = + compute::unary(exp_array, |exp| base.powf(exp)); + Arc::new(res) + } + }; + Ok(ColumnarValue::Array(res)) + } + // Both arguments are arrays so we have to perform the calculation for every row + (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => { + let res: Float64Array = compute::binary( + base_array.as_primitive::(), + exp_array.as_primitive::(), + |base, exp| base.powf(exp), + )?; + Ok(ColumnarValue::Array(Arc::new(res))) + } + // if the types were not float, it is a bug in DataFusion + _ => { + use datafusion_common::DataFusionError; + internal_err!("Invalid argument types to pow function") + } + } + } + + /// We will also add an alias of "my_pow" + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// In this example we register `PowUdf` as a user defined function +/// and invoke it via the DataFrame API and SQL +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context()?; + + // create the UDF + let pow = ScalarUDF::from(PowUdf::new()); + + // register the UDF with the context so it can be invoked by name and from SQL + ctx.register_udf(pow.clone()); + + // get a DataFrame from the context for scanning the "t" table + let df = ctx.table("t").await?; + + // Call pow(a, 10) using the DataFrame API + let df = df.select(vec![pow.call(vec![col("a"), lit(10i32)])])?; + + // note that the second argument is passed as an i32, not f64. DataFusion + // automatically coerces the types to match the UDF's defined signature. + + // print the results + df.show().await?; + + // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL + let sql_df = ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t").await?; + sql_df.show().await?; + + Ok(()) +} + +/// create local execution context with an in-memory table: +/// +/// ```text +/// +-----+-----+ +/// | a | b | +/// +-----+-----+ +/// | 2.1 | 1.0 | +/// | 3.1 | 2.0 | +/// | 4.1 | 3.0 | +/// | 5.1 | 4.0 | +/// +-----+-----+ +/// ``` +fn create_context() -> Result { + // define data. + let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); + let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; + + // declare a new context. In Spark API, this corresponds to a new SparkSession + let ctx = SessionContext::new(); + + // declare a table in memory. In Spark API, this corresponds to createDataFrame(...). + ctx.register_batch("t", batch)?; + Ok(ctx) +} diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index 591991786515..39e1e13ce39a 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -140,5 +140,11 @@ async fn main() -> Result<()> { // print the results df.show().await?; + // Given that `pow` is registered in the context, we can also use it in SQL: + let sql_df = ctx.sql("SELECT pow(a, b) FROM t").await?; + + // print the results + sql_df.show().await?; + Ok(()) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index b46e9ec8f69d..0ec19bcadbf6 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1724,13 +1724,13 @@ mod test { use crate::expr::Cast; use crate::expr_fn::col; use crate::{ - case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ReturnTypeFunction, - ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, Signature, - Volatility, + case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ScalarFunctionDefinition, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use arrow::datatypes::DataType; use datafusion_common::Column; use datafusion_common::{Result, ScalarValue}; + use std::any::Any; use std::sync::Arc; #[test] @@ -1848,24 +1848,41 @@ mod test { ); // UDF - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); - let fun: ScalarFunctionImplementation = - Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a")))); - let udf = Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), - &return_type, - &fun, - )); + struct TestScalarUDF { + signature: Signature, + } + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "TestScalarUDF" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) + } + } + let udf = Arc::new(ScalarUDF::from(TestScalarUDF { + signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + })); assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); - let udf = Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Float32], Volatility::Volatile), - &return_type, - &fun, - )); + let udf = Arc::new(ScalarUDF::from(TestScalarUDF { + signature: Signature::uniform( + 1, + vec![DataType::Float32], + Volatility::Volatile, + ), + })); assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); // Unresolved function diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index cedf1d845137..eed41d97ccba 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -22,15 +22,16 @@ use crate::expr::{ Placeholder, ScalarFunction, TryCast, }; use crate::function::PartitionEvaluatorFactory; -use crate::WindowUDF; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, }; +use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF}; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; +use std::any::Any; use std::ops::Not; use std::sync::Arc; @@ -944,11 +945,18 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder { CaseBuilder::new(None, vec![when], vec![then], None) } -/// Creates a new UDF with a specific signature and specific return type. -/// This is a helper function to create a new UDF. -/// The function `create_udf` returns a subset of all possible `ScalarFunction`: -/// * the UDF has a fixed return type -/// * the UDF has a fixed signature (e.g. [f64, f64]) +/// Convenience method to create a new user defined scalar function (UDF) with a +/// specific signature and specific return type. +/// +/// Note this function does not expose all available features of [`ScalarUDF`], +/// such as +/// +/// * computing return types based on input types +/// * multiple [`Signature`]s +/// * aliases +/// +/// See [`ScalarUDF`] for details and examples on how to use the full +/// functionality. pub fn create_udf( name: &str, input_types: Vec, @@ -956,13 +964,66 @@ pub fn create_udf( volatility: Volatility, fun: ScalarFunctionImplementation, ) -> ScalarUDF { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - ScalarUDF::new( + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + ScalarUDF::from(SimpleScalarUDF::new( name, - &Signature::exact(input_types, volatility), - &return_type, - &fun, - ) + input_types, + return_type, + volatility, + fun, + )) +} + +/// Implements [`ScalarUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleScalarUDF { + name: String, + signature: Signature, + return_type: DataType, + fun: ScalarFunctionImplementation, +} + +impl SimpleScalarUDF { + /// Create a new `SimpleScalarUDF` from a name, input types, return type and + /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_types: Vec, + return_type: DataType, + volatility: Volatility, + fun: ScalarFunctionImplementation, + ) -> Self { + let name = name.into(); + let signature = Signature::exact(input_types, volatility); + Self { + name, + signature, + return_type, + fun, + } + } +} + +impl ScalarUDFImpl for SimpleScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + (self.fun)(args) + } } /// Creates a new UDAF with a specific signature, state type and return type. diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 48532e13dcd7..bf8e9e2954f4 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -80,7 +80,7 @@ pub use signature::{ }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; -pub use udf::ScalarUDF; +pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::WindowUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 3a18ca2d25e8..2ec80a4a9ea1 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,9 +17,12 @@ //! [`ScalarUDF`]: Scalar User Defined Functions -use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use crate::{ + ColumnarValue, Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature, +}; use arrow::datatypes::DataType; use datafusion_common::Result; +use std::any::Any; use std::fmt; use std::fmt::Debug; use std::fmt::Formatter; @@ -27,11 +30,19 @@ use std::sync::Arc; /// Logical representation of a Scalar User Defined Function. /// -/// A scalar function produces a single row output for each row of input. +/// A scalar function produces a single row output for each row of input. This +/// struct contains the information DataFusion needs to plan and invoke +/// functions you supply such name, type signature, return type, and actual +/// implementation. /// -/// This struct contains the information DataFusion needs to plan and invoke -/// functions such name, type signature, return type, and actual implementation. /// +/// 1. For simple (less performant) use cases, use [`create_udf`] and [`simple_udf.rs`]. +/// +/// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`]. +/// +/// [`create_udf`]: crate::expr_fn::create_udf +/// [`simple_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs +/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs #[derive(Clone)] pub struct ScalarUDF { /// The name of the function @@ -79,7 +90,11 @@ impl std::hash::Hash for ScalarUDF { } impl ScalarUDF { - /// Create a new ScalarUDF + /// Create a new ScalarUDF from low level details. + /// + /// See [`ScalarUDFImpl`] for a more convenient way to create a + /// `ScalarUDF` using trait objects + #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl instead")] pub fn new( name: &str, signature: &Signature, @@ -95,6 +110,34 @@ impl ScalarUDF { } } + /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object + /// + /// Note this is the same as using the `From` impl (`ScalarUDF::from`) + pub fn new_from_impl(fun: F) -> ScalarUDF + where + F: ScalarUDFImpl + Send + Sync + 'static, + { + // TODO change the internal implementation to use the trait object + let arc_fun = Arc::new(fun); + let captured_self = arc_fun.clone(); + let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { + let return_type = captured_self.return_type(arg_types)?; + Ok(Arc::new(return_type)) + }); + + let captured_self = arc_fun.clone(); + let func: ScalarFunctionImplementation = + Arc::new(move |args| captured_self.invoke(args)); + + Self { + name: arc_fun.name().to_string(), + signature: arc_fun.signature().clone(), + return_type: return_type.clone(), + fun: func, + aliases: arc_fun.aliases().to_vec(), + } + } + /// Adds additional names that can be used to invoke this function, in addition to `name` pub fn with_aliases( mut self, @@ -105,7 +148,9 @@ impl ScalarUDF { self } - /// creates a logical expression with a call of the UDF + /// Returns a [`Expr`] logical expression to call this UDF with specified + /// arguments. + /// /// This utility allows using the UDF without requiring access to the registry. pub fn call(&self, args: Vec) -> Expr { Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf( @@ -124,22 +169,126 @@ impl ScalarUDF { &self.aliases } - /// Returns this function's signature (what input types are accepted) + /// Returns this function's [`Signature`] (what input types are accepted) pub fn signature(&self) -> &Signature { &self.signature } - /// Return the type of the function given its input types + /// The datatype this function returns given the input argument input types pub fn return_type(&self, args: &[DataType]) -> Result { // Old API returns an Arc of the datatype for some reason let res = (self.return_type)(args)?; Ok(res.as_ref().clone()) } - /// Return the actual implementation + /// Return an [`Arc`] to the function implementation pub fn fun(&self) -> ScalarFunctionImplementation { self.fun.clone() } +} - // TODO maybe add an invoke() method that runs the actual function? +impl From for ScalarUDF +where + F: ScalarUDFImpl + Send + Sync + 'static, +{ + fn from(fun: F) -> Self { + Self::new_from_impl(fun) + } +} + +/// Trait for implementing [`ScalarUDF`]. +/// +/// This trait exposes the full API for implementing user defined functions and +/// can be used to implement any function. +/// +/// See [`advanced_udf.rs`] for a full example with complete implementation and +/// [`ScalarUDF`] for other available options. +/// +/// +/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; +/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; +/// struct AddOne { +/// signature: Signature +/// }; +/// +/// impl AddOne { +/// fn new() -> Self { +/// Self { +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) +/// } +/// } +/// } +/// +/// /// Implement the ScalarUDFImpl trait for AddOne +/// impl ScalarUDFImpl for AddOne { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "add_one" } +/// fn signature(&self) -> &Signature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Int32)) { +/// return plan_err!("add_one only accepts Int32 arguments"); +/// } +/// Ok(DataType::Int32) +/// } +/// // The actual implementation would add one to the argument +/// fn invoke(&self, args: &[ColumnarValue]) -> Result { unimplemented!() } +/// } +/// +/// // Create a new ScalarUDF from the implementation +/// let add_one = ScalarUDF::from(AddOne::new()); +/// +/// // Call the function `add_one(col)` +/// let expr = add_one.call(vec![col("a")]); +/// ``` +pub trait ScalarUDFImpl { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns the function's [`Signature`] for information about what input + /// types are accepted and the function's Volatility. + fn signature(&self) -> &Signature; + + /// What [`DataType`] will be returned by this function, given the types of + /// the arguments + fn return_type(&self, arg_types: &[DataType]) -> Result; + + /// Invoke the function on `args`, returning the appropriate result + /// + /// The function will be invoked passed with the slice of [`ColumnarValue`] + /// (either scalar or array). + /// + /// # Zero Argument Functions + /// If the function has zero parameters (e.g. `now()`) it will be passed a + /// single element slice which is a a null array to indicate the batch's row + /// count (so the function can know the resulting array size). + /// + /// # Performance + /// + /// For the best performance, the implementations of `invoke` should handle + /// the common case when one or more of their arguments are constant values + /// (aka [`ColumnarValue::Scalar`]). Calling [`ColumnarValue::into_array`] + /// and treating all arguments as arrays will work, but will be slower. + fn invoke(&self, args: &[ColumnarValue]) -> Result; + + /// Returns any aliases (alternate names) for this function. + /// + /// Aliases can be used to invoke the same function using different names. + /// For example in some databases `now()` and `current_timestamp()` are + /// aliases for the same function. This behavior can be obtained by + /// returning `current_timestamp` as an alias for the `now` function. + /// + /// Note: `aliases` should only include names other than [`Self::name`]. + /// Defaults to `[]` (no aliases) + fn aliases(&self) -> &[String] { + &[] + } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index c5e1180b9f97..b6298f5b552f 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -738,7 +738,8 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { #[cfg(test)] mod test { - use std::sync::Arc; + use std::any::Any; + use std::sync::{Arc, OnceLock}; use arrow::array::{FixedSizeListArray, Int32Array}; use arrow::datatypes::{DataType, TimeUnit}; @@ -750,13 +751,13 @@ mod test { use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case, - ColumnarValue, ExprSchemable, Filter, Operator, StateTypeFunction, Subquery, + ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl, StateTypeFunction, + Subquery, }; use datafusion_expr::{ lit, logical_plan::{EmptyRelation, Projection}, - Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, - Signature, Volatility, + Expr, LogicalPlan, ReturnTypeFunction, ScalarUDF, Signature, Volatility, }; use datafusion_physical_expr::expressions::AvgAccumulator; @@ -808,22 +809,36 @@ mod test { assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) } + static TEST_SIGNATURE: OnceLock = OnceLock::new(); + + struct TestScalarUDF {} + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "TestScalarUDF" + } + fn signature(&self) -> &Signature { + TEST_SIGNATURE.get_or_init(|| { + Signature::uniform(1, vec![DataType::Float32], Volatility::Stable) + }) + } + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) + } + } + #[test] fn scalar_udf() -> Result<()> { let empty = empty(); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); - let fun: ScalarFunctionImplementation = - Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a")))); - let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf( - Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), - &return_type, - &fun, - )), - vec![lit(123_i32)], - )); + + let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit(123_i32)]); let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); let expected = "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation"; @@ -833,24 +848,13 @@ mod test { #[test] fn scalar_udf_invalid_input() -> Result<()> { let empty = empty(); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); - let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!()); - let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf( - Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Int32], Volatility::Stable), - &return_type, - &fun, - )), - vec![lit("Apple")], - )); + let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit("Apple")]); let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "") .err() .unwrap(); assert_eq!( - "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Int32]) failed.", + "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Float32]) failed.", err.strip_backtrace() ); Ok(()) diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index 11cf52eb3fcf..c51e4de3236c 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -76,7 +76,9 @@ The challenge however is that DataFusion doesn't know about this function. We ne ### Registering a Scalar UDF -To register a Scalar UDF, you need to wrap the function implementation in a `ScalarUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udf` and `make_scalar_function` helper functions to make this easier. +To register a Scalar UDF, you need to wrap the function implementation in a [`ScalarUDF`] struct and then register it with the `SessionContext`. +DataFusion provides the [`create_udf`] and helper functions to make this easier. +There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udf.rs`]. ```rust use datafusion::logical_expr::{Volatility, create_udf}; @@ -93,6 +95,11 @@ let udf = create_udf( ); ``` +[`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html +[`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html +[`make_scalar_function`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.make_scalar_function.html +[`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs + A few things to note: - The first argument is the name of the function. This is the name that will be used in SQL queries. From 06ed3dd1ac01b1bd6a70b93b56cb72cb40777690 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Thu, 28 Dec 2023 23:34:40 +0300 Subject: [PATCH 42/63] Handle ordering of first last aggregation inside aggregator (#8662) * Initial commit * Update tests in distinct_on * Update group by joins slt * Remove unused code * Minor changes * Minor changes * Simplifications * Update comments * Review * Fix clippy --------- Co-authored-by: Mehmet Ozan Kabak --- datafusion-cli/src/functions.rs | 2 +- datafusion/common/src/error.rs | 1 - .../physical_optimizer/projection_pushdown.rs | 4 + .../src/simplify_expressions/guarantees.rs | 4 + .../physical-expr/src/aggregate/first_last.rs | 131 +++-- datafusion/physical-expr/src/aggregate/mod.rs | 30 +- .../physical-expr/src/aggregate/utils.rs | 18 +- .../physical-expr/src/array_expressions.rs | 2 +- .../physical-plan/src/aggregates/mod.rs | 461 ++++++++---------- .../src/engines/datafusion_engine/mod.rs | 1 - .../sqllogictest/test_files/distinct_on.slt | 9 +- .../sqllogictest/test_files/groupby.slt | 82 ++-- datafusion/sqllogictest/test_files/joins.slt | 4 +- 13 files changed, 373 insertions(+), 376 deletions(-) diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index f8d9ed238be4..5390fa9f2271 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -297,7 +297,7 @@ pub struct ParquetMetadataFunc {} impl TableFunctionImpl for ParquetMetadataFunc { fn call(&self, exprs: &[Expr]) -> Result> { - let filename = match exprs.get(0) { + let filename = match exprs.first() { Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") _ => { diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 515acc6d1c47..e58faaa15096 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -558,7 +558,6 @@ macro_rules! arrow_err { // To avoid compiler error when using macro in the same crate: // macros from the current crate cannot be referred to by absolute paths -pub use exec_err as _exec_err; pub use internal_datafusion_err as _internal_datafusion_err; pub use internal_err as _internal_err; pub use not_impl_err as _not_impl_err; diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 7e1312dad23e..d237a3e8607e 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -990,6 +990,10 @@ fn update_join_on( proj_right_exprs: &[(Column, String)], hash_join_on: &[(Column, Column)], ) -> Option> { + // TODO: Clippy wants the "map" call removed, but doing so generates + // a compilation error. Remove the clippy directive once this + // issue is fixed. + #[allow(clippy::map_identity)] let (left_idx, right_idx): (Vec<_>, Vec<_>) = hash_join_on .iter() .map(|(left, right)| (left, right)) diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 860dc326b9b0..aa7bb4f78a93 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -47,6 +47,10 @@ impl<'a> GuaranteeRewriter<'a> { guarantees: impl IntoIterator, ) -> Self { Self { + // TODO: Clippy wants the "map" call removed, but doing so generates + // a compilation error. Remove the clippy directive once this + // issue is fixed. + #[allow(clippy::map_identity)] guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), } } diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index c009881d8918..c7032e601cf8 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -20,7 +20,7 @@ use std::any::Any; use std::sync::Arc; -use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; +use crate::aggregate::utils::{down_cast_any_ref, get_sort_options, ordering_fields}; use crate::expressions::format_state_name; use crate::{ reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, @@ -29,9 +29,10 @@ use crate::{ use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; use arrow::compute::{self, lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field}; -use arrow_schema::SortOptions; use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; -use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::Accumulator; /// FIRST_VALUE aggregate expression @@ -211,10 +212,45 @@ impl FirstValueAccumulator { } // Updates state with the values in the given row. - fn update_with_new_row(&mut self, row: &[ScalarValue]) { - self.first = row[0].clone(); - self.orderings = row[1..].to_vec(); - self.is_set = true; + fn update_with_new_row(&mut self, row: &[ScalarValue]) -> Result<()> { + let [value, orderings @ ..] = row else { + return internal_err!("Empty row in FIRST_VALUE"); + }; + // Update when there is no entry in the state, or we have an "earlier" + // entry according to sort requirements. + if !self.is_set + || compare_rows( + &self.orderings, + orderings, + &get_sort_options(&self.ordering_req), + )? + .is_gt() + { + self.first = value.clone(); + self.orderings = orderings.to_vec(); + self.is_set = true; + } + Ok(()) + } + + fn get_first_idx(&self, values: &[ArrayRef]) -> Result> { + let [value, ordering_values @ ..] = values else { + return internal_err!("Empty row in FIRST_VALUE"); + }; + if self.ordering_req.is_empty() { + // Get first entry according to receive order (0th index) + return Ok((!value.is_empty()).then_some(0)); + } + let sort_columns = ordering_values + .iter() + .zip(self.ordering_req.iter()) + .map(|(values, req)| SortColumn { + values: values.clone(), + options: Some(req.options), + }) + .collect::>(); + let indices = lexsort_to_indices(&sort_columns, Some(1))?; + Ok((!indices.is_empty()).then_some(indices.value(0) as _)) } } @@ -227,11 +263,9 @@ impl Accumulator for FirstValueAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - // If we have seen first value, we shouldn't update it - if !values[0].is_empty() && !self.is_set { - let row = get_row_at_idx(values, 0)?; - // Update with first value in the array. - self.update_with_new_row(&row); + if let Some(first_idx) = self.get_first_idx(values)? { + let row = get_row_at_idx(values, first_idx)?; + self.update_with_new_row(&row)?; } Ok(()) } @@ -265,7 +299,7 @@ impl Accumulator for FirstValueAccumulator { // Update with first value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state // containing two is_set flags. - self.update_with_new_row(&first_row[0..is_set_idx]); + self.update_with_new_row(&first_row[0..is_set_idx])?; } } Ok(()) @@ -459,10 +493,50 @@ impl LastValueAccumulator { } // Updates state with the values in the given row. - fn update_with_new_row(&mut self, row: &[ScalarValue]) { - self.last = row[0].clone(); - self.orderings = row[1..].to_vec(); - self.is_set = true; + fn update_with_new_row(&mut self, row: &[ScalarValue]) -> Result<()> { + let [value, orderings @ ..] = row else { + return internal_err!("Empty row in LAST_VALUE"); + }; + // Update when there is no entry in the state, or we have a "later" + // entry (either according to sort requirements or the order of execution). + if !self.is_set + || self.orderings.is_empty() + || compare_rows( + &self.orderings, + orderings, + &get_sort_options(&self.ordering_req), + )? + .is_lt() + { + self.last = value.clone(); + self.orderings = orderings.to_vec(); + self.is_set = true; + } + Ok(()) + } + + fn get_last_idx(&self, values: &[ArrayRef]) -> Result> { + let [value, ordering_values @ ..] = values else { + return internal_err!("Empty row in LAST_VALUE"); + }; + if self.ordering_req.is_empty() { + // Get last entry according to the order of data: + return Ok((!value.is_empty()).then_some(value.len() - 1)); + } + let sort_columns = ordering_values + .iter() + .zip(self.ordering_req.iter()) + .map(|(values, req)| { + // Take the reverse ordering requirement. This enables us to + // use "fetch = 1" to get the last value. + SortColumn { + values: values.clone(), + options: Some(!req.options), + } + }) + .collect::>(); + let indices = lexsort_to_indices(&sort_columns, Some(1))?; + Ok((!indices.is_empty()).then_some(indices.value(0) as _)) } } @@ -475,10 +549,9 @@ impl Accumulator for LastValueAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if !values[0].is_empty() { - let row = get_row_at_idx(values, values[0].len() - 1)?; - // Update with last value in the array. - self.update_with_new_row(&row); + if let Some(last_idx) = self.get_last_idx(values)? { + let row = get_row_at_idx(values, last_idx)?; + self.update_with_new_row(&row)?; } Ok(()) } @@ -515,7 +588,7 @@ impl Accumulator for LastValueAccumulator { // Update with last value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state // containing two is_set flags. - self.update_with_new_row(&last_row[0..is_set_idx]); + self.update_with_new_row(&last_row[0..is_set_idx])?; } } Ok(()) @@ -559,26 +632,18 @@ fn convert_to_sort_cols( .collect::>() } -/// Selects the sort option attribute from all the given `PhysicalSortExpr`s. -fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec { - ordering_req - .iter() - .map(|item| item.options) - .collect::>() -} - #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::aggregate::first_last::{FirstValueAccumulator, LastValueAccumulator}; + use arrow::compute::concat; use arrow_array::{ArrayRef, Int64Array}; use arrow_schema::DataType; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; - use arrow::compute::concat; - use std::sync::Arc; - #[test] fn test_first_last_value_value() -> Result<()> { let mut first_accumulator = diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 329bb1e6415e..5bd1fca385b1 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -15,16 +15,20 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::{FirstValue, LastValue, OrderSensitiveArrayAgg}; -use crate::{PhysicalExpr, PhysicalSortExpr}; -use arrow::datatypes::Field; -use datafusion_common::{not_impl_err, DataFusionError, Result}; -use datafusion_expr::Accumulator; use std::any::Any; use std::fmt::Debug; use std::sync::Arc; use self::groups_accumulator::GroupsAccumulator; +use crate::expressions::OrderSensitiveArrayAgg; +use crate::{PhysicalExpr, PhysicalSortExpr}; + +use arrow::datatypes::Field; +use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_expr::Accumulator; + +mod hyperloglog; +mod tdigest; pub(crate) mod approx_distinct; pub(crate) mod approx_median; @@ -46,19 +50,18 @@ pub(crate) mod median; pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; -pub mod build_in; pub(crate) mod groups_accumulator; -mod hyperloglog; -pub mod moving_min_max; pub(crate) mod regr; pub(crate) mod stats; pub(crate) mod stddev; pub(crate) mod sum; pub(crate) mod sum_distinct; -mod tdigest; -pub mod utils; pub(crate) mod variance; +pub mod build_in; +pub mod moving_min_max; +pub mod utils; + /// An aggregate expression that: /// * knows its resulting field /// * knows how to create its accumulator @@ -134,10 +137,7 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { /// Checks whether the given aggregate expression is order-sensitive. /// For instance, a `SUM` aggregation doesn't depend on the order of its inputs. -/// However, a `FirstValue` depends on the input ordering (if the order changes, -/// the first value in the list would change). +/// However, an `ARRAY_AGG` with `ORDER BY` depends on the input ordering. pub fn is_order_sensitive(aggr_expr: &Arc) -> bool { - aggr_expr.as_any().is::() - || aggr_expr.as_any().is::() - || aggr_expr.as_any().is::() + aggr_expr.as_any().is::() } diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index e5421ef5ab7e..9777158da133 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -17,20 +17,21 @@ //! Utilities used in aggregates +use std::any::Any; +use std::sync::Arc; + use crate::{AggregateExpr, PhysicalSortExpr}; -use arrow::array::ArrayRef; + +use arrow::array::{ArrayRef, ArrowNativeTypeOp}; use arrow_array::cast::AsArray; use arrow_array::types::{ Decimal128Type, DecimalType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use arrow_array::ArrowNativeTypeOp; use arrow_buffer::ArrowNativeType; -use arrow_schema::{DataType, Field}; +use arrow_schema::{DataType, Field, SortOptions}; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; /// Convert scalar values from an accumulator into arrays. pub fn get_accum_scalar_values_as_arrays( @@ -40,7 +41,7 @@ pub fn get_accum_scalar_values_as_arrays( .state()? .iter() .map(|s| s.to_array_of_size(1)) - .collect::>>() + .collect() } /// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow @@ -205,3 +206,8 @@ pub(crate) fn ordering_fields( }) .collect() } + +/// Selects the sort option attribute from all the given `PhysicalSortExpr`s. +pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec { + ordering_req.iter().map(|item| item.options).collect() +} diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 274d1db4eb0d..7a986810bad2 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -2453,7 +2453,7 @@ pub fn general_array_distinct( let last_offset: OffsetSize = offsets.last().copied().unwrap(); offsets.push(last_offset + OffsetSize::usize_as(rows.len())); let arrays = converter.convert_rows(rows)?; - let array = match arrays.get(0) { + let array = match arrays.first() { Some(array) => array.clone(), None => { return internal_err!("array_distinct: failed to get array from rows") diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index f779322456ca..f5bb4fe59b5d 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -27,7 +27,7 @@ use crate::aggregates::{ }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::windows::{get_ordered_partition_by_indices, get_window_mode}; +use crate::windows::get_ordered_partition_by_indices; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, Partitioning, SendableRecordBatchStream, Statistics, @@ -45,11 +45,11 @@ use datafusion_physical_expr::{ aggregate::is_order_sensitive, equivalence::{collapse_lex_req, ProjectionMapping}, expressions::{Column, Max, Min, UnKnownColumn}, - physical_exprs_contains, reverse_order_bys, AggregateExpr, EquivalenceProperties, - LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, + physical_exprs_contains, AggregateExpr, EquivalenceProperties, LexOrdering, + LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; -use itertools::{izip, Itertools}; +use itertools::Itertools; mod group_values; mod no_grouping; @@ -277,159 +277,6 @@ pub struct AggregateExec { output_ordering: Option, } -/// This function returns the ordering requirement of the first non-reversible -/// order-sensitive aggregate function such as ARRAY_AGG. This requirement serves -/// as the initial requirement while calculating the finest requirement among all -/// aggregate functions. If this function returns `None`, it means there is no -/// hard ordering requirement for the aggregate functions (in terms of direction). -/// Then, we can generate two alternative requirements with opposite directions. -fn get_init_req( - aggr_expr: &[Arc], - order_by_expr: &[Option], -) -> Option { - for (aggr_expr, fn_reqs) in aggr_expr.iter().zip(order_by_expr.iter()) { - // If the aggregation function is a non-reversible order-sensitive function - // and there is a hard requirement, choose first such requirement: - if is_order_sensitive(aggr_expr) - && aggr_expr.reverse_expr().is_none() - && fn_reqs.is_some() - { - return fn_reqs.clone(); - } - } - None -} - -/// This function gets the finest ordering requirement among all the aggregation -/// functions. If requirements are conflicting, (i.e. we can not compute the -/// aggregations in a single [`AggregateExec`]), the function returns an error. -fn get_finest_requirement( - aggr_expr: &mut [Arc], - order_by_expr: &mut [Option], - eq_properties: &EquivalenceProperties, -) -> Result> { - // First, we check if all the requirements are satisfied by the existing - // ordering. If so, we return `None` to indicate this. - let mut all_satisfied = true; - for (aggr_expr, fn_req) in aggr_expr.iter_mut().zip(order_by_expr.iter_mut()) { - if eq_properties.ordering_satisfy(fn_req.as_deref().unwrap_or(&[])) { - continue; - } - if let Some(reverse) = aggr_expr.reverse_expr() { - let reverse_req = fn_req.as_ref().map(|item| reverse_order_bys(item)); - if eq_properties.ordering_satisfy(reverse_req.as_deref().unwrap_or(&[])) { - // We need to update `aggr_expr` with its reverse since only its - // reverse requirement is compatible with the existing requirements: - *aggr_expr = reverse; - *fn_req = reverse_req; - continue; - } - } - // Requirement is not satisfied: - all_satisfied = false; - } - if all_satisfied { - // All of the requirements are already satisfied. - return Ok(None); - } - let mut finest_req = get_init_req(aggr_expr, order_by_expr); - for (aggr_expr, fn_req) in aggr_expr.iter_mut().zip(order_by_expr.iter_mut()) { - let Some(fn_req) = fn_req else { - continue; - }; - - if let Some(finest_req) = &mut finest_req { - if let Some(finer) = eq_properties.get_finer_ordering(finest_req, fn_req) { - *finest_req = finer; - continue; - } - // If an aggregate function is reversible, analyze whether its reverse - // direction is compatible with existing requirements: - if let Some(reverse) = aggr_expr.reverse_expr() { - let fn_req_reverse = reverse_order_bys(fn_req); - if let Some(finer) = - eq_properties.get_finer_ordering(finest_req, &fn_req_reverse) - { - // We need to update `aggr_expr` with its reverse, since only its - // reverse requirement is compatible with existing requirements: - *aggr_expr = reverse; - *finest_req = finer; - *fn_req = fn_req_reverse; - continue; - } - } - // If neither of the requirements satisfy the other, this means - // requirements are conflicting. Currently, we do not support - // conflicting requirements. - return not_impl_err!( - "Conflicting ordering requirements in aggregate functions is not supported" - ); - } else { - finest_req = Some(fn_req.clone()); - } - } - Ok(finest_req) -} - -/// Calculates search_mode for the aggregation -fn get_aggregate_search_mode( - group_by: &PhysicalGroupBy, - input: &Arc, - aggr_expr: &mut [Arc], - order_by_expr: &mut [Option], - ordering_req: &mut Vec, -) -> InputOrderMode { - let groupby_exprs = group_by - .expr - .iter() - .map(|(item, _)| item.clone()) - .collect::>(); - let mut input_order_mode = InputOrderMode::Linear; - if !group_by.is_single() || groupby_exprs.is_empty() { - return input_order_mode; - } - - if let Some((should_reverse, mode)) = - get_window_mode(&groupby_exprs, ordering_req, input) - { - let all_reversible = aggr_expr - .iter() - .all(|expr| !is_order_sensitive(expr) || expr.reverse_expr().is_some()); - if should_reverse && all_reversible { - izip!(aggr_expr.iter_mut(), order_by_expr.iter_mut()).for_each( - |(aggr, order_by)| { - if let Some(reverse) = aggr.reverse_expr() { - *aggr = reverse; - } else { - unreachable!(); - } - *order_by = order_by.as_ref().map(|ob| reverse_order_bys(ob)); - }, - ); - *ordering_req = reverse_order_bys(ordering_req); - } - input_order_mode = mode; - } - input_order_mode -} - -/// Check whether group by expression contains all of the expression inside `requirement` -// As an example Group By (c,b,a) contains all of the expressions in the `requirement`: (a ASC, b DESC) -fn group_by_contains_all_requirements( - group_by: &PhysicalGroupBy, - requirement: &LexOrdering, -) -> bool { - let physical_exprs = group_by.input_exprs(); - // When we have multiple groups (grouping set) - // since group by may be calculated on the subset of the group_by.expr() - // it is not guaranteed to have all of the requirements among group by expressions. - // Hence do the analysis: whether group by contains all requirements in the single group case. - group_by.is_single() - && requirement - .iter() - .all(|req| physical_exprs_contains(&physical_exprs, &req.expr)) -} - impl AggregateExec { /// Create a new hash aggregate execution plan pub fn try_new( @@ -477,50 +324,14 @@ impl AggregateExec { fn try_new_with_schema( mode: AggregateMode, group_by: PhysicalGroupBy, - mut aggr_expr: Vec>, + aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, schema: SchemaRef, original_schema: SchemaRef, ) -> Result { - // Reset ordering requirement to `None` if aggregator is not order-sensitive - let mut order_by_expr = aggr_expr - .iter() - .map(|aggr_expr| { - let fn_reqs = aggr_expr.order_bys().map(|ordering| ordering.to_vec()); - // If - // - aggregation function is order-sensitive and - // - aggregation is performing a "first stage" calculation, and - // - at least one of the aggregate function requirement is not inside group by expression - // keep the ordering requirement as is; otherwise ignore the ordering requirement. - // In non-first stage modes, we accumulate data (using `merge_batch`) - // from different partitions (i.e. merge partial results). During - // this merge, we consider the ordering of each partial result. - // Hence, we do not need to use the ordering requirement in such - // modes as long as partial results are generated with the - // correct ordering. - fn_reqs.filter(|req| { - is_order_sensitive(aggr_expr) - && mode.is_first_stage() - && !group_by_contains_all_requirements(&group_by, req) - }) - }) - .collect::>(); - let requirement = get_finest_requirement( - &mut aggr_expr, - &mut order_by_expr, - &input.equivalence_properties(), - )?; - let mut ordering_req = requirement.unwrap_or(vec![]); - let input_order_mode = get_aggregate_search_mode( - &group_by, - &input, - &mut aggr_expr, - &mut order_by_expr, - &mut ordering_req, - ); - + let input_eq_properties = input.equivalence_properties(); // Get GROUP BY expressions: let groupby_exprs = group_by.input_exprs(); // If existing ordering satisfies a prefix of the GROUP BY expressions, @@ -528,17 +339,31 @@ impl AggregateExec { // work more efficiently. let indices = get_ordered_partition_by_indices(&groupby_exprs, &input); let mut new_requirement = indices - .into_iter() - .map(|idx| PhysicalSortRequirement { + .iter() + .map(|&idx| PhysicalSortRequirement { expr: groupby_exprs[idx].clone(), options: None, }) .collect::>(); - // Postfix ordering requirement of the aggregation to the requirement. - let req = PhysicalSortRequirement::from_sort_exprs(&ordering_req); + + let req = get_aggregate_exprs_requirement( + &aggr_expr, + &group_by, + &input_eq_properties, + &mode, + )?; new_requirement.extend(req); new_requirement = collapse_lex_req(new_requirement); + let input_order_mode = + if indices.len() == groupby_exprs.len() && !indices.is_empty() { + InputOrderMode::Sorted + } else if !indices.is_empty() { + InputOrderMode::PartiallySorted(indices) + } else { + InputOrderMode::Linear + }; + // construct a map from the input expression to the output expression of the Aggregation group by let projection_mapping = ProjectionMapping::try_new(&group_by.expr, &input.schema())?; @@ -546,9 +371,8 @@ impl AggregateExec { let required_input_ordering = (!new_requirement.is_empty()).then_some(new_requirement); - let aggregate_eqs = input - .equivalence_properties() - .project(&projection_mapping, schema.clone()); + let aggregate_eqs = + input_eq_properties.project(&projection_mapping, schema.clone()); let output_ordering = aggregate_eqs.oeq_class().output_ordering(); Ok(AggregateExec { @@ -998,6 +822,121 @@ fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { Arc::new(Schema::new(group_fields)) } +/// Determines the lexical ordering requirement for an aggregate expression. +/// +/// # Parameters +/// +/// - `aggr_expr`: A reference to an `Arc` representing the +/// aggregate expression. +/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the +/// physical GROUP BY expression. +/// - `agg_mode`: A reference to an `AggregateMode` instance representing the +/// mode of aggregation. +/// +/// # Returns +/// +/// A `LexOrdering` instance indicating the lexical ordering requirement for +/// the aggregate expression. +fn get_aggregate_expr_req( + aggr_expr: &Arc, + group_by: &PhysicalGroupBy, + agg_mode: &AggregateMode, +) -> LexOrdering { + // If the aggregation function is not order sensitive, or the aggregation + // is performing a "second stage" calculation, or all aggregate function + // requirements are inside the GROUP BY expression, then ignore the ordering + // requirement. + if !is_order_sensitive(aggr_expr) || !agg_mode.is_first_stage() { + return vec![]; + } + + let mut req = aggr_expr.order_bys().unwrap_or_default().to_vec(); + + // In non-first stage modes, we accumulate data (using `merge_batch`) from + // different partitions (i.e. merge partial results). During this merge, we + // consider the ordering of each partial result. Hence, we do not need to + // use the ordering requirement in such modes as long as partial results are + // generated with the correct ordering. + if group_by.is_single() { + // Remove all orderings that occur in the group by. These requirements + // will definitely be satisfied -- Each group by expression will have + // distinct values per group, hence all requirements are satisfied. + let physical_exprs = group_by.input_exprs(); + req.retain(|sort_expr| { + !physical_exprs_contains(&physical_exprs, &sort_expr.expr) + }); + } + req +} + +/// Computes the finer ordering for between given existing ordering requirement +/// of aggregate expression. +/// +/// # Parameters +/// +/// * `existing_req` - The existing lexical ordering that needs refinement. +/// * `aggr_expr` - A reference to an aggregate expression trait object. +/// * `group_by` - Information about the physical grouping (e.g group by expression). +/// * `eq_properties` - Equivalence properties relevant to the computation. +/// * `agg_mode` - The mode of aggregation (e.g., Partial, Final, etc.). +/// +/// # Returns +/// +/// An `Option` representing the computed finer lexical ordering, +/// or `None` if there is no finer ordering; e.g. the existing requirement and +/// the aggregator requirement is incompatible. +fn finer_ordering( + existing_req: &LexOrdering, + aggr_expr: &Arc, + group_by: &PhysicalGroupBy, + eq_properties: &EquivalenceProperties, + agg_mode: &AggregateMode, +) -> Option { + let aggr_req = get_aggregate_expr_req(aggr_expr, group_by, agg_mode); + eq_properties.get_finer_ordering(existing_req, &aggr_req) +} + +/// Get the common requirement that satisfies all the aggregate expressions. +/// +/// # Parameters +/// +/// - `aggr_exprs`: A slice of `Arc` containing all the +/// aggregate expressions. +/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the +/// physical GROUP BY expression. +/// - `eq_properties`: A reference to an `EquivalenceProperties` instance +/// representing equivalence properties for ordering. +/// - `agg_mode`: A reference to an `AggregateMode` instance representing the +/// mode of aggregation. +/// +/// # Returns +/// +/// A `LexRequirement` instance, which is the requirement that satisfies all the +/// aggregate requirements. Returns an error in case of conflicting requirements. +fn get_aggregate_exprs_requirement( + aggr_exprs: &[Arc], + group_by: &PhysicalGroupBy, + eq_properties: &EquivalenceProperties, + agg_mode: &AggregateMode, +) -> Result { + let mut requirement = vec![]; + for aggr_expr in aggr_exprs.iter() { + if let Some(finer_ordering) = + finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) + { + requirement = finer_ordering; + } else { + // If neither of the requirements satisfy the other, this means + // requirements are conflicting. Currently, we do not support + // conflicting requirements. + return not_impl_err!( + "Conflicting ordering requirements in aggregate functions is not supported" + ); + } + } + Ok(PhysicalSortRequirement::from_sort_exprs(&requirement)) +} + /// returns physical expressions for arguments to evaluate against a batch /// The expressions are different depending on `mode`: /// * Partial: AggregateExpr::expressions @@ -1013,33 +952,27 @@ fn aggregate_expressions( | AggregateMode::SinglePartitioned => Ok(aggr_expr .iter() .map(|agg| { - let mut result = agg.expressions().clone(); - // In partial mode, append ordering requirements to expressions' results. - // Ordering requirements are used by subsequent executors to satisfy the required - // ordering for `AggregateMode::FinalPartitioned`/`AggregateMode::Final` modes. - if matches!(mode, AggregateMode::Partial) { - if let Some(ordering_req) = agg.order_bys() { - let ordering_exprs = ordering_req - .iter() - .map(|item| item.expr.clone()) - .collect::>(); - result.extend(ordering_exprs); - } + let mut result = agg.expressions(); + // Append ordering requirements to expressions' results. This + // way order sensitive aggregators can satisfy requirement + // themselves. + if let Some(ordering_req) = agg.order_bys() { + result.extend(ordering_req.iter().map(|item| item.expr.clone())); } result }) .collect()), - // in this mode, we build the merge expressions of the aggregation + // In this mode, we build the merge expressions of the aggregation. AggregateMode::Final | AggregateMode::FinalPartitioned => { let mut col_idx_base = col_idx_base; - Ok(aggr_expr + aggr_expr .iter() .map(|agg| { let exprs = merge_expressions(col_idx_base, agg)?; col_idx_base += exprs.len(); Ok(exprs) }) - .collect::>>()?) + .collect() } } } @@ -1052,14 +985,13 @@ fn merge_expressions( index_base: usize, expr: &Arc, ) -> Result>> { - Ok(expr - .state_fields()? - .iter() - .enumerate() - .map(|(idx, f)| { - Arc::new(Column::new(f.name(), index_base + idx)) as Arc - }) - .collect::>()) + expr.state_fields().map(|fields| { + fields + .iter() + .enumerate() + .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _) + .collect() + }) } pub(crate) type AccumulatorItem = Box; @@ -1070,7 +1002,7 @@ fn create_accumulators( aggr_expr .iter() .map(|expr| expr.create_accumulator()) - .collect::>>() + .collect() } /// returns a vector of ArrayRefs, where each entry corresponds to either the @@ -1081,8 +1013,8 @@ fn finalize_aggregation( ) -> Result> { match mode { AggregateMode::Partial => { - // build the vector of states - let a = accumulators + // Build the vector of states + accumulators .iter() .map(|accumulator| { accumulator.state().and_then(|e| { @@ -1091,18 +1023,18 @@ fn finalize_aggregation( .collect::>>() }) }) - .collect::>>()?; - Ok(a.iter().flatten().cloned().collect::>()) + .flatten_ok() + .collect() } AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::Single | AggregateMode::SinglePartitioned => { - // merge the state to the final value + // Merge the state to the final value accumulators .iter() .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) - .collect::>>() + .collect() } } } @@ -1125,9 +1057,7 @@ pub(crate) fn evaluate_many( expr: &[Vec>], batch: &RecordBatch, ) -> Result>> { - expr.iter() - .map(|expr| evaluate(expr, batch)) - .collect::>>() + expr.iter().map(|expr| evaluate(expr, batch)).collect() } fn evaluate_optional( @@ -1143,7 +1073,7 @@ fn evaluate_optional( }) .transpose() }) - .collect::>>() + .collect() } /// Evaluate a group by expression against a `RecordBatch` @@ -1204,9 +1134,7 @@ mod tests { use std::task::{Context, Poll}; use super::*; - use crate::aggregates::{ - get_finest_requirement, AggregateExec, AggregateMode, PhysicalGroupBy, - }; + use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::coalesce_batches::CoalesceBatchesExec; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common; @@ -1228,15 +1156,16 @@ mod tests { Result, ScalarValue, }; use datafusion_execution::config::SessionConfig; + use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_physical_expr::expressions::{ - lit, ApproxDistinct, Count, FirstValue, LastValue, Median, + lit, ApproxDistinct, Count, FirstValue, LastValue, Median, OrderSensitiveArrayAgg, }; use datafusion_physical_expr::{ - AggregateExpr, EquivalenceProperties, PhysicalExpr, PhysicalSortExpr, + reverse_order_bys, AggregateExpr, EquivalenceProperties, PhysicalExpr, + PhysicalSortExpr, }; - use datafusion_execution::memory_pool::FairSpillPool; use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -2093,11 +2022,6 @@ mod tests { descending: false, nulls_first: false, }; - // This is the reverse requirement of options1 - let options2 = SortOptions { - descending: true, - nulls_first: true, - }; let col_a = &col("a", &test_schema)?; let col_b = &col("b", &test_schema)?; let col_c = &col("c", &test_schema)?; @@ -2106,7 +2030,7 @@ mod tests { eq_properties.add_equal_conditions(col_a, col_b); // Aggregate requirements are // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively - let mut order_by_exprs = vec![ + let order_by_exprs = vec![ None, Some(vec![PhysicalSortExpr { expr: col_a.clone(), @@ -2136,14 +2060,8 @@ mod tests { options: options1, }, ]), - // Since aggregate expression is reversible (FirstValue), we should be able to resolve below - // contradictory requirement by reversing it. - Some(vec![PhysicalSortExpr { - expr: col_b.clone(), - options: options2, - }]), ]; - let common_requirement = Some(vec![ + let common_requirement = vec![ PhysicalSortExpr { expr: col_a.clone(), options: options1, @@ -2152,17 +2070,28 @@ mod tests { expr: col_c.clone(), options: options1, }, - ]); - let aggr_expr = Arc::new(FirstValue::new( - col_a.clone(), - "first1", - DataType::Int32, - vec![], - vec![], - )) as _; - let mut aggr_exprs = vec![aggr_expr; order_by_exprs.len()]; - let res = - get_finest_requirement(&mut aggr_exprs, &mut order_by_exprs, &eq_properties)?; + ]; + let aggr_exprs = order_by_exprs + .into_iter() + .map(|order_by_expr| { + Arc::new(OrderSensitiveArrayAgg::new( + col_a.clone(), + "array_agg", + DataType::Int32, + false, + vec![], + order_by_expr.unwrap_or_default(), + )) as _ + }) + .collect::>(); + let group_by = PhysicalGroupBy::new_single(vec![]); + let res = get_aggregate_exprs_requirement( + &aggr_exprs, + &group_by, + &eq_properties, + &AggregateMode::Partial, + )?; + let res = PhysicalSortRequirement::to_sort_exprs(res); assert_eq!(res, common_requirement); Ok(()) } diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs index 663bbdd5a3c7..8e2bbbfe4f69 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs @@ -21,5 +21,4 @@ mod normalize; mod runner; pub use error::*; -pub use normalize::*; pub use runner::*; diff --git a/datafusion/sqllogictest/test_files/distinct_on.slt b/datafusion/sqllogictest/test_files/distinct_on.slt index 9a7117b69b99..3f609e254839 100644 --- a/datafusion/sqllogictest/test_files/distinct_on.slt +++ b/datafusion/sqllogictest/test_files/distinct_on.slt @@ -78,7 +78,7 @@ c 4 query I SELECT DISTINCT ON (c1) c2 FROM aggregate_test_100 ORDER BY c1, c3; ---- -5 +4 4 2 1 @@ -100,10 +100,9 @@ ProjectionExec: expr=[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_tes ------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 -------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)], ordering_mode=Sorted ---------------SortExec: expr=[c1@0 ASC NULLS LAST,c3@2 ASC NULLS LAST] -----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true +------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true # ON expressions are not a sub-set of the ORDER BY expressions query error SELECT DISTINCT ON expressions must match initial ORDER BY expressions diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index f1b6a57287b5..bbf21e135fe4 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -2019,17 +2019,16 @@ SortPreservingMergeExec: [col0@0 ASC NULLS LAST] ------AggregateExec: mode=FinalPartitioned, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([col0@0, col1@1, col2@2], 4), input_partitions=4 -------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)], ordering_mode=PartiallySorted([0]) ---------------SortExec: expr=[col0@3 ASC NULLS LAST] -----------------ProjectionExec: expr=[col0@2 as col0, col1@3 as col1, col2@4 as col2, col0@0 as col0, col1@1 as col1] -------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] -----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 ---------------------------MemoryExec: partitions=1, partition_sizes=[3] -----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 ---------------------------MemoryExec: partitions=1, partition_sizes=[3] +------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)] +--------------ProjectionExec: expr=[col0@2 as col0, col1@3 as col1, col2@4 as col2, col0@0 as col0, col1@1 as col1] +----------------CoalesceBatchesExec: target_batch_size=8192 +------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[3] +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[3] # Columns in the table are a,b,c,d. Source is CsvExec which is ordered by # a,b,c column. Column a has cardinality 2, column b has cardinality 4. @@ -2209,7 +2208,7 @@ ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c) ----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III -SELECT a, b, LAST_VALUE(c ORDER BY a DESC) as last_c +SELECT a, b, LAST_VALUE(c ORDER BY a DESC, c ASC) as last_c FROM annotated_data_infinite2 GROUP BY a, b ---- @@ -2509,7 +2508,7 @@ Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales ----TableScan: sales_global projection=[country, amount] physical_plan ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ----SortExec: expr=[amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2540,7 +2539,7 @@ Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales ----TableScan: sales_global projection=[country, amount] physical_plan ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ----SortExec: expr=[amount@1 ASC NULLS LAST] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2572,7 +2571,7 @@ Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sal ----TableScan: sales_global projection=[country, amount] physical_plan ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@2 as fv2, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@3 as amounts] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount), ARRAY_AGG(sales_global.amount)] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount), ARRAY_AGG(sales_global.amount)] ----SortExec: expr=[amount@1 ASC NULLS LAST] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2637,9 +2636,8 @@ Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sal ------TableScan: sales_global projection=[country, ts, amount] physical_plan ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[LAST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount), SUM(sales_global.amount)] -----SortExec: expr=[ts@1 ASC NULLS LAST] -------MemoryExec: partitions=1, partition_sizes=[1] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount), SUM(sales_global.amount)] +----MemoryExec: partitions=1, partition_sizes=[1] query TRRR rowsort SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, @@ -2672,8 +2670,7 @@ Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sal physical_plan ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] --AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount), SUM(sales_global.amount)] -----SortExec: expr=[ts@1 DESC] -------MemoryExec: partitions=1, partition_sizes=[1] +----MemoryExec: partitions=1, partition_sizes=[1] query TRRR rowsort SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, @@ -2709,12 +2706,11 @@ physical_plan SortExec: expr=[sn@2 ASC NULLS LAST] --ProjectionExec: expr=[zip_code@1 as zip_code, country@2 as country, sn@0 as sn, ts@3 as ts, currency@4 as currency, LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST]@5 as last_rate] ----AggregateExec: mode=Single, gby=[sn@2 as sn, zip_code@0 as zip_code, country@1 as country, ts@3 as ts, currency@4 as currency], aggr=[LAST_VALUE(e.amount)] -------SortExec: expr=[sn@5 ASC NULLS LAST] ---------ProjectionExec: expr=[zip_code@4 as zip_code, country@5 as country, sn@6 as sn, ts@7 as ts, currency@8 as currency, sn@0 as sn, amount@3 as amount] -----------CoalesceBatchesExec: target_batch_size=8192 -------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@2, currency@4)], filter=ts@0 >= ts@1 ---------------MemoryExec: partitions=1, partition_sizes=[1] ---------------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[zip_code@4 as zip_code, country@5 as country, sn@6 as sn, ts@7 as ts, currency@8 as currency, sn@0 as sn, amount@3 as amount] +--------CoalesceBatchesExec: target_batch_size=8192 +----------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@2, currency@4)], filter=ts@0 >= ts@1 +------------MemoryExec: partitions=1, partition_sizes=[1] +------------MemoryExec: partitions=1, partition_sizes=[1] query ITIPTR rowsort SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate @@ -2759,8 +2755,7 @@ SortPreservingMergeExec: [country@0 ASC NULLS LAST] ----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 ------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 --------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] -----------------SortExec: expr=[ts@1 ASC NULLS LAST] -------------------MemoryExec: partitions=1, partition_sizes=[1] +----------------MemoryExec: partitions=1, partition_sizes=[1] query TRR SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, @@ -2791,13 +2786,12 @@ physical_plan SortPreservingMergeExec: [country@0 ASC NULLS LAST] --SortExec: expr=[country@0 ASC NULLS LAST] ----ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as fv2] -------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 ------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ---------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] -----------------SortExec: expr=[ts@1 ASC NULLS LAST] -------------------MemoryExec: partitions=1, partition_sizes=[1] +--------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +----------------MemoryExec: partitions=1, partition_sizes=[1] query TRR SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, @@ -2831,16 +2825,15 @@ ProjectionExec: expr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts --AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ----CoalescePartitionsExec ------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ---------SortExec: expr=[ts@0 ASC NULLS LAST] -----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -------------MemoryExec: partitions=1, partition_sizes=[1] +--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] query RR SELECT FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, LAST_VALUE(amount ORDER BY ts ASC) AS fv2 FROM sales_global ---- -30 80 +30 100 # Conversion in between FIRST_VALUE and LAST_VALUE to resolve # contradictory requirements should work in multi partitions. @@ -2855,12 +2848,11 @@ Projection: FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS ----TableScan: sales_global projection=[ts, amount] physical_plan ProjectionExec: expr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv2] ---AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +--AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ----CoalescePartitionsExec -------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] ---------SortExec: expr=[ts@0 ASC NULLS LAST] -----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -------------MemoryExec: partitions=1, partition_sizes=[1] +------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] query RR SELECT FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, @@ -2993,10 +2985,10 @@ physical_plan SortPreservingMergeExec: [country@0 ASC NULLS LAST] --SortExec: expr=[country@0 ASC NULLS LAST] ----ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] -------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------CoalesceBatchesExec: target_batch_size=4 ----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 -------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------------SortExec: expr=[amount@1 DESC] ----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------------MemoryExec: partitions=1, partition_sizes=[1] @@ -3639,10 +3631,10 @@ Projection: FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_tab ----TableScan: multiple_ordered_table projection=[a, c, d] physical_plan ProjectionExec: expr=[FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST]@1 as first_a, LAST_VALUE(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]@2 as last_c] ---AggregateExec: mode=FinalPartitioned, gby=[d@0 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), FIRST_VALUE(multiple_ordered_table.c)] +--AggregateExec: mode=FinalPartitioned, gby=[d@0 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), LAST_VALUE(multiple_ordered_table.c)] ----CoalesceBatchesExec: target_batch_size=2 ------RepartitionExec: partitioning=Hash([d@0], 8), input_partitions=8 ---------AggregateExec: mode=Partial, gby=[d@2 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), FIRST_VALUE(multiple_ordered_table.c)] +--------AggregateExec: mode=Partial, gby=[d@2 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), LAST_VALUE(multiple_ordered_table.c)] ----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 9a349f600091..a7146a5a91c4 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3454,7 +3454,7 @@ SortPreservingMergeExec: [a@0 ASC] ------AggregateExec: mode=FinalPartitioned, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b)] --------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([a@0, b@1, c@2], 2), input_partitions=2 -------------AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b)], ordering_mode=PartiallySorted([0]) +------------AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b)] --------------CoalesceBatchesExec: target_batch_size=2 ----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0)] ------------------CoalesceBatchesExec: target_batch_size=2 @@ -3462,7 +3462,7 @@ SortPreservingMergeExec: [a@0 ASC] ----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true ------------------CoalesceBatchesExec: target_batch_size=2 ---------------------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC,b@1 ASC NULLS LAST +--------------------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 ----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST], has_header=true From 8284371cb5dbeb5d0b1d50c420affb9be86b1599 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Thu, 28 Dec 2023 22:08:09 +0100 Subject: [PATCH 43/63] feat: support 'LargeList' in `array_pop_front` and `array_pop_back` (#8569) * support largelist in pop back * support largelist in pop front * add function comment * use execution error * use execution error * spilit the general code --- .../physical-expr/src/array_expressions.rs | 90 ++++++++++++++----- datafusion/sqllogictest/test_files/array.slt | 75 ++++++++++++++++ 2 files changed, 141 insertions(+), 24 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 7a986810bad2..250250630eff 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -743,22 +743,78 @@ where )?)) } -/// array_pop_back SQL function -pub fn array_pop_back(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_pop_back needs one argument"); - } +fn general_pop_front_list( + array: &GenericListArray, +) -> Result +where + i64: TryInto, +{ + let from_array = Int64Array::from(vec![2; array.len()]); + let to_array = Int64Array::from( + array + .iter() + .map(|arr| arr.map_or(0, |arr| arr.len() as i64)) + .collect::>(), + ); + general_array_slice::(array, &from_array, &to_array) +} - let list_array = as_list_array(&args[0])?; - let from_array = Int64Array::from(vec![1; list_array.len()]); +fn general_pop_back_list( + array: &GenericListArray, +) -> Result +where + i64: TryInto, +{ + let from_array = Int64Array::from(vec![1; array.len()]); let to_array = Int64Array::from( - list_array + array .iter() .map(|arr| arr.map_or(0, |arr| arr.len() as i64 - 1)) .collect::>(), ); - let args = vec![args[0].clone(), Arc::new(from_array), Arc::new(to_array)]; - array_slice(args.as_slice()) + general_array_slice::(array, &from_array, &to_array) +} + +/// array_pop_front SQL function +pub fn array_pop_front(args: &[ArrayRef]) -> Result { + let array_data_type = args[0].data_type(); + match array_data_type { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + general_pop_front_list::(array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_pop_front_list::(array) + } + _ => exec_err!( + "array_pop_front does not support type: {:?}", + array_data_type + ), + } +} + +/// array_pop_back SQL function +pub fn array_pop_back(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_pop_back needs one argument"); + } + + let array_data_type = args[0].data_type(); + match array_data_type { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + general_pop_back_list::(array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_pop_back_list::(array) + } + _ => exec_err!( + "array_pop_back does not support type: {:?}", + array_data_type + ), + } } /// Appends or prepends elements to a ListArray. @@ -882,20 +938,6 @@ pub fn gen_range(args: &[ArrayRef]) -> Result { Ok(arr) } -/// array_pop_front SQL function -pub fn array_pop_front(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let from_array = Int64Array::from(vec![2; list_array.len()]); - let to_array = Int64Array::from( - list_array - .iter() - .map(|arr| arr.map_or(0, |arr| arr.len() as i64)) - .collect::>(), - ); - let args = vec![args[0].clone(), Arc::new(from_array), Arc::new(to_array)]; - array_slice(args.as_slice()) -} - /// Array_append SQL function pub fn array_append(args: &[ArrayRef]) -> Result { if args.len() != 2 { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 4c4adbabfda5..b8d89edb49b1 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -994,18 +994,33 @@ select array_pop_back(make_array(1, 2, 3, 4, 5)), array_pop_back(make_array('h', ---- [1, 2, 3, 4] [h, e, l, l] +query ?? +select array_pop_back(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), array_pop_back(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [h, e, l, l] + # array_pop_back scalar function #2 (after array_pop_back, array is empty) query ? select array_pop_back(make_array(1)); ---- [] +query ? +select array_pop_back(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +[] + # array_pop_back scalar function #3 (array_pop_back the empty array) query ? select array_pop_back(array_pop_back(make_array(1))); ---- [] +query ? +select array_pop_back(array_pop_back(arrow_cast(make_array(1), 'LargeList(Int64)'))); +---- +[] + # array_pop_back scalar function #4 (array_pop_back the arrays which have NULL) query ?? select array_pop_back(make_array(1, 2, 3, 4, NULL)), array_pop_back(make_array(NULL, 'e', 'l', NULL, 'o')); @@ -1018,24 +1033,44 @@ select array_pop_back(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_ ---- [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6)), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + # array_pop_back scalar function #6 (array_pop_back the nested arrays with NULL) query ? select array_pop_back(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), NULL)); ---- [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), NULL), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + # array_pop_back scalar function #7 (array_pop_back the nested arrays with NULL) query ? select array_pop_back(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), NULL, make_array(1, 7, 4))); ---- [[1, 2, 3], [2, 9, 1], [7, 8, 9], ] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), NULL, make_array(1, 7, 4)), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], ] + # array_pop_back scalar function #8 (after array_pop_back, nested array is empty) query ? select array_pop_back(make_array(make_array(1, 2, 3))); ---- [] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3)), 'LargeList(List(Int64))')); +---- +[] + # array_pop_back with columns query ? select array_pop_back(column1) from arrayspop; @@ -1047,6 +1082,16 @@ select array_pop_back(column1) from arrayspop; [] [, 10, 11] +query ? +select array_pop_back(arrow_cast(column1, 'LargeList(Int64)')) from arrayspop; +---- +[1, 2] +[3, 4, 5] +[6, 7, 8, ] +[, ] +[] +[, 10, 11] + ## array_pop_front (aliases: `list_pop_front`) # array_pop_front scalar function #1 @@ -1055,36 +1100,66 @@ select array_pop_front(make_array(1, 2, 3, 4, 5)), array_pop_front(make_array('h ---- [2, 3, 4, 5] [e, l, l, o] +query ?? +select array_pop_front(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), array_pop_front(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[2, 3, 4, 5] [e, l, l, o] + # array_pop_front scalar function #2 (after array_pop_front, array is empty) query ? select array_pop_front(make_array(1)); ---- [] +query ? +select array_pop_front(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +[] + # array_pop_front scalar function #3 (array_pop_front the empty array) query ? select array_pop_front(array_pop_front(make_array(1))); ---- [] +query ? +select array_pop_front(array_pop_front(arrow_cast(make_array(1), 'LargeList(Int64)'))); +---- +[] + # array_pop_front scalar function #5 (array_pop_front the nested arrays) query ? select array_pop_front(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6))); ---- [[2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] +query ? +select array_pop_front(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6)), 'LargeList(List(Int64))')); +---- +[[2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] + # array_pop_front scalar function #6 (array_pop_front the nested arrays with NULL) query ? select array_pop_front(make_array(NULL, make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4))); ---- [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] +query ? +select array_pop_front(arrow_cast(make_array(NULL, make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4)), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + # array_pop_front scalar function #8 (after array_pop_front, nested array is empty) query ? select array_pop_front(make_array(make_array(1, 2, 3))); ---- [] +query ? +select array_pop_front(arrow_cast(make_array(make_array(1, 2, 3)), 'LargeList(List(Int64))')); +---- +[] + ## array_slice (aliases: list_slice) # array_slice scalar function #1 (with positive indexes) From 673f0e17ace7e7a08474c26be50038cf0e251477 Mon Sep 17 00:00:00 2001 From: Ruixiang Tan Date: Fri, 29 Dec 2023 19:27:39 +0800 Subject: [PATCH 44/63] chore: rename ceresdb to apache horaedb (#8674) --- docs/source/user-guide/introduction.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 6c1e54c2b701..b737c3bab266 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -75,7 +75,7 @@ latency). Here are some example systems built using DataFusion: -- Specialized Analytical Database systems such as [CeresDB] and more general Apache Spark like system such a [Ballista]. +- Specialized Analytical Database systems such as [HoraeDB] and more general Apache Spark like system such a [Ballista]. - New query language engines such as [prql-query] and accelerators such as [VegaFusion] - Research platform for new Database Systems, such as [Flock] - SQL support to another library, such as [dask sql] @@ -96,7 +96,6 @@ Here are some active projects using DataFusion: - [Arroyo](https://github.com/ArroyoSystems/arroyo) Distributed stream processing engine in Rust - [Ballista](https://github.com/apache/arrow-ballista) Distributed SQL Query Engine -- [CeresDB](https://github.com/CeresDB/ceresdb) Distributed Time-Series Database - [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database - [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) - [Dask SQL](https://github.com/dask-contrib/dask-sql) Distributed SQL query engine in Python @@ -104,6 +103,7 @@ Here are some active projects using DataFusion: - [delta-rs](https://github.com/delta-io/delta-rs) Native Rust implementation of Delta Lake - [GreptimeDB](https://github.com/GreptimeTeam/greptimedb) Open Source & Cloud Native Distributed Time Series Database - [GlareDB](https://github.com/GlareDB/glaredb) Fast SQL database for querying and analyzing distributed data. +- [HoraeDB](https://github.com/apache/incubator-horaedb) Distributed Time-Series Database - [InfluxDB IOx](https://github.com/influxdata/influxdb_iox) Time Series Database - [Kamu](https://github.com/kamu-data/kamu-cli/) Planet-scale streaming data pipeline - [LakeSoul](https://github.com/lakesoul-io/LakeSoul) Open source LakeHouse framework with native IO in Rust. @@ -128,7 +128,6 @@ Here are some less active projects that used DataFusion: [ballista]: https://github.com/apache/arrow-ballista [blaze]: https://github.com/blaze-init/blaze -[ceresdb]: https://github.com/CeresDB/ceresdb [cloudfuse buzz]: https://github.com/cloudfuse-io/buzz-rust [cnosdb]: https://github.com/cnosdb/cnosdb [cube store]: https://github.com/cube-js/cube.js/tree/master/rust @@ -138,6 +137,7 @@ Here are some less active projects that used DataFusion: [flock]: https://github.com/flock-lab/flock [kamu]: https://github.com/kamu-data/kamu-cli [greptime db]: https://github.com/GreptimeTeam/greptimedb +[horaedb]: https://github.com/apache/incubator-horaedb [influxdb iox]: https://github.com/influxdata/influxdb_iox [parseable]: https://github.com/parseablehq/parseable [prql-query]: https://github.com/prql/prql-query From d515c68da6e9795271c54a2f4b7853ca25cc90da Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Fri, 29 Dec 2023 12:44:07 +0100 Subject: [PATCH 45/63] clean code (#8671) --- datafusion/proto/src/logical_plan/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index e997bcde426e..dbed0252d051 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1765,8 +1765,8 @@ pub(crate) fn writer_properties_to_proto( pub(crate) fn writer_properties_from_proto( props: &protobuf::WriterProperties, ) -> Result { - let writer_version = WriterVersion::from_str(&props.writer_version) - .map_err(|e| proto_error(e.to_string()))?; + let writer_version = + WriterVersion::from_str(&props.writer_version).map_err(proto_error)?; Ok(WriterProperties::builder() .set_created_by(props.created_by.clone()) .set_writer_version(writer_version) From 8ced56e418a50456cc8193547683bfcceb063f0d Mon Sep 17 00:00:00 2001 From: Eduard Karacharov <13005055+korowa@users.noreply.github.com> Date: Fri, 29 Dec 2023 14:37:25 +0200 Subject: [PATCH 46/63] remove tz with modified offset from tests (#8677) --- datafusion/sqllogictest/test_files/timestamps.slt | 3 --- 1 file changed, 3 deletions(-) diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 2b3b4bf2e45b..c84e46c965fa 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -1730,14 +1730,11 @@ SELECT TIMESTAMPTZ '2022-01-01 01:10:00 AEST' query P rowsort SELECT TIMESTAMPTZ '2022-01-01 01:10:00 Australia/Sydney' as ts_geo UNION ALL -SELECT TIMESTAMPTZ '2022-01-01 01:10:00 Antarctica/Vostok' as ts_geo - UNION ALL SELECT TIMESTAMPTZ '2022-01-01 01:10:00 Africa/Johannesburg' as ts_geo UNION ALL SELECT TIMESTAMPTZ '2022-01-01 01:10:00 America/Los_Angeles' as ts_geo ---- 2021-12-31T14:10:00Z -2021-12-31T19:10:00Z 2021-12-31T23:10:00Z 2022-01-01T09:10:00Z From b85a39739e754576723ff4b1691c518a86335769 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Metehan=20Y=C4=B1ld=C4=B1r=C4=B1m?= <100111937+metesynnada@users.noreply.github.com> Date: Fri, 29 Dec 2023 15:51:02 +0300 Subject: [PATCH 47/63] Make the BatchSerializer behind Arc to avoid unnecessary struct creation (#8666) * Make the BatchSerializer behind Arc * Commenting * Review * Incorporate review suggestions * Use old names --------- Co-authored-by: Mehmet Ozan Kabak --- .../core/src/datasource/file_format/csv.rs | 69 +++++++---------- .../core/src/datasource/file_format/json.rs | 77 ++++++++----------- .../src/datasource/file_format/write/mod.rs | 16 +--- .../file_format/write/orchestration.rs | 74 ++++++++---------- .../datasource/physical_plan/file_stream.rs | 12 ++- 5 files changed, 98 insertions(+), 150 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 4033bcd3b557..d4e63904bdd4 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -19,21 +19,9 @@ use std::any::Any; use std::collections::HashSet; -use std::fmt; -use std::fmt::Debug; +use std::fmt::{self, Debug}; use std::sync::Arc; -use arrow_array::RecordBatch; -use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; -use datafusion_execution::TaskContext; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; - -use bytes::{Buf, Bytes}; -use datafusion_physical_plan::metrics::MetricsSet; -use futures::stream::BoxStream; -use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; -use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore}; - use super::write::orchestration::stateless_multipart_put; use super::{FileFormat, DEFAULT_SCHEMA_INFER_MAX_RECORD}; use crate::datasource::file_format::file_compression_type::FileCompressionType; @@ -47,11 +35,20 @@ use crate::physical_plan::insert::{DataSink, FileSinkExec}; use crate::physical_plan::{DisplayAs, DisplayFormatType, Statistics}; use crate::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; +use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use arrow::datatypes::{DataType, Field, Fields, Schema}; use arrow::{self, datatypes::SchemaRef}; +use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; +use bytes::{Buf, Bytes}; +use futures::stream::BoxStream; +use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; +use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore}; /// Character Separated Value `FileFormat` implementation. #[derive(Debug)] @@ -400,8 +397,6 @@ impl Default for CsvSerializer { pub struct CsvSerializer { // CSV writer builder builder: WriterBuilder, - // Inner buffer for avoiding reallocation - buffer: Vec, // Flag to indicate whether there will be a header header: bool, } @@ -412,7 +407,6 @@ impl CsvSerializer { Self { builder: WriterBuilder::new(), header: true, - buffer: Vec::with_capacity(4096), } } @@ -431,21 +425,14 @@ impl CsvSerializer { #[async_trait] impl BatchSerializer for CsvSerializer { - async fn serialize(&mut self, batch: RecordBatch) -> Result { + async fn serialize(&self, batch: RecordBatch, initial: bool) -> Result { + let mut buffer = Vec::with_capacity(4096); let builder = self.builder.clone(); - let mut writer = builder.with_header(self.header).build(&mut self.buffer); + let header = self.header && initial; + let mut writer = builder.with_header(header).build(&mut buffer); writer.write(&batch)?; drop(writer); - self.header = false; - Ok(Bytes::from(self.buffer.drain(..).collect::>())) - } - - fn duplicate(&mut self) -> Result> { - let new_self = CsvSerializer::new() - .with_builder(self.builder.clone()) - .with_header(self.header); - self.header = false; - Ok(Box::new(new_self)) + Ok(Bytes::from(buffer)) } } @@ -488,13 +475,11 @@ impl CsvSink { let builder_clone = builder.clone(); let options_clone = writer_options.clone(); let get_serializer = move || { - let inner_clone = builder_clone.clone(); - let serializer: Box = Box::new( + Arc::new( CsvSerializer::new() - .with_builder(inner_clone) + .with_builder(builder_clone.clone()) .with_header(options_clone.writer_options.header()), - ); - serializer + ) as _ }; stateless_multipart_put( @@ -541,15 +526,15 @@ mod tests { use crate::physical_plan::collect; use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext}; use crate::test_util::arrow_test_data; + use arrow::compute::concat_batches; - use bytes::Bytes; - use chrono::DateTime; use datafusion_common::cast::as_string_array; - use datafusion_common::internal_err; use datafusion_common::stats::Precision; - use datafusion_common::FileType; - use datafusion_common::GetExt; + use datafusion_common::{internal_err, FileType, GetExt}; use datafusion_expr::{col, lit}; + + use bytes::Bytes; + use chrono::DateTime; use futures::StreamExt; use object_store::local::LocalFileSystem; use object_store::path::Path; @@ -836,8 +821,8 @@ mod tests { .collect() .await?; let batch = concat_batches(&batches[0].schema(), &batches)?; - let mut serializer = CsvSerializer::new(); - let bytes = serializer.serialize(batch).await?; + let serializer = CsvSerializer::new(); + let bytes = serializer.serialize(batch, true).await?; assert_eq!( "c2,c3\n2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n", String::from_utf8(bytes.into()).unwrap() @@ -860,8 +845,8 @@ mod tests { .collect() .await?; let batch = concat_batches(&batches[0].schema(), &batches)?; - let mut serializer = CsvSerializer::new().with_header(false); - let bytes = serializer.serialize(batch).await?; + let serializer = CsvSerializer::new().with_header(false); + let bytes = serializer.serialize(batch, true).await?; assert_eq!( "2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n", String::from_utf8(bytes.into()).unwrap() diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index fcb1d5f8e527..3d437bc5fe68 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -23,40 +23,34 @@ use std::fmt::Debug; use std::io::BufReader; use std::sync::Arc; -use super::{FileFormat, FileScanConfig}; -use arrow::datatypes::Schema; -use arrow::datatypes::SchemaRef; -use arrow::json; -use arrow::json::reader::infer_json_schema_from_iterator; -use arrow::json::reader::ValueIter; -use arrow_array::RecordBatch; -use async_trait::async_trait; -use bytes::Buf; - -use bytes::Bytes; -use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_expr::PhysicalSortRequirement; -use datafusion_physical_plan::ExecutionPlan; -use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; - -use crate::datasource::physical_plan::FileGroupDisplay; -use crate::physical_plan::insert::DataSink; -use crate::physical_plan::insert::FileSinkExec; -use crate::physical_plan::SendableRecordBatchStream; -use crate::physical_plan::{DisplayAs, DisplayFormatType, Statistics}; - use super::write::orchestration::stateless_multipart_put; - +use super::{FileFormat, FileScanConfig}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; +use crate::datasource::physical_plan::FileGroupDisplay; use crate::datasource::physical_plan::{FileSinkConfig, NdJsonExec}; use crate::error::Result; use crate::execution::context::SessionState; +use crate::physical_plan::insert::{DataSink, FileSinkExec}; +use crate::physical_plan::{ + DisplayAs, DisplayFormatType, SendableRecordBatchStream, Statistics, +}; +use arrow::datatypes::Schema; +use arrow::datatypes::SchemaRef; +use arrow::json; +use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter}; +use arrow_array::RecordBatch; use datafusion_common::{not_impl_err, DataFusionError, FileType}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::ExecutionPlan; + +use async_trait::async_trait; +use bytes::{Buf, Bytes}; +use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; /// New line delimited JSON `FileFormat` implementation. #[derive(Debug)] @@ -201,31 +195,22 @@ impl Default for JsonSerializer { } /// Define a struct for serializing Json records to a stream -pub struct JsonSerializer { - // Inner buffer for avoiding reallocation - buffer: Vec, -} +pub struct JsonSerializer {} impl JsonSerializer { /// Constructor for the JsonSerializer object pub fn new() -> Self { - Self { - buffer: Vec::with_capacity(4096), - } + Self {} } } #[async_trait] impl BatchSerializer for JsonSerializer { - async fn serialize(&mut self, batch: RecordBatch) -> Result { - let mut writer = json::LineDelimitedWriter::new(&mut self.buffer); + async fn serialize(&self, batch: RecordBatch, _initial: bool) -> Result { + let mut buffer = Vec::with_capacity(4096); + let mut writer = json::LineDelimitedWriter::new(&mut buffer); writer.write(&batch)?; - //drop(writer); - Ok(Bytes::from(self.buffer.drain(..).collect::>())) - } - - fn duplicate(&mut self) -> Result> { - Ok(Box::new(JsonSerializer::new())) + Ok(Bytes::from(buffer)) } } @@ -272,10 +257,7 @@ impl JsonSink { let writer_options = self.config.file_type_writer_options.try_into_json()?; let compression = &writer_options.compression; - let get_serializer = move || { - let serializer: Box = Box::new(JsonSerializer::new()); - serializer - }; + let get_serializer = move || Arc::new(JsonSerializer::new()) as _; stateless_multipart_put( data, @@ -312,16 +294,17 @@ impl DataSink for JsonSink { #[cfg(test)] mod tests { use super::super::test_util::scan_format; - use datafusion_common::cast::as_int64_array; - use datafusion_common::stats::Precision; - use futures::StreamExt; - use object_store::local::LocalFileSystem; - use super::*; use crate::physical_plan::collect; use crate::prelude::{SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; + use datafusion_common::cast::as_int64_array; + use datafusion_common::stats::Precision; + + use futures::StreamExt; + use object_store::local::LocalFileSystem; + #[tokio::test] async fn read_small_batches() -> Result<()> { let config = SessionConfig::new().with_batch_size(2); diff --git a/datafusion/core/src/datasource/file_format/write/mod.rs b/datafusion/core/src/datasource/file_format/write/mod.rs index 68fe81ce91fa..c481f2accf19 100644 --- a/datafusion/core/src/datasource/file_format/write/mod.rs +++ b/datafusion/core/src/datasource/file_format/write/mod.rs @@ -24,20 +24,16 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::datasource::file_format::file_compression_type::FileCompressionType; - use crate::error::Result; use arrow_array::RecordBatch; - use datafusion_common::DataFusionError; use async_trait::async_trait; use bytes::Bytes; - use futures::future::BoxFuture; use object_store::path::Path; use object_store::{MultipartId, ObjectStore}; - use tokio::io::AsyncWrite; pub(crate) mod demux; @@ -149,15 +145,11 @@ impl AsyncWrite for AbortableWrite { /// A trait that defines the methods required for a RecordBatch serializer. #[async_trait] -pub trait BatchSerializer: Unpin + Send { +pub trait BatchSerializer: Sync + Send { /// Asynchronously serializes a `RecordBatch` and returns the serialized bytes. - async fn serialize(&mut self, batch: RecordBatch) -> Result; - /// Duplicates self to support serializing multiple batches in parallel on multiple cores - fn duplicate(&mut self) -> Result> { - Err(DataFusionError::NotImplemented( - "Parallel serialization is not implemented for this file type".into(), - )) - } + /// Parameter `initial` signals whether the given batch is the first batch. + /// This distinction is important for certain serializers (like CSV). + async fn serialize(&self, batch: RecordBatch, initial: bool) -> Result; } /// Returns an [`AbortableWrite`] which writes to the given object store location diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index 120e27ecf669..9b820a15b280 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -21,28 +21,25 @@ use std::sync::Arc; +use super::demux::start_demuxer_task; +use super::{create_writer, AbortableWrite, BatchSerializer}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::physical_plan::FileSinkConfig; use crate::error::Result; use crate::physical_plan::SendableRecordBatchStream; use arrow_array::RecordBatch; - -use datafusion_common::DataFusionError; - -use bytes::Bytes; +use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError}; use datafusion_execution::TaskContext; +use bytes::Bytes; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver}; use tokio::task::{JoinHandle, JoinSet}; use tokio::try_join; -use super::demux::start_demuxer_task; -use super::{create_writer, AbortableWrite, BatchSerializer}; - type WriterType = AbortableWrite>; -type SerializerType = Box; +type SerializerType = Arc; /// Serializes a single data stream in parallel and writes to an ObjectStore /// concurrently. Data order is preserved. In the event of an error, @@ -50,33 +47,28 @@ type SerializerType = Box; /// so that the caller may handle aborting failed writes. pub(crate) async fn serialize_rb_stream_to_object_store( mut data_rx: Receiver, - mut serializer: Box, + serializer: Arc, mut writer: AbortableWrite>, ) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> { let (tx, mut rx) = mpsc::channel::>>(100); - let serialize_task = tokio::spawn(async move { + // Some serializers (like CSV) handle the first batch differently than + // subsequent batches, so we track that here. + let mut initial = true; while let Some(batch) = data_rx.recv().await { - match serializer.duplicate() { - Ok(mut serializer_clone) => { - let handle = tokio::spawn(async move { - let num_rows = batch.num_rows(); - let bytes = serializer_clone.serialize(batch).await?; - Ok((num_rows, bytes)) - }); - tx.send(handle).await.map_err(|_| { - DataFusionError::Internal( - "Unknown error writing to object store".into(), - ) - })?; - } - Err(_) => { - return Err(DataFusionError::Internal( - "Unknown error writing to object store".into(), - )) - } + let serializer_clone = serializer.clone(); + let handle = tokio::spawn(async move { + let num_rows = batch.num_rows(); + let bytes = serializer_clone.serialize(batch, initial).await?; + Ok((num_rows, bytes)) + }); + if initial { + initial = false; } + tx.send(handle).await.map_err(|_| { + internal_datafusion_err!("Unknown error writing to object store") + })?; } Ok(()) }); @@ -120,7 +112,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store( Err(_) => { return Err(( writer, - DataFusionError::Internal("Unknown error writing to object store".into()), + internal_datafusion_err!("Unknown error writing to object store"), )) } }; @@ -171,9 +163,9 @@ pub(crate) async fn stateless_serialize_and_write_files( // this thread, so we cannot clean it up (hence any_abort_errors is true) any_errors = true; any_abort_errors = true; - triggering_error = Some(DataFusionError::Internal(format!( + triggering_error = Some(internal_datafusion_err!( "Unexpected join error while serializing file {e}" - ))); + )); } } } @@ -190,24 +182,24 @@ pub(crate) async fn stateless_serialize_and_write_files( false => { writer.shutdown() .await - .map_err(|_| DataFusionError::Internal("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!".into()))?; + .map_err(|_| internal_datafusion_err!("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!"))?; } } } if any_errors { match any_abort_errors{ - true => return Err(DataFusionError::Internal("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written.".into())), + true => return internal_err!("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written."), false => match triggering_error { Some(e) => return Err(e), - None => return Err(DataFusionError::Internal("Unknown Error encountered during writing to ObjectStore. All writers succesfully aborted.".into())) + None => return internal_err!("Unknown Error encountered during writing to ObjectStore. All writers succesfully aborted.") } } } tx.send(row_count).map_err(|_| { - DataFusionError::Internal( - "Error encountered while sending row count back to file sink!".into(), + internal_datafusion_err!( + "Error encountered while sending row count back to file sink!" ) })?; Ok(()) @@ -220,7 +212,7 @@ pub(crate) async fn stateless_multipart_put( data: SendableRecordBatchStream, context: &Arc, file_extension: String, - get_serializer: Box Box + Send>, + get_serializer: Box Arc + Send>, config: &FileSinkConfig, compression: FileCompressionType, ) -> Result { @@ -264,8 +256,8 @@ pub(crate) async fn stateless_multipart_put( .send((rb_stream, serializer, writer)) .await .map_err(|_| { - DataFusionError::Internal( - "Writer receive file bundle channel closed unexpectedly!".into(), + internal_datafusion_err!( + "Writer receive file bundle channel closed unexpectedly!" ) })?; } @@ -288,9 +280,7 @@ pub(crate) async fn stateless_multipart_put( } let total_count = rx_row_cnt.await.map_err(|_| { - DataFusionError::Internal( - "Did not receieve row count from write coordinater".into(), - ) + internal_datafusion_err!("Did not receieve row count from write coordinater") })?; Ok(total_count) diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index 99fb088b66f4..bb4c8313642c 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -518,10 +518,8 @@ impl RecordBatchStream for FileStream { #[cfg(test)] mod tests { - use arrow_schema::Schema; - use datafusion_common::internal_err; - use datafusion_common::DataFusionError; - use datafusion_common::Statistics; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; use super::*; use crate::datasource::file_format::write::BatchSerializer; @@ -534,8 +532,8 @@ mod tests { test::{make_partition, object_store::register_test_store}, }; - use std::sync::atomic::{AtomicUsize, Ordering}; - use std::sync::Arc; + use arrow_schema::Schema; + use datafusion_common::{internal_err, DataFusionError, Statistics}; use async_trait::async_trait; use bytes::Bytes; @@ -993,7 +991,7 @@ mod tests { #[async_trait] impl BatchSerializer for TestSerializer { - async fn serialize(&mut self, _batch: RecordBatch) -> Result { + async fn serialize(&self, _batch: RecordBatch, _initial: bool) -> Result { Ok(self.bytes.clone()) } } From 7fc663c2e40be2928778102386bbf76962dd2cdc Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 29 Dec 2023 16:53:31 -0700 Subject: [PATCH 48/63] Implement serde for CSV and Parquet FileSinkExec (#8646) * Add serde for Csv and Parquet sink * Add tests * parquet test passes * save progress * add compression type to csv serde * remove hard-coded compression from CSV serde --- .../core/src/datasource/file_format/csv.rs | 11 +- .../src/datasource/file_format/parquet.rs | 9 +- datafusion/proto/proto/datafusion.proto | 40 +- datafusion/proto/src/generated/pbjson.rs | 517 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 59 +- datafusion/proto/src/logical_plan/mod.rs | 43 +- .../proto/src/physical_plan/from_proto.rs | 38 +- datafusion/proto/src/physical_plan/mod.rs | 91 +++ .../proto/src/physical_plan/to_proto.rs | 46 +- .../tests/cases/roundtrip_physical_plan.rs | 125 ++++- 10 files changed, 922 insertions(+), 57 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index d4e63904bdd4..7a0af3ff0809 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -437,7 +437,7 @@ impl BatchSerializer for CsvSerializer { } /// Implements [`DataSink`] for writing to a CSV file. -struct CsvSink { +pub struct CsvSink { /// Config options for writing data config: FileSinkConfig, } @@ -461,9 +461,16 @@ impl DisplayAs for CsvSink { } impl CsvSink { - fn new(config: FileSinkConfig) -> Self { + /// Create from config. + pub fn new(config: FileSinkConfig) -> Self { Self { config } } + + /// Retrieve the inner [`FileSinkConfig`]. + pub fn config(&self) -> &FileSinkConfig { + &self.config + } + async fn multipartput_all( &self, data: SendableRecordBatchStream, diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 7044acccd6dc..9729bfa163af 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -621,7 +621,7 @@ async fn fetch_statistics( } /// Implements [`DataSink`] for writing to a parquet file. -struct ParquetSink { +pub struct ParquetSink { /// Config options for writing data config: FileSinkConfig, } @@ -645,10 +645,15 @@ impl DisplayAs for ParquetSink { } impl ParquetSink { - fn new(config: FileSinkConfig) -> Self { + /// Create from config. + pub fn new(config: FileSinkConfig) -> Self { Self { config } } + /// Retrieve the inner [`FileSinkConfig`]. + pub fn config(&self) -> &FileSinkConfig { + &self.config + } /// Converts table schema to writer schema, which may differ in the case /// of hive style partitioning where some columns are removed from the /// underlying files. diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 59b82efcbb43..d5f8397aa30c 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1187,6 +1187,8 @@ message PhysicalPlanNode { SymmetricHashJoinExecNode symmetric_hash_join = 25; InterleaveExecNode interleave = 26; PlaceholderRowExecNode placeholder_row = 27; + CsvSinkExecNode csv_sink = 28; + ParquetSinkExecNode parquet_sink = 29; } } @@ -1220,20 +1222,22 @@ message ParquetWriterOptions { } message CsvWriterOptions { + // Compression type + CompressionTypeVariant compression = 1; // Optional column delimiter. Defaults to `b','` - string delimiter = 1; + string delimiter = 2; // Whether to write column names as file headers. Defaults to `true` - bool has_header = 2; + bool has_header = 3; // Optional date format for date arrays - string date_format = 3; + string date_format = 4; // Optional datetime format for datetime arrays - string datetime_format = 4; + string datetime_format = 5; // Optional timestamp format for timestamp arrays - string timestamp_format = 5; + string timestamp_format = 6; // Optional time format for time arrays - string time_format = 6; + string time_format = 7; // Optional value to represent null - string null_value = 7; + string null_value = 8; } message WriterProperties { @@ -1270,6 +1274,28 @@ message JsonSinkExecNode { PhysicalSortExprNodeCollection sort_order = 4; } +message CsvSink { + FileSinkConfig config = 1; +} + +message CsvSinkExecNode { + PhysicalPlanNode input = 1; + CsvSink sink = 2; + Schema sink_schema = 3; + PhysicalSortExprNodeCollection sort_order = 4; +} + +message ParquetSink { + FileSinkConfig config = 1; +} + +message ParquetSinkExecNode { + PhysicalPlanNode input = 1; + ParquetSink sink = 2; + Schema sink_schema = 3; + PhysicalSortExprNodeCollection sort_order = 4; +} + message PhysicalExtensionNode { bytes node = 1; repeated PhysicalPlanNode inputs = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 956244ffdbc2..12e834d75adf 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -5151,6 +5151,241 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { deserializer.deserialize_struct("datafusion.CsvScanExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for CsvSink { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.config.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CsvSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CsvSink { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "config", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Config, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "config" => Ok(GeneratedField::Config), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CsvSink; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CsvSink") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut config__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); + } + config__ = map_.next_value()?; + } + } + } + Ok(CsvSink { + config: config__, + }) + } + } + deserializer.deserialize_struct("datafusion.CsvSink", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for CsvSinkExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.sink.is_some() { + len += 1; + } + if self.sink_schema.is_some() { + len += 1; + } + if self.sort_order.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CsvSinkExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; + } + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; + } + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CsvSinkExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Sink, + SinkSchema, + SortOrder, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CsvSinkExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CsvSinkExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); + } + sink__ = map_.next_value()?; + } + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); + } + sink_schema__ = map_.next_value()?; + } + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); + } + sort_order__ = map_.next_value()?; + } + } + } + Ok(CsvSinkExecNode { + input: input__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, + }) + } + } + deserializer.deserialize_struct("datafusion.CsvSinkExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CsvWriterOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -5159,6 +5394,9 @@ impl serde::Serialize for CsvWriterOptions { { use serde::ser::SerializeStruct; let mut len = 0; + if self.compression != 0 { + len += 1; + } if !self.delimiter.is_empty() { len += 1; } @@ -5181,6 +5419,11 @@ impl serde::Serialize for CsvWriterOptions { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.CsvWriterOptions", len)?; + if self.compression != 0 { + let v = CompressionTypeVariant::try_from(self.compression) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; + struct_ser.serialize_field("compression", &v)?; + } if !self.delimiter.is_empty() { struct_ser.serialize_field("delimiter", &self.delimiter)?; } @@ -5212,6 +5455,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "compression", "delimiter", "has_header", "hasHeader", @@ -5229,6 +5473,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { #[allow(clippy::enum_variant_names)] enum GeneratedField { + Compression, Delimiter, HasHeader, DateFormat, @@ -5257,6 +5502,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { E: serde::de::Error, { match value { + "compression" => Ok(GeneratedField::Compression), "delimiter" => Ok(GeneratedField::Delimiter), "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), "dateFormat" | "date_format" => Ok(GeneratedField::DateFormat), @@ -5283,6 +5529,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { where V: serde::de::MapAccess<'de>, { + let mut compression__ = None; let mut delimiter__ = None; let mut has_header__ = None; let mut date_format__ = None; @@ -5292,6 +5539,12 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { let mut null_value__ = None; while let Some(k) = map_.next_key()? { match k { + GeneratedField::Compression => { + if compression__.is_some() { + return Err(serde::de::Error::duplicate_field("compression")); + } + compression__ = Some(map_.next_value::()? as i32); + } GeneratedField::Delimiter => { if delimiter__.is_some() { return Err(serde::de::Error::duplicate_field("delimiter")); @@ -5337,6 +5590,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { } } Ok(CsvWriterOptions { + compression: compression__.unwrap_or_default(), delimiter: delimiter__.unwrap_or_default(), has_header: has_header__.unwrap_or_default(), date_format: date_format__.unwrap_or_default(), @@ -15398,6 +15652,241 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { deserializer.deserialize_struct("datafusion.ParquetScanExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ParquetSink { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.config.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetSink { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "config", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Config, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "config" => Ok(GeneratedField::Config), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetSink; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ParquetSink") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut config__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); + } + config__ = map_.next_value()?; + } + } + } + Ok(ParquetSink { + config: config__, + }) + } + } + deserializer.deserialize_struct("datafusion.ParquetSink", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ParquetSinkExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.sink.is_some() { + len += 1; + } + if self.sink_schema.is_some() { + len += 1; + } + if self.sort_order.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetSinkExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; + } + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; + } + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetSinkExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Sink, + SinkSchema, + SortOrder, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetSinkExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ParquetSinkExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); + } + sink__ = map_.next_value()?; + } + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); + } + sink_schema__ = map_.next_value()?; + } + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); + } + sort_order__ = map_.next_value()?; + } + } + } + Ok(ParquetSinkExecNode { + input: input__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, + }) + } + } + deserializer.deserialize_struct("datafusion.ParquetSinkExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ParquetWriterOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -18484,6 +18973,12 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::PlaceholderRow(v) => { struct_ser.serialize_field("placeholderRow", v)?; } + physical_plan_node::PhysicalPlanType::CsvSink(v) => { + struct_ser.serialize_field("csvSink", v)?; + } + physical_plan_node::PhysicalPlanType::ParquetSink(v) => { + struct_ser.serialize_field("parquetSink", v)?; + } } } struct_ser.end() @@ -18535,6 +19030,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "interleave", "placeholder_row", "placeholderRow", + "csv_sink", + "csvSink", + "parquet_sink", + "parquetSink", ]; #[allow(clippy::enum_variant_names)] @@ -18565,6 +19064,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { SymmetricHashJoin, Interleave, PlaceholderRow, + CsvSink, + ParquetSink, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18612,6 +19113,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "symmetricHashJoin" | "symmetric_hash_join" => Ok(GeneratedField::SymmetricHashJoin), "interleave" => Ok(GeneratedField::Interleave), "placeholderRow" | "placeholder_row" => Ok(GeneratedField::PlaceholderRow), + "csvSink" | "csv_sink" => Ok(GeneratedField::CsvSink), + "parquetSink" | "parquet_sink" => Ok(GeneratedField::ParquetSink), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18814,6 +19317,20 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("placeholderRow")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::PlaceholderRow) +; + } + GeneratedField::CsvSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("csvSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CsvSink) +; + } + GeneratedField::ParquetSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("parquetSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ParquetSink) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 32e892e663ef..4ee0b70325ca 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1566,7 +1566,7 @@ pub mod owned_table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29" )] pub physical_plan_type: ::core::option::Option, } @@ -1629,6 +1629,10 @@ pub mod physical_plan_node { Interleave(super::InterleaveExecNode), #[prost(message, tag = "27")] PlaceholderRow(super::PlaceholderRowExecNode), + #[prost(message, tag = "28")] + CsvSink(::prost::alloc::boxed::Box), + #[prost(message, tag = "29")] + ParquetSink(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1673,26 +1677,29 @@ pub struct ParquetWriterOptions { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CsvWriterOptions { + /// Compression type + #[prost(enumeration = "CompressionTypeVariant", tag = "1")] + pub compression: i32, /// Optional column delimiter. Defaults to `b','` - #[prost(string, tag = "1")] + #[prost(string, tag = "2")] pub delimiter: ::prost::alloc::string::String, /// Whether to write column names as file headers. Defaults to `true` - #[prost(bool, tag = "2")] + #[prost(bool, tag = "3")] pub has_header: bool, /// Optional date format for date arrays - #[prost(string, tag = "3")] + #[prost(string, tag = "4")] pub date_format: ::prost::alloc::string::String, /// Optional datetime format for datetime arrays - #[prost(string, tag = "4")] + #[prost(string, tag = "5")] pub datetime_format: ::prost::alloc::string::String, /// Optional timestamp format for timestamp arrays - #[prost(string, tag = "5")] + #[prost(string, tag = "6")] pub timestamp_format: ::prost::alloc::string::String, /// Optional time format for time arrays - #[prost(string, tag = "6")] + #[prost(string, tag = "7")] pub time_format: ::prost::alloc::string::String, /// Optional value to represent null - #[prost(string, tag = "7")] + #[prost(string, tag = "8")] pub null_value: ::prost::alloc::string::String, } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1753,6 +1760,42 @@ pub struct JsonSinkExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalExtensionNode { #[prost(bytes = "vec", tag = "1")] pub node: ::prost::alloc::vec::Vec, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index dbed0252d051..5ee88c3d5328 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1648,28 +1648,10 @@ impl AsLogicalPlan for LogicalPlanNode { match opt.as_ref() { FileTypeWriterOptions::CSV(csv_opts) => { let csv_options = &csv_opts.writer_options; - let csv_writer_options = protobuf::CsvWriterOptions { - delimiter: (csv_options.delimiter() as char) - .to_string(), - has_header: csv_options.header(), - date_format: csv_options - .date_format() - .unwrap_or("") - .to_owned(), - datetime_format: csv_options - .datetime_format() - .unwrap_or("") - .to_owned(), - timestamp_format: csv_options - .timestamp_format() - .unwrap_or("") - .to_owned(), - time_format: csv_options - .time_format() - .unwrap_or("") - .to_owned(), - null_value: csv_options.null().to_owned(), - }; + let csv_writer_options = csv_writer_options_to_proto( + csv_options, + (&csv_opts.compression).into(), + ); let csv_options = file_type_writer_options::FileType::CsvOptions( csv_writer_options, @@ -1724,6 +1706,23 @@ impl AsLogicalPlan for LogicalPlanNode { } } +pub(crate) fn csv_writer_options_to_proto( + csv_options: &WriterBuilder, + compression: &CompressionTypeVariant, +) -> protobuf::CsvWriterOptions { + let compression: protobuf::CompressionTypeVariant = compression.into(); + protobuf::CsvWriterOptions { + compression: compression.into(), + delimiter: (csv_options.delimiter() as char).to_string(), + has_header: csv_options.header(), + date_format: csv_options.date_format().unwrap_or("").to_owned(), + datetime_format: csv_options.datetime_format().unwrap_or("").to_owned(), + timestamp_format: csv_options.timestamp_format().unwrap_or("").to_owned(), + time_format: csv_options.time_format().unwrap_or("").to_owned(), + null_value: csv_options.null().to_owned(), + } +} + pub(crate) fn csv_writer_options_from_proto( writer_options: &protobuf::CsvWriterOptions, ) -> Result { diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 6f1e811510c6..8ad6d679df4d 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -22,7 +22,10 @@ use std::sync::Arc; use arrow::compute::SortOptions; use datafusion::arrow::datatypes::Schema; +use datafusion::datasource::file_format::csv::CsvSink; use datafusion::datasource::file_format::json::JsonSink; +#[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::datasource::listing::{FileRange, ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; @@ -713,6 +716,23 @@ impl TryFrom<&protobuf::JsonSink> for JsonSink { } } +#[cfg(feature = "parquet")] +impl TryFrom<&protobuf::ParquetSink> for ParquetSink { + type Error = DataFusionError; + + fn try_from(value: &protobuf::ParquetSink) -> Result { + Ok(Self::new(convert_required!(value.config)?)) + } +} + +impl TryFrom<&protobuf::CsvSink> for CsvSink { + type Error = DataFusionError; + + fn try_from(value: &protobuf::CsvSink) -> Result { + Ok(Self::new(convert_required!(value.config)?)) + } +} + impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { type Error = DataFusionError; @@ -768,16 +788,16 @@ impl TryFrom<&protobuf::FileTypeWriterOptions> for FileTypeWriterOptions { .file_type .as_ref() .ok_or_else(|| proto_error("Missing required file_type field in protobuf"))?; + match file_type { - protobuf::file_type_writer_options::FileType::JsonOptions(opts) => Ok( - Self::JSON(JsonWriterOptions::new(opts.compression().into())), - ), - protobuf::file_type_writer_options::FileType::CsvOptions(opt) => { - let write_options = csv_writer_options_from_proto(opt)?; - Ok(Self::CSV(CsvWriterOptions::new( - write_options, - CompressionTypeVariant::UNCOMPRESSED, - ))) + protobuf::file_type_writer_options::FileType::JsonOptions(opts) => { + let compression: CompressionTypeVariant = opts.compression().into(); + Ok(Self::JSON(JsonWriterOptions::new(compression))) + } + protobuf::file_type_writer_options::FileType::CsvOptions(opts) => { + let write_options = csv_writer_options_from_proto(opts)?; + let compression: CompressionTypeVariant = opts.compression().into(); + Ok(Self::CSV(CsvWriterOptions::new(write_options, compression))) } protobuf::file_type_writer_options::FileType::ParquetOptions(opt) => { let props = opt.writer_properties.clone().unwrap_or_default(); diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 24ede3fcaf62..95becb3fe4b3 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -21,9 +21,12 @@ use std::sync::Arc; use datafusion::arrow::compute::SortOptions; use datafusion::arrow::datatypes::SchemaRef; +use datafusion::datasource::file_format::csv::CsvSink; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::file_format::json::JsonSink; #[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetSink; +#[cfg(feature = "parquet")] use datafusion::datasource::physical_plan::ParquetExec; use datafusion::datasource::physical_plan::{AvroExec, CsvExec}; use datafusion::execution::runtime_env::RuntimeEnv; @@ -921,6 +924,68 @@ impl AsExecutionPlan for PhysicalPlanNode { sort_order, ))) } + PhysicalPlanType::CsvSink(sink) => { + let input = + into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + + let data_sink: CsvSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = convert_required!(sink.sink_schema)?; + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + collection + .physical_sort_expr_nodes + .iter() + .map(|proto| { + parse_physical_sort_expr(proto, registry, &sink_schema) + .map(Into::into) + }) + .collect::>>() + }) + .transpose()?; + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(data_sink), + Arc::new(sink_schema), + sort_order, + ))) + } + PhysicalPlanType::ParquetSink(sink) => { + let input = + into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + + let data_sink: ParquetSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = convert_required!(sink.sink_schema)?; + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + collection + .physical_sort_expr_nodes + .iter() + .map(|proto| { + parse_physical_sort_expr(proto, registry, &sink_schema) + .map(Into::into) + }) + .collect::>>() + }) + .transpose()?; + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(data_sink), + Arc::new(sink_schema), + sort_order, + ))) + } } } @@ -1678,6 +1743,32 @@ impl AsExecutionPlan for PhysicalPlanNode { }); } + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::CsvSink(Box::new( + protobuf::CsvSinkExecNode { + input: Some(Box::new(input)), + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, + }, + ))), + }); + } + + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::ParquetSink(Box::new( + protobuf::ParquetSinkExecNode { + input: Some(Box::new(input)), + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, + }, + ))), + }); + } + // If unknown DataSink then let extension handle it } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index e9cdb34cf1b9..f4e3f9e4dca7 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -28,7 +28,12 @@ use crate::protobuf::{ ScalarValue, }; +#[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetSink; + +use crate::logical_plan::{csv_writer_options_to_proto, writer_properties_to_proto}; use datafusion::datasource::{ + file_format::csv::CsvSink, file_format::json::JsonSink, listing::{FileRange, PartitionedFile}, physical_plan::FileScanConfig, @@ -814,6 +819,27 @@ impl TryFrom<&JsonSink> for protobuf::JsonSink { } } +impl TryFrom<&CsvSink> for protobuf::CsvSink { + type Error = DataFusionError; + + fn try_from(value: &CsvSink) -> Result { + Ok(Self { + config: Some(value.config().try_into()?), + }) + } +} + +#[cfg(feature = "parquet")] +impl TryFrom<&ParquetSink> for protobuf::ParquetSink { + type Error = DataFusionError; + + fn try_from(value: &ParquetSink) -> Result { + Ok(Self { + config: Some(value.config().try_into()?), + }) + } +} + impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { type Error = DataFusionError; @@ -870,13 +896,21 @@ impl TryFrom<&FileTypeWriterOptions> for protobuf::FileTypeWriterOptions { fn try_from(opts: &FileTypeWriterOptions) -> Result { let file_type = match opts { #[cfg(feature = "parquet")] - FileTypeWriterOptions::Parquet(ParquetWriterOptions { - writer_options: _, - }) => return not_impl_err!("Parquet file sink protobuf serialization"), + FileTypeWriterOptions::Parquet(ParquetWriterOptions { writer_options }) => { + protobuf::file_type_writer_options::FileType::ParquetOptions( + protobuf::ParquetWriterOptions { + writer_properties: Some(writer_properties_to_proto( + writer_options, + )), + }, + ) + } FileTypeWriterOptions::CSV(CsvWriterOptions { - writer_options: _, - compression: _, - }) => return not_impl_err!("CSV file sink protobuf serialization"), + writer_options, + compression, + }) => protobuf::file_type_writer_options::FileType::CsvOptions( + csv_writer_options_to_proto(writer_options, compression), + ), FileTypeWriterOptions::JSON(JsonWriterOptions { compression }) => { let compression: protobuf::CompressionTypeVariant = compression.into(); protobuf::file_type_writer_options::FileType::JsonOptions( diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 2eb04ab6cbab..27ac5d122f83 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. +use arrow::csv::WriterBuilder; use std::ops::Deref; use std::sync::Arc; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema}; +use datafusion::datasource::file_format::csv::CsvSink; use datafusion::datasource::file_format::json::JsonSink; +use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{ @@ -31,6 +34,7 @@ use datafusion::execution::context::ExecutionProps; use datafusion::logical_expr::{ create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, }; +use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ @@ -62,7 +66,9 @@ use datafusion::physical_plan::{ }; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{FileTypeWriterOptions, Result}; @@ -73,7 +79,23 @@ use datafusion_expr::{ use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; +/// Perform a serde roundtrip and assert that the string representation of the before and after plans +/// are identical. Note that this often isn't sufficient to guarantee that no information is +/// lost during serde because the string representation of a plan often only shows a subset of state. fn roundtrip_test(exec_plan: Arc) -> Result<()> { + let _ = roundtrip_test_and_return(exec_plan); + Ok(()) +} + +/// Perform a serde roundtrip and assert that the string representation of the before and after plans +/// are identical. Note that this often isn't sufficient to guarantee that no information is +/// lost during serde because the string representation of a plan often only shows a subset of state. +/// +/// This version of the roundtrip_test method returns the final plan after serde so that it can be inspected +/// farther in tests. +fn roundtrip_test_and_return( + exec_plan: Arc, +) -> Result> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; let proto: protobuf::PhysicalPlanNode = @@ -84,9 +106,15 @@ fn roundtrip_test(exec_plan: Arc) -> Result<()> { .try_into_physical_plan(&ctx, runtime.deref(), &codec) .expect("from proto"); assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); - Ok(()) + Ok(result_exec_plan) } +/// Perform a serde roundtrip and assert that the string representation of the before and after plans +/// are identical. Note that this often isn't sufficient to guarantee that no information is +/// lost during serde because the string representation of a plan often only shows a subset of state. +/// +/// This version of the roundtrip_test function accepts a SessionContext, which is required when +/// performing serde on some plans. fn roundtrip_test_with_context( exec_plan: Arc, ctx: SessionContext, @@ -755,6 +783,101 @@ fn roundtrip_json_sink() -> Result<()> { ))) } +#[test] +fn roundtrip_csv_sink() -> Result<()> { + let field_a = Field::new("plan_type", DataType::Utf8, false); + let field_b = Field::new("plan", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let input = Arc::new(PlaceholderRowExec::new(schema.clone())); + + let file_sink_config = FileSinkConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], + single_file_output: true, + overwrite: true, + file_type_writer_options: FileTypeWriterOptions::CSV(CsvWriterOptions::new( + WriterBuilder::default(), + CompressionTypeVariant::ZSTD, + )), + }; + let data_sink = Arc::new(CsvSink::new(file_sink_config)); + let sort_order = vec![PhysicalSortRequirement::new( + Arc::new(Column::new("plan_type", 0)), + Some(SortOptions { + descending: true, + nulls_first: false, + }), + )]; + + let roundtrip_plan = roundtrip_test_and_return(Arc::new(FileSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + ))) + .unwrap(); + + let roundtrip_plan = roundtrip_plan + .as_any() + .downcast_ref::() + .unwrap(); + let csv_sink = roundtrip_plan + .sink() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + CompressionTypeVariant::ZSTD, + csv_sink + .config() + .file_type_writer_options + .try_into_csv() + .unwrap() + .compression + ); + + Ok(()) +} + +#[test] +fn roundtrip_parquet_sink() -> Result<()> { + let field_a = Field::new("plan_type", DataType::Utf8, false); + let field_b = Field::new("plan", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let input = Arc::new(PlaceholderRowExec::new(schema.clone())); + + let file_sink_config = FileSinkConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], + single_file_output: true, + overwrite: true, + file_type_writer_options: FileTypeWriterOptions::Parquet( + ParquetWriterOptions::new(WriterProperties::default()), + ), + }; + let data_sink = Arc::new(ParquetSink::new(file_sink_config)); + let sort_order = vec![PhysicalSortRequirement::new( + Arc::new(Column::new("plan_type", 0)), + Some(SortOptions { + descending: true, + nulls_first: false, + }), + )]; + + roundtrip_test(Arc::new(FileSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + ))) +} + #[test] fn roundtrip_sym_hash_join() -> Result<()> { let field_a = Field::new("col", DataType::Int64, false); From 7f440e18f22ac9b6a6b72ca305fd04704de325fd Mon Sep 17 00:00:00 2001 From: Yang Jiang Date: Sat, 30 Dec 2023 08:33:32 +0800 Subject: [PATCH 49/63] [pruning] Add shortcut when all units have been pruned (#8675) --- datafusion/core/src/physical_optimizer/pruning.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 79e084d7b7f1..fecbffdbb041 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -258,6 +258,11 @@ impl PruningPredicate { builder.combine_array(&arrow::compute::not(&results)?) } } + // if all containers are pruned (has rows that DEFINITELY DO NOT pass the predicate) + // can return early without evaluating the rest of predicates. + if builder.check_all_pruned() { + return Ok(builder.build()); + } } } @@ -380,6 +385,11 @@ impl BoolVecBuilder { fn build(self) -> Vec { self.inner } + + /// Check all containers has rows that DEFINITELY DO NOT pass the predicate + fn check_all_pruned(&self) -> bool { + self.inner.iter().all(|&x| !x) + } } fn is_always_true(expr: &Arc) -> bool { From bb98dfed08d8c2b94ab668a064b206d8b84b51b0 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Sat, 30 Dec 2023 03:48:36 +0300 Subject: [PATCH 50/63] Change first/last implementation to prevent redundant comparisons when data is already sorted (#8678) * Change fist last implementation to prevent redundant computations * Remove redundant checks * Review --------- Co-authored-by: Mehmet Ozan Kabak --- .../physical-expr/src/aggregate/first_last.rs | 259 +++++++++++------- .../physical-plan/src/aggregates/mod.rs | 77 +++++- .../sqllogictest/test_files/groupby.slt | 14 +- 3 files changed, 234 insertions(+), 116 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index c7032e601cf8..4afa8d0dd5ec 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -36,13 +36,14 @@ use datafusion_common::{ use datafusion_expr::Accumulator; /// FIRST_VALUE aggregate expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct FirstValue { name: String, input_data_type: DataType, order_by_data_types: Vec, expr: Arc, ordering_req: LexOrdering, + requirement_satisfied: bool, } impl FirstValue { @@ -54,12 +55,14 @@ impl FirstValue { ordering_req: LexOrdering, order_by_data_types: Vec, ) -> Self { + let requirement_satisfied = ordering_req.is_empty(); Self { name: name.into(), input_data_type, order_by_data_types, expr, ordering_req, + requirement_satisfied, } } @@ -87,6 +90,33 @@ impl FirstValue { pub fn ordering_req(&self) -> &LexOrdering { &self.ordering_req } + + pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } + + pub fn convert_to_last(self) -> LastValue { + let name = if self.name.starts_with("FIRST") { + format!("LAST{}", &self.name[5..]) + } else { + format!("LAST_VALUE({})", self.expr) + }; + let FirstValue { + expr, + input_data_type, + ordering_req, + order_by_data_types, + .. + } = self; + LastValue::new( + expr, + name, + input_data_type, + reverse_order_bys(&ordering_req), + order_by_data_types, + ) + } } impl AggregateExpr for FirstValue { @@ -100,11 +130,14 @@ impl AggregateExpr for FirstValue { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(FirstValueAccumulator::try_new( + FirstValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } fn state_fields(&self) -> Result> { @@ -130,11 +163,7 @@ impl AggregateExpr for FirstValue { } fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - if self.ordering_req.is_empty() { - None - } else { - Some(&self.ordering_req) - } + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) } fn name(&self) -> &str { @@ -142,26 +171,18 @@ impl AggregateExpr for FirstValue { } fn reverse_expr(&self) -> Option> { - let name = if self.name.starts_with("FIRST") { - format!("LAST{}", &self.name[5..]) - } else { - format!("LAST_VALUE({})", self.expr) - }; - Some(Arc::new(LastValue::new( - self.expr.clone(), - name, - self.input_data_type.clone(), - reverse_order_bys(&self.ordering_req), - self.order_by_data_types.clone(), - ))) + Some(Arc::new(self.clone().convert_to_last())) } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(FirstValueAccumulator::try_new( + FirstValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } } @@ -190,6 +211,8 @@ struct FirstValueAccumulator { orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, + // Stores whether incoming data already satisfies the ordering requirement. + requirement_satisfied: bool, } impl FirstValueAccumulator { @@ -203,42 +226,29 @@ impl FirstValueAccumulator { .iter() .map(ScalarValue::try_from) .collect::>>()?; - ScalarValue::try_from(data_type).map(|value| Self { - first: value, + let requirement_satisfied = ordering_req.is_empty(); + ScalarValue::try_from(data_type).map(|first| Self { + first, is_set: false, orderings, ordering_req, + requirement_satisfied, }) } // Updates state with the values in the given row. - fn update_with_new_row(&mut self, row: &[ScalarValue]) -> Result<()> { - let [value, orderings @ ..] = row else { - return internal_err!("Empty row in FIRST_VALUE"); - }; - // Update when there is no entry in the state, or we have an "earlier" - // entry according to sort requirements. - if !self.is_set - || compare_rows( - &self.orderings, - orderings, - &get_sort_options(&self.ordering_req), - )? - .is_gt() - { - self.first = value.clone(); - self.orderings = orderings.to_vec(); - self.is_set = true; - } - Ok(()) + fn update_with_new_row(&mut self, row: &[ScalarValue]) { + self.first = row[0].clone(); + self.orderings = row[1..].to_vec(); + self.is_set = true; } fn get_first_idx(&self, values: &[ArrayRef]) -> Result> { let [value, ordering_values @ ..] = values else { return internal_err!("Empty row in FIRST_VALUE"); }; - if self.ordering_req.is_empty() { - // Get first entry according to receive order (0th index) + if self.requirement_satisfied { + // Get first entry according to the pre-existing ordering (0th index): return Ok((!value.is_empty()).then_some(0)); } let sort_columns = ordering_values @@ -252,6 +262,11 @@ impl FirstValueAccumulator { let indices = lexsort_to_indices(&sort_columns, Some(1))?; Ok((!indices.is_empty()).then_some(indices.value(0) as _)) } + + fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } } impl Accumulator for FirstValueAccumulator { @@ -263,9 +278,25 @@ impl Accumulator for FirstValueAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if let Some(first_idx) = self.get_first_idx(values)? { - let row = get_row_at_idx(values, first_idx)?; - self.update_with_new_row(&row)?; + if !self.is_set { + if let Some(first_idx) = self.get_first_idx(values)? { + let row = get_row_at_idx(values, first_idx)?; + self.update_with_new_row(&row); + } + } else if !self.requirement_satisfied { + if let Some(first_idx) = self.get_first_idx(values)? { + let row = get_row_at_idx(values, first_idx)?; + let orderings = &row[1..]; + if compare_rows( + &self.orderings, + orderings, + &get_sort_options(&self.ordering_req), + )? + .is_gt() + { + self.update_with_new_row(&row); + } + } } Ok(()) } @@ -294,12 +325,12 @@ impl Accumulator for FirstValueAccumulator { let sort_options = get_sort_options(&self.ordering_req); // Either there is no existing value, or there is an earlier version in new data. if !self.is_set - || compare_rows(first_ordering, &self.orderings, &sort_options)?.is_lt() + || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt() { // Update with first value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state // containing two is_set flags. - self.update_with_new_row(&first_row[0..is_set_idx])?; + self.update_with_new_row(&first_row[0..is_set_idx]); } } Ok(()) @@ -318,13 +349,14 @@ impl Accumulator for FirstValueAccumulator { } /// LAST_VALUE aggregate expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct LastValue { name: String, input_data_type: DataType, order_by_data_types: Vec, expr: Arc, ordering_req: LexOrdering, + requirement_satisfied: bool, } impl LastValue { @@ -336,12 +368,14 @@ impl LastValue { ordering_req: LexOrdering, order_by_data_types: Vec, ) -> Self { + let requirement_satisfied = ordering_req.is_empty(); Self { name: name.into(), input_data_type, order_by_data_types, expr, ordering_req, + requirement_satisfied, } } @@ -369,6 +403,33 @@ impl LastValue { pub fn ordering_req(&self) -> &LexOrdering { &self.ordering_req } + + pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } + + pub fn convert_to_first(self) -> FirstValue { + let name = if self.name.starts_with("LAST") { + format!("FIRST{}", &self.name[4..]) + } else { + format!("FIRST_VALUE({})", self.expr) + }; + let LastValue { + expr, + input_data_type, + ordering_req, + order_by_data_types, + .. + } = self; + FirstValue::new( + expr, + name, + input_data_type, + reverse_order_bys(&ordering_req), + order_by_data_types, + ) + } } impl AggregateExpr for LastValue { @@ -382,11 +443,14 @@ impl AggregateExpr for LastValue { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(LastValueAccumulator::try_new( + LastValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } fn state_fields(&self) -> Result> { @@ -412,11 +476,7 @@ impl AggregateExpr for LastValue { } fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - if self.ordering_req.is_empty() { - None - } else { - Some(&self.ordering_req) - } + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) } fn name(&self) -> &str { @@ -424,26 +484,18 @@ impl AggregateExpr for LastValue { } fn reverse_expr(&self) -> Option> { - let name = if self.name.starts_with("LAST") { - format!("FIRST{}", &self.name[4..]) - } else { - format!("FIRST_VALUE({})", self.expr) - }; - Some(Arc::new(FirstValue::new( - self.expr.clone(), - name, - self.input_data_type.clone(), - reverse_order_bys(&self.ordering_req), - self.order_by_data_types.clone(), - ))) + Some(Arc::new(self.clone().convert_to_first())) } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(LastValueAccumulator::try_new( + LastValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } } @@ -471,6 +523,8 @@ struct LastValueAccumulator { orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, + // Stores whether incoming data already satisfies the ordering requirement. + requirement_satisfied: bool, } impl LastValueAccumulator { @@ -484,42 +538,28 @@ impl LastValueAccumulator { .iter() .map(ScalarValue::try_from) .collect::>>()?; - Ok(Self { - last: ScalarValue::try_from(data_type)?, + let requirement_satisfied = ordering_req.is_empty(); + ScalarValue::try_from(data_type).map(|last| Self { + last, is_set: false, orderings, ordering_req, + requirement_satisfied, }) } // Updates state with the values in the given row. - fn update_with_new_row(&mut self, row: &[ScalarValue]) -> Result<()> { - let [value, orderings @ ..] = row else { - return internal_err!("Empty row in LAST_VALUE"); - }; - // Update when there is no entry in the state, or we have a "later" - // entry (either according to sort requirements or the order of execution). - if !self.is_set - || self.orderings.is_empty() - || compare_rows( - &self.orderings, - orderings, - &get_sort_options(&self.ordering_req), - )? - .is_lt() - { - self.last = value.clone(); - self.orderings = orderings.to_vec(); - self.is_set = true; - } - Ok(()) + fn update_with_new_row(&mut self, row: &[ScalarValue]) { + self.last = row[0].clone(); + self.orderings = row[1..].to_vec(); + self.is_set = true; } fn get_last_idx(&self, values: &[ArrayRef]) -> Result> { let [value, ordering_values @ ..] = values else { return internal_err!("Empty row in LAST_VALUE"); }; - if self.ordering_req.is_empty() { + if self.requirement_satisfied { // Get last entry according to the order of data: return Ok((!value.is_empty()).then_some(value.len() - 1)); } @@ -538,6 +578,11 @@ impl LastValueAccumulator { let indices = lexsort_to_indices(&sort_columns, Some(1))?; Ok((!indices.is_empty()).then_some(indices.value(0) as _)) } + + fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } } impl Accumulator for LastValueAccumulator { @@ -549,10 +594,26 @@ impl Accumulator for LastValueAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if let Some(last_idx) = self.get_last_idx(values)? { + if !self.is_set || self.requirement_satisfied { + if let Some(last_idx) = self.get_last_idx(values)? { + let row = get_row_at_idx(values, last_idx)?; + self.update_with_new_row(&row); + } + } else if let Some(last_idx) = self.get_last_idx(values)? { let row = get_row_at_idx(values, last_idx)?; - self.update_with_new_row(&row)?; + let orderings = &row[1..]; + // Update when there is a more recent entry + if compare_rows( + &self.orderings, + orderings, + &get_sort_options(&self.ordering_req), + )? + .is_lt() + { + self.update_with_new_row(&row); + } } + Ok(()) } @@ -583,12 +644,12 @@ impl Accumulator for LastValueAccumulator { // Either there is no existing value, or there is a newer (latest) // version in the new data: if !self.is_set - || compare_rows(last_ordering, &self.orderings, &sort_options)?.is_gt() + || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt() { // Update with last value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state // containing two is_set flags. - self.update_with_new_row(&last_row[0..is_set_idx])?; + self.update_with_new_row(&last_row[0..is_set_idx]); } } Ok(()) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index f5bb4fe59b5d..a38044de02e3 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -44,9 +44,9 @@ use datafusion_expr::Accumulator; use datafusion_physical_expr::{ aggregate::is_order_sensitive, equivalence::{collapse_lex_req, ProjectionMapping}, - expressions::{Column, Max, Min, UnKnownColumn}, - physical_exprs_contains, AggregateExpr, EquivalenceProperties, LexOrdering, - LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, + expressions::{Column, FirstValue, LastValue, Max, Min, UnKnownColumn}, + physical_exprs_contains, reverse_order_bys, AggregateExpr, EquivalenceProperties, + LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; use itertools::Itertools; @@ -324,7 +324,7 @@ impl AggregateExec { fn try_new_with_schema( mode: AggregateMode, group_by: PhysicalGroupBy, - aggr_expr: Vec>, + mut aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, @@ -347,7 +347,8 @@ impl AggregateExec { .collect::>(); let req = get_aggregate_exprs_requirement( - &aggr_expr, + &new_requirement, + &mut aggr_expr, &group_by, &input_eq_properties, &mode, @@ -896,6 +897,11 @@ fn finer_ordering( eq_properties.get_finer_ordering(existing_req, &aggr_req) } +/// Concatenates the given slices. +fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { + [lhs, rhs].concat() +} + /// Get the common requirement that satisfies all the aggregate expressions. /// /// # Parameters @@ -914,14 +920,64 @@ fn finer_ordering( /// A `LexRequirement` instance, which is the requirement that satisfies all the /// aggregate requirements. Returns an error in case of conflicting requirements. fn get_aggregate_exprs_requirement( - aggr_exprs: &[Arc], + prefix_requirement: &[PhysicalSortRequirement], + aggr_exprs: &mut [Arc], group_by: &PhysicalGroupBy, eq_properties: &EquivalenceProperties, agg_mode: &AggregateMode, ) -> Result { let mut requirement = vec![]; - for aggr_expr in aggr_exprs.iter() { - if let Some(finer_ordering) = + for aggr_expr in aggr_exprs.iter_mut() { + let aggr_req = aggr_expr.order_bys().unwrap_or(&[]); + let reverse_aggr_req = reverse_order_bys(aggr_req); + let aggr_req = PhysicalSortRequirement::from_sort_exprs(aggr_req); + let reverse_aggr_req = + PhysicalSortRequirement::from_sort_exprs(&reverse_aggr_req); + if let Some(first_value) = aggr_expr.as_any().downcast_ref::() { + let mut first_value = first_value.clone(); + if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &aggr_req, + )) { + first_value = first_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(first_value) as _; + } else if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &reverse_aggr_req, + )) { + // Converting to LAST_VALUE enables more efficient execution + // given the existing ordering: + let mut last_value = first_value.convert_to_last(); + last_value = last_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(last_value) as _; + } else { + // Requirement is not satisfied with existing ordering. + first_value = first_value.with_requirement_satisfied(false); + *aggr_expr = Arc::new(first_value) as _; + } + } else if let Some(last_value) = aggr_expr.as_any().downcast_ref::() { + let mut last_value = last_value.clone(); + if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &aggr_req, + )) { + last_value = last_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(last_value) as _; + } else if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &reverse_aggr_req, + )) { + // Converting to FIRST_VALUE enables more efficient execution + // given the existing ordering: + let mut first_value = last_value.convert_to_first(); + first_value = first_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(first_value) as _; + } else { + // Requirement is not satisfied with existing ordering. + last_value = last_value.with_requirement_satisfied(false); + *aggr_expr = Arc::new(last_value) as _; + } + } else if let Some(finer_ordering) = finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) { requirement = finer_ordering; @@ -2071,7 +2127,7 @@ mod tests { options: options1, }, ]; - let aggr_exprs = order_by_exprs + let mut aggr_exprs = order_by_exprs .into_iter() .map(|order_by_expr| { Arc::new(OrderSensitiveArrayAgg::new( @@ -2086,7 +2142,8 @@ mod tests { .collect::>(); let group_by = PhysicalGroupBy::new_single(vec![]); let res = get_aggregate_exprs_requirement( - &aggr_exprs, + &[], + &mut aggr_exprs, &group_by, &eq_properties, &AggregateMode::Partial, diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index bbf21e135fe4..b09ff79e88d5 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -2508,7 +2508,7 @@ Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales ----TableScan: sales_global projection=[country, amount] physical_plan ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ----SortExec: expr=[amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2539,7 +2539,7 @@ Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales ----TableScan: sales_global projection=[country, amount] physical_plan ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] ----SortExec: expr=[amount@1 ASC NULLS LAST] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2571,7 +2571,7 @@ Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sal ----TableScan: sales_global projection=[country, amount] physical_plan ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@2 as fv2, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@3 as amounts] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount), ARRAY_AGG(sales_global.amount)] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount), ARRAY_AGG(sales_global.amount)] ----SortExec: expr=[amount@1 ASC NULLS LAST] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2636,7 +2636,7 @@ Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sal ------TableScan: sales_global projection=[country, ts, amount] physical_plan ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount), SUM(sales_global.amount)] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[LAST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount), SUM(sales_global.amount)] ----MemoryExec: partitions=1, partition_sizes=[1] query TRRR rowsort @@ -2988,7 +2988,7 @@ SortPreservingMergeExec: [country@0 ASC NULLS LAST] ------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------CoalesceBatchesExec: target_batch_size=4 ----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 -------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------------SortExec: expr=[amount@1 DESC] ----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------------MemoryExec: partitions=1, partition_sizes=[1] @@ -3631,10 +3631,10 @@ Projection: FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_tab ----TableScan: multiple_ordered_table projection=[a, c, d] physical_plan ProjectionExec: expr=[FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST]@1 as first_a, LAST_VALUE(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]@2 as last_c] ---AggregateExec: mode=FinalPartitioned, gby=[d@0 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), LAST_VALUE(multiple_ordered_table.c)] +--AggregateExec: mode=FinalPartitioned, gby=[d@0 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), FIRST_VALUE(multiple_ordered_table.c)] ----CoalesceBatchesExec: target_batch_size=2 ------RepartitionExec: partitioning=Hash([d@0], 8), input_partitions=8 ---------AggregateExec: mode=Partial, gby=[d@2 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), LAST_VALUE(multiple_ordered_table.c)] +--------AggregateExec: mode=Partial, gby=[d@2 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), FIRST_VALUE(multiple_ordered_table.c)] ----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true From cc3042a6343457036770267f921bb3b6e726956c Mon Sep 17 00:00:00 2001 From: comphead Date: Fri, 29 Dec 2023 22:47:46 -0800 Subject: [PATCH 51/63] minor: remove unused conversion (#8684) Fixes clippy error in main --- datafusion/proto/src/logical_plan/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 5ee88c3d5328..e8a38784481b 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1650,7 +1650,7 @@ impl AsLogicalPlan for LogicalPlanNode { let csv_options = &csv_opts.writer_options; let csv_writer_options = csv_writer_options_to_proto( csv_options, - (&csv_opts.compression).into(), + &csv_opts.compression, ); let csv_options = file_type_writer_options::FileType::CsvOptions( From 00a679a0533f1f878db43c2a9cdcaa2e92ab859e Mon Sep 17 00:00:00 2001 From: Eduard Karacharov <13005055+korowa@users.noreply.github.com> Date: Sat, 30 Dec 2023 16:08:59 +0200 Subject: [PATCH 52/63] refactor: modified `JoinHashMap` build order for `HashJoinStream` (#8658) * maintaining fifo hashmap in hash join * extended HashJoinExec docstring on build phase * testcases for randomly ordered build side input * trigger ci --- .../physical-plan/src/joins/hash_join.rs | 316 ++++++++++++------ .../src/joins/symmetric_hash_join.rs | 2 + datafusion/physical-plan/src/joins/utils.rs | 78 ++++- 3 files changed, 300 insertions(+), 96 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 13ac06ee301c..374a0ad50700 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -29,7 +29,6 @@ use crate::joins::utils::{ need_produce_result_in_final, JoinHashMap, JoinHashMapType, }; use crate::{ - coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec, expressions::Column, expressions::PhysicalSortExpr, @@ -52,10 +51,10 @@ use super::{ use arrow::array::{ Array, ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray, UInt32Array, - UInt32BufferBuilder, UInt64Array, UInt64BufferBuilder, + UInt64Array, }; use arrow::compute::kernels::cmp::{eq, not_distinct}; -use arrow::compute::{and, take, FilterBuilder}; +use arrow::compute::{and, concat_batches, take, FilterBuilder}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; @@ -156,8 +155,48 @@ impl JoinLeftData { /// /// Execution proceeds in 2 stages: /// -/// 1. the **build phase** where a hash table is created from the tuples of the -/// build side. +/// 1. the **build phase** creates a hash table from the tuples of the build side, +/// and single concatenated batch containing data from all fetched record batches. +/// Resulting hash table stores hashed join-key fields for each row as a key, and +/// indices of corresponding rows in concatenated batch. +/// +/// Hash join uses LIFO data structure as a hash table, and in order to retain +/// original build-side input order while obtaining data during probe phase, hash +/// table is updated by iterating batch sequence in reverse order -- it allows to +/// keep rows with smaller indices "on the top" of hash table, and still maintain +/// correct indexing for concatenated build-side data batch. +/// +/// Example of build phase for 3 record batches: +/// +/// +/// ```text +/// +/// Original build-side data Inserting build-side values into hashmap Concatenated build-side batch +/// ┌───────────────────────────┐ +/// hasmap.insert(row-hash, row-idx + offset) │ idx │ +/// ┌───────┐ │ ┌───────┐ │ +/// │ Row 1 │ 1) update_hash for batch 3 with offset 0 │ │ Row 6 │ 0 │ +/// Batch 1 │ │ - hashmap.insert(Row 7, idx 1) │ Batch 3 │ │ │ +/// │ Row 2 │ - hashmap.insert(Row 6, idx 0) │ │ Row 7 │ 1 │ +/// └───────┘ │ └───────┘ │ +/// │ │ +/// ┌───────┐ │ ┌───────┐ │ +/// │ Row 3 │ 2) update_hash for batch 2 with offset 2 │ │ Row 3 │ 2 │ +/// │ │ - hashmap.insert(Row 5, idx 4) │ │ │ │ +/// Batch 2 │ Row 4 │ - hashmap.insert(Row 4, idx 3) │ Batch 2 │ Row 4 │ 3 │ +/// │ │ - hashmap.insert(Row 3, idx 2) │ │ │ │ +/// │ Row 5 │ │ │ Row 5 │ 4 │ +/// └───────┘ │ └───────┘ │ +/// │ │ +/// ┌───────┐ │ ┌───────┐ │ +/// │ Row 6 │ 3) update_hash for batch 1 with offset 5 │ │ Row 1 │ 5 │ +/// Batch 3 │ │ - hashmap.insert(Row 2, idx 5) │ Batch 1 │ │ │ +/// │ Row 7 │ - hashmap.insert(Row 1, idx 6) │ │ Row 2 │ 6 │ +/// └───────┘ │ └───────┘ │ +/// │ │ +/// └───────────────────────────┘ +/// +/// ``` /// /// 2. the **probe phase** where the tuples of the probe side are streamed /// through, checking for matches of the join keys in the hash table. @@ -715,7 +754,10 @@ async fn collect_left_input( let mut hashmap = JoinHashMap::with_capacity(num_rows); let mut hashes_buffer = Vec::new(); let mut offset = 0; - for batch in batches.iter() { + + // Updating hashmap starting from the last batch + let batches_iter = batches.iter().rev(); + for batch in batches_iter.clone() { hashes_buffer.clear(); hashes_buffer.resize(batch.num_rows(), 0); update_hash( @@ -726,19 +768,25 @@ async fn collect_left_input( &random_state, &mut hashes_buffer, 0, + true, )?; offset += batch.num_rows(); } // Merge all batches into a single batch, so we // can directly index into the arrays - let single_batch = concat_batches(&schema, &batches, num_rows)?; + let single_batch = concat_batches(&schema, batches_iter)?; let data = JoinLeftData::new(hashmap, single_batch, reservation); Ok(data) } -/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, -/// assuming that the [RecordBatch] corresponds to the `index`th +/// Updates `hash_map` with new entries from `batch` evaluated against the expressions `on` +/// using `offset` as a start value for `batch` row indices. +/// +/// `fifo_hashmap` sets the order of iteration over `batch` rows while updating hashmap, +/// which allows to keep either first (if set to true) or last (if set to false) row index +/// as a chain head for rows with equal hash values. +#[allow(clippy::too_many_arguments)] pub fn update_hash( on: &[Column], batch: &RecordBatch, @@ -747,6 +795,7 @@ pub fn update_hash( random_state: &RandomState, hashes_buffer: &mut Vec, deleted_offset: usize, + fifo_hashmap: bool, ) -> Result<()> where T: JoinHashMapType, @@ -763,28 +812,18 @@ where // For usual JoinHashmap, the implementation is void. hash_map.extend_zero(batch.num_rows()); - // insert hashes to key of the hashmap - let (mut_map, mut_list) = hash_map.get_mut(); - for (row, hash_value) in hash_values.iter().enumerate() { - let item = mut_map.get_mut(*hash_value, |(hash, _)| *hash_value == *hash); - if let Some((_, index)) = item { - // Already exists: add index to next array - let prev_index = *index; - // Store new value inside hashmap - *index = (row + offset + 1) as u64; - // Update chained Vec at row + offset with previous value - mut_list[row + offset - deleted_offset] = prev_index; - } else { - mut_map.insert( - *hash_value, - // store the value + 1 as 0 value reserved for end of list - (*hash_value, (row + offset + 1) as u64), - |(hash, _)| *hash, - ); - // chained list at (row + offset) is already initialized with 0 - // meaning end of list - } + // Updating JoinHashMap from hash values iterator + let hash_values_iter = hash_values + .iter() + .enumerate() + .map(|(i, val)| (i + offset, val)); + + if fifo_hashmap { + hash_map.update_from_iter(hash_values_iter.rev(), deleted_offset); + } else { + hash_map.update_from_iter(hash_values_iter, deleted_offset); } + Ok(()) } @@ -987,6 +1026,7 @@ pub fn build_equal_condition_join_indices( filter: Option<&JoinFilter>, build_side: JoinSide, deleted_offset: Option, + fifo_hashmap: bool, ) -> Result<(UInt64Array, UInt32Array)> { let keys_values = probe_on .iter() @@ -1002,10 +1042,9 @@ pub fn build_equal_condition_join_indices( hashes_buffer.clear(); hashes_buffer.resize(probe_batch.num_rows(), 0); let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - // Using a buffer builder to avoid slower normal builder - let mut build_indices = UInt64BufferBuilder::new(0); - let mut probe_indices = UInt32BufferBuilder::new(0); - // The chained list algorithm generates build indices for each probe row in a reversed sequence as such: + + // In case build-side input has not been inverted while JoinHashMap creation, the chained list algorithm + // will return build indices for each probe row in a reverse order as such: // Build Indices: [5, 4, 3] // Probe Indices: [1, 1, 1] // @@ -1034,44 +1073,17 @@ pub fn build_equal_condition_join_indices( // (5,1) // // With this approach, the lexicographic order on both the probe side and the build side is preserved. - let hash_map = build_hashmap.get_map(); - let next_chain = build_hashmap.get_list(); - for (row, hash_value) in hash_values.iter().enumerate().rev() { - // Get the hash and find it in the build index - - // For every item on the build and probe we check if it matches - // This possibly contains rows with hash collisions, - // So we have to check here whether rows are equal or not - if let Some((_, index)) = - hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) - { - let mut i = *index - 1; - loop { - let build_row_value = if let Some(offset) = deleted_offset { - // This arguments means that we prune the next index way before here. - if i < offset as u64 { - // End of the list due to pruning - break; - } - i - offset as u64 - } else { - i - }; - build_indices.append(build_row_value); - probe_indices.append(row as u32); - // Follow the chain to get the next index value - let next = next_chain[build_row_value as usize]; - if next == 0 { - // end of list - break; - } - i = next - 1; - } - } - } - // Reversing both sets of indices - build_indices.as_slice_mut().reverse(); - probe_indices.as_slice_mut().reverse(); + let (mut probe_indices, mut build_indices) = if fifo_hashmap { + build_hashmap.get_matched_indices(hash_values.iter().enumerate(), deleted_offset) + } else { + let (mut matched_probe, mut matched_build) = build_hashmap + .get_matched_indices(hash_values.iter().enumerate().rev(), deleted_offset); + + matched_probe.as_slice_mut().reverse(); + matched_build.as_slice_mut().reverse(); + + (matched_probe, matched_build) + }; let left: UInt64Array = PrimitiveArray::new(build_indices.finish().into(), None); let right: UInt32Array = PrimitiveArray::new(probe_indices.finish().into(), None); @@ -1279,6 +1291,7 @@ impl HashJoinStream { self.filter.as_ref(), JoinSide::Left, None, + true, ); let result = match left_right_indices { @@ -1393,7 +1406,9 @@ mod tests { use arrow::array::{ArrayRef, Date32Array, Int32Array, UInt32Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; + use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, assert_contains, ScalarValue, + }; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::Operator; @@ -1558,7 +1573,9 @@ mod tests { "| 3 | 5 | 9 | 20 | 5 | 80 |", "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1640,7 +1657,48 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_inner_one_randomly_ordered() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = build_table( + ("a1", &vec![0, 3, 2, 1]), + ("b1", &vec![4, 5, 5, 4]), + ("c1", &vec![6, 9, 8, 7]), + ); + let right = build_table( + ("a2", &vec![20, 30, 10]), + ("b2", &vec![5, 6, 4]), + ("c2", &vec![80, 90, 70]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let expected = [ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 0 | 4 | 6 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "+----+----+----+----+----+----+", + ]; + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1686,7 +1744,8 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1740,7 +1799,58 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_inner_one_two_parts_left_randomly_ordered() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let batch1 = build_table_i32( + ("a1", &vec![0, 3]), + ("b1", &vec![4, 5]), + ("c1", &vec![6, 9]), + ); + let batch2 = build_table_i32( + ("a1", &vec![2, 1]), + ("b1", &vec![5, 4]), + ("c1", &vec![8, 7]), + ); + let schema = batch1.schema(); + + let left = Arc::new( + MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), + ); + let right = build_table( + ("a2", &vec![20, 30, 10]), + ("b2", &vec![5, 6, 4]), + ("c2", &vec![80, 90, 70]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let expected = [ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 0 | 4 | 6 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "+----+----+----+----+----+----+", + ]; + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1789,7 +1899,9 @@ mod tests { "| 1 | 4 | 7 | 10 | 4 | 70 |", "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); // second part let stream = join.execute(1, task_ctx.clone())?; @@ -1804,7 +1916,8 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2228,12 +2341,14 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", - "| 12 | 10 | 40 |", "| 8 | 8 | 20 |", + "| 12 | 10 | 40 |", + "| 10 | 10 | 100 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightSemi join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2288,12 +2403,14 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", - "| 12 | 10 | 40 |", "| 8 | 8 | 20 |", + "| 12 | 10 | 40 |", + "| 10 | 10 | 100 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightSemi join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); // left_table right semi join right_table on left_table.b1 = right_table.b2 on left_table.a1!=9 let filter_expression = Arc::new(BinaryExpr::new( @@ -2314,11 +2431,13 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", "| 12 | 10 | 40 |", + "| 10 | 10 | 100 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightSemi join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2471,12 +2590,14 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", + "| 6 | 6 | 60 |", "| 2 | 2 | 80 |", "| 4 | 4 | 120 |", - "| 6 | 6 | 60 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightAnti join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2529,14 +2650,16 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", "| 12 | 10 | 40 |", + "| 6 | 6 | 60 |", "| 2 | 2 | 80 |", + "| 10 | 10 | 100 |", "| 4 | 4 | 120 |", - "| 6 | 6 | 60 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightAnti join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); // left_table right anti join right_table on left_table.b1 = right_table.b2 and right_table.b2!=8 let column_indices = vec![ColumnIndex { @@ -2565,13 +2688,15 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", + "| 8 | 8 | 20 |", + "| 6 | 6 | 60 |", "| 2 | 2 | 80 |", "| 4 | 4 | 120 |", - "| 6 | 6 | 60 |", - "| 8 | 8 | 20 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightAnti join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2734,6 +2859,7 @@ mod tests { None, JoinSide::Left, None, + false, )?; let mut left_ids = UInt64Builder::with_capacity(0); diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index f071a7f6015a..2d38c2bd16c3 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -771,6 +771,7 @@ pub(crate) fn join_with_probe_batch( filter, build_hash_joiner.build_side, Some(build_hash_joiner.deleted_offset), + false, )?; if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) { record_visited_indices( @@ -883,6 +884,7 @@ impl OneSideHashJoiner { random_state, &mut self.hashes_buffer, self.deleted_offset, + false, )?; Ok(()) } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index ac805b50e6a5..1e3cf5abb477 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -30,7 +30,7 @@ use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics}; use arrow::array::{ downcast_array, new_null_array, Array, BooleanBufferBuilder, UInt32Array, - UInt32Builder, UInt64Array, + UInt32BufferBuilder, UInt32Builder, UInt64Array, UInt64BufferBuilder, }; use arrow::compute; use arrow::datatypes::{Field, Schema, SchemaBuilder}; @@ -148,6 +148,82 @@ pub trait JoinHashMapType { fn get_map(&self) -> &RawTable<(u64, u64)>; /// Returns a reference to the next. fn get_list(&self) -> &Self::NextType; + + /// Updates hashmap from iterator of row indices & row hashes pairs. + fn update_from_iter<'a>( + &mut self, + iter: impl Iterator, + deleted_offset: usize, + ) { + let (mut_map, mut_list) = self.get_mut(); + for (row, hash_value) in iter { + let item = mut_map.get_mut(*hash_value, |(hash, _)| *hash_value == *hash); + if let Some((_, index)) = item { + // Already exists: add index to next array + let prev_index = *index; + // Store new value inside hashmap + *index = (row + 1) as u64; + // Update chained Vec at `row` with previous value + mut_list[row - deleted_offset] = prev_index; + } else { + mut_map.insert( + *hash_value, + // store the value + 1 as 0 value reserved for end of list + (*hash_value, (row + 1) as u64), + |(hash, _)| *hash, + ); + // chained list at `row` is already initialized with 0 + // meaning end of list + } + } + } + + /// Returns all pairs of row indices matched by hash. + /// + /// This method only compares hashes, so additional further check for actual values + /// equality may be required. + fn get_matched_indices<'a>( + &self, + iter: impl Iterator, + deleted_offset: Option, + ) -> (UInt32BufferBuilder, UInt64BufferBuilder) { + let mut input_indices = UInt32BufferBuilder::new(0); + let mut match_indices = UInt64BufferBuilder::new(0); + + let hash_map = self.get_map(); + let next_chain = self.get_list(); + for (row_idx, hash_value) in iter { + // Get the hash and find it in the index + if let Some((_, index)) = + hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) + { + let mut i = *index - 1; + loop { + let match_row_idx = if let Some(offset) = deleted_offset { + // This arguments means that we prune the next index way before here. + if i < offset as u64 { + // End of the list due to pruning + break; + } + i - offset as u64 + } else { + i + }; + match_indices.append(match_row_idx); + input_indices.append(row_idx as u32); + // Follow the chain to get the next index value + let next = next_chain[match_row_idx as usize]; + if next == 0 { + // end of list + break; + } + i = next - 1; + } + } + } + + (input_indices, match_indices) + } } /// Implementation of `JoinHashMapType` for `JoinHashMap`. From 545275bff316507226c68cb9d5a0739a0d90f32e Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Sat, 30 Dec 2023 09:12:26 -0500 Subject: [PATCH 53/63] Start setting up tpch planning benchmarks (#8665) * Start setting up tpch planning benchmarks * Add remaining tpch queries * Fix bench function * Clippy --- datafusion/core/benches/sql_planner.rs | 156 +++++++++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 7a41b6bec6f5..1754129a768f 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -60,6 +60,104 @@ pub fn create_table_provider(column_prefix: &str, num_columns: usize) -> Arc [(String, Schema); 8] { + let lineitem_schema = Schema::new(vec![ + Field::new("l_orderkey", DataType::Int64, false), + Field::new("l_partkey", DataType::Int64, false), + Field::new("l_suppkey", DataType::Int64, false), + Field::new("l_linenumber", DataType::Int32, false), + Field::new("l_quantity", DataType::Decimal128(15, 2), false), + Field::new("l_extendedprice", DataType::Decimal128(15, 2), false), + Field::new("l_discount", DataType::Decimal128(15, 2), false), + Field::new("l_tax", DataType::Decimal128(15, 2), false), + Field::new("l_returnflag", DataType::Utf8, false), + Field::new("l_linestatus", DataType::Utf8, false), + Field::new("l_shipdate", DataType::Date32, false), + Field::new("l_commitdate", DataType::Date32, false), + Field::new("l_receiptdate", DataType::Date32, false), + Field::new("l_shipinstruct", DataType::Utf8, false), + Field::new("l_shipmode", DataType::Utf8, false), + Field::new("l_comment", DataType::Utf8, false), + ]); + + let orders_schema = Schema::new(vec![ + Field::new("o_orderkey", DataType::Int64, false), + Field::new("o_custkey", DataType::Int64, false), + Field::new("o_orderstatus", DataType::Utf8, false), + Field::new("o_totalprice", DataType::Decimal128(15, 2), false), + Field::new("o_orderdate", DataType::Date32, false), + Field::new("o_orderpriority", DataType::Utf8, false), + Field::new("o_clerk", DataType::Utf8, false), + Field::new("o_shippriority", DataType::Int32, false), + Field::new("o_comment", DataType::Utf8, false), + ]); + + let part_schema = Schema::new(vec![ + Field::new("p_partkey", DataType::Int64, false), + Field::new("p_name", DataType::Utf8, false), + Field::new("p_mfgr", DataType::Utf8, false), + Field::new("p_brand", DataType::Utf8, false), + Field::new("p_type", DataType::Utf8, false), + Field::new("p_size", DataType::Int32, false), + Field::new("p_container", DataType::Utf8, false), + Field::new("p_retailprice", DataType::Decimal128(15, 2), false), + Field::new("p_comment", DataType::Utf8, false), + ]); + + let supplier_schema = Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, false), + Field::new("s_name", DataType::Utf8, false), + Field::new("s_address", DataType::Utf8, false), + Field::new("s_nationkey", DataType::Int64, false), + Field::new("s_phone", DataType::Utf8, false), + Field::new("s_acctbal", DataType::Decimal128(15, 2), false), + Field::new("s_comment", DataType::Utf8, false), + ]); + + let partsupp_schema = Schema::new(vec![ + Field::new("ps_partkey", DataType::Int64, false), + Field::new("ps_suppkey", DataType::Int64, false), + Field::new("ps_availqty", DataType::Int32, false), + Field::new("ps_supplycost", DataType::Decimal128(15, 2), false), + Field::new("ps_comment", DataType::Utf8, false), + ]); + + let customer_schema = Schema::new(vec![ + Field::new("c_custkey", DataType::Int64, false), + Field::new("c_name", DataType::Utf8, false), + Field::new("c_address", DataType::Utf8, false), + Field::new("c_nationkey", DataType::Int64, false), + Field::new("c_phone", DataType::Utf8, false), + Field::new("c_acctbal", DataType::Decimal128(15, 2), false), + Field::new("c_mktsegment", DataType::Utf8, false), + Field::new("c_comment", DataType::Utf8, false), + ]); + + let nation_schema = Schema::new(vec![ + Field::new("n_nationkey", DataType::Int64, false), + Field::new("n_name", DataType::Utf8, false), + Field::new("n_regionkey", DataType::Int64, false), + Field::new("n_comment", DataType::Utf8, false), + ]); + + let region_schema = Schema::new(vec![ + Field::new("r_regionkey", DataType::Int64, false), + Field::new("r_name", DataType::Utf8, false), + Field::new("r_comment", DataType::Utf8, false), + ]); + + [ + ("lineitem".to_string(), lineitem_schema), + ("orders".to_string(), orders_schema), + ("part".to_string(), part_schema), + ("supplier".to_string(), supplier_schema), + ("partsupp".to_string(), partsupp_schema), + ("customer".to_string(), customer_schema), + ("nation".to_string(), nation_schema), + ("region".to_string(), region_schema), + ] +} + fn create_context() -> SessionContext { let ctx = SessionContext::new(); ctx.register_table("t1", create_table_provider("a", 200)) @@ -68,6 +166,16 @@ fn create_context() -> SessionContext { .unwrap(); ctx.register_table("t700", create_table_provider("c", 700)) .unwrap(); + + let tpch_schemas = create_tpch_schemas(); + tpch_schemas.iter().for_each(|(name, schema)| { + ctx.register_table( + name, + Arc::new(MemTable::try_new(Arc::new(schema.clone()), vec![]).unwrap()), + ) + .unwrap(); + }); + ctx } @@ -115,6 +223,54 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + let q1_sql = std::fs::read_to_string("../../benchmarks/queries/q1.sql").unwrap(); + let q2_sql = std::fs::read_to_string("../../benchmarks/queries/q2.sql").unwrap(); + let q3_sql = std::fs::read_to_string("../../benchmarks/queries/q3.sql").unwrap(); + let q4_sql = std::fs::read_to_string("../../benchmarks/queries/q4.sql").unwrap(); + let q5_sql = std::fs::read_to_string("../../benchmarks/queries/q5.sql").unwrap(); + let q6_sql = std::fs::read_to_string("../../benchmarks/queries/q6.sql").unwrap(); + let q7_sql = std::fs::read_to_string("../../benchmarks/queries/q7.sql").unwrap(); + let q8_sql = std::fs::read_to_string("../../benchmarks/queries/q8.sql").unwrap(); + let q9_sql = std::fs::read_to_string("../../benchmarks/queries/q9.sql").unwrap(); + let q10_sql = std::fs::read_to_string("../../benchmarks/queries/q10.sql").unwrap(); + let q11_sql = std::fs::read_to_string("../../benchmarks/queries/q11.sql").unwrap(); + let q12_sql = std::fs::read_to_string("../../benchmarks/queries/q12.sql").unwrap(); + let q13_sql = std::fs::read_to_string("../../benchmarks/queries/q13.sql").unwrap(); + let q14_sql = std::fs::read_to_string("../../benchmarks/queries/q14.sql").unwrap(); + // let q15_sql = std::fs::read_to_string("../../benchmarks/queries/q15.sql").unwrap(); + let q16_sql = std::fs::read_to_string("../../benchmarks/queries/q16.sql").unwrap(); + let q17_sql = std::fs::read_to_string("../../benchmarks/queries/q17.sql").unwrap(); + let q18_sql = std::fs::read_to_string("../../benchmarks/queries/q18.sql").unwrap(); + let q19_sql = std::fs::read_to_string("../../benchmarks/queries/q19.sql").unwrap(); + let q20_sql = std::fs::read_to_string("../../benchmarks/queries/q20.sql").unwrap(); + let q21_sql = std::fs::read_to_string("../../benchmarks/queries/q21.sql").unwrap(); + let q22_sql = std::fs::read_to_string("../../benchmarks/queries/q22.sql").unwrap(); + + c.bench_function("physical_plan_tpch", |b| { + b.iter(|| physical_plan(&ctx, &q1_sql)); + b.iter(|| physical_plan(&ctx, &q2_sql)); + b.iter(|| physical_plan(&ctx, &q3_sql)); + b.iter(|| physical_plan(&ctx, &q4_sql)); + b.iter(|| physical_plan(&ctx, &q5_sql)); + b.iter(|| physical_plan(&ctx, &q6_sql)); + b.iter(|| physical_plan(&ctx, &q7_sql)); + b.iter(|| physical_plan(&ctx, &q8_sql)); + b.iter(|| physical_plan(&ctx, &q9_sql)); + b.iter(|| physical_plan(&ctx, &q10_sql)); + b.iter(|| physical_plan(&ctx, &q11_sql)); + b.iter(|| physical_plan(&ctx, &q12_sql)); + b.iter(|| physical_plan(&ctx, &q13_sql)); + b.iter(|| physical_plan(&ctx, &q14_sql)); + // b.iter(|| physical_plan(&ctx, &q15_sql)); + b.iter(|| physical_plan(&ctx, &q16_sql)); + b.iter(|| physical_plan(&ctx, &q17_sql)); + b.iter(|| physical_plan(&ctx, &q18_sql)); + b.iter(|| physical_plan(&ctx, &q19_sql)); + b.iter(|| physical_plan(&ctx, &q20_sql)); + b.iter(|| physical_plan(&ctx, &q21_sql)); + b.iter(|| physical_plan(&ctx, &q22_sql)); + }); } criterion_group!(benches, criterion_benchmark); From 848f6c395afef790880112f809b1443949d4bb0b Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Sun, 31 Dec 2023 07:34:54 -0500 Subject: [PATCH 54/63] update doc (#8686) --- datafusion/core/src/datasource/provider.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/datasource/provider.rs b/datafusion/core/src/datasource/provider.rs index 275523405a09..c1cee849fe5c 100644 --- a/datafusion/core/src/datasource/provider.rs +++ b/datafusion/core/src/datasource/provider.rs @@ -141,7 +141,11 @@ pub trait TableProvider: Sync + Send { /// (though it may return more). Like Projection Pushdown and Filter /// Pushdown, DataFusion pushes `LIMIT`s as far down in the plan as /// possible, called "Limit Pushdown" as some sources can use this - /// information to improve their performance. + /// information to improve their performance. Note that if there are any + /// Inexact filters pushed down, the LIMIT cannot be pushed down. This is + /// because inexact filters do not guarentee that every filtered row is + /// removed, so applying the limit could lead to too few rows being available + /// to return as a final result. async fn scan( &self, state: &SessionState, From 03bd9b462e9068476e704f0056a3761bd9dce3f0 Mon Sep 17 00:00:00 2001 From: Marvin Lanhenke <62298609+marvinlanhenke@users.noreply.github.com> Date: Sun, 31 Dec 2023 13:52:04 +0100 Subject: [PATCH 55/63] Closes #8502: Parallel NDJSON file reading (#8659) * added basic test * added `fn repartitioned` * added basic version of FileOpener * refactor: extract calculate_range * refactor: handle GetResultPayload::Stream * refactor: extract common functions to mod.rs * refactor: use common functions * added docs * added test * clippy * fix: test_chunked_json * fix: sqllogictest * delete imports * update docs --- .../core/src/datasource/file_format/json.rs | 106 ++++++++++++++++- .../core/src/datasource/physical_plan/csv.rs | 98 +++------------- .../core/src/datasource/physical_plan/json.rs | 105 +++++++++++++---- .../core/src/datasource/physical_plan/mod.rs | 107 +++++++++++++++++- datafusion/core/tests/data/empty.json | 0 .../test_files/repartition_scan.slt | 8 +- 6 files changed, 305 insertions(+), 119 deletions(-) create mode 100644 datafusion/core/tests/data/empty.json diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 3d437bc5fe68..8c02955ad363 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -294,16 +294,20 @@ impl DataSink for JsonSink { #[cfg(test)] mod tests { use super::super::test_util::scan_format; - use super::*; - use crate::physical_plan::collect; - use crate::prelude::{SessionConfig, SessionContext}; - use crate::test::object_store::local_unpartitioned_file; - + use arrow::util::pretty; use datafusion_common::cast::as_int64_array; use datafusion_common::stats::Precision; - + use datafusion_common::{assert_batches_eq, internal_err}; use futures::StreamExt; use object_store::local::LocalFileSystem; + use regex::Regex; + use rstest::rstest; + + use super::*; + use crate::execution::options::NdJsonReadOptions; + use crate::physical_plan::collect; + use crate::prelude::{SessionConfig, SessionContext}; + use crate::test::object_store::local_unpartitioned_file; #[tokio::test] async fn read_small_batches() -> Result<()> { @@ -424,4 +428,94 @@ mod tests { .collect::>(); assert_eq!(vec!["a: Int64", "b: Float64", "c: Boolean"], fields); } + + async fn count_num_partitions(ctx: &SessionContext, query: &str) -> Result { + let result = ctx + .sql(&format!("EXPLAIN {query}")) + .await? + .collect() + .await?; + + let plan = format!("{}", &pretty::pretty_format_batches(&result)?); + + let re = Regex::new(r"file_groups=\{(\d+) group").unwrap(); + + if let Some(captures) = re.captures(&plan) { + if let Some(match_) = captures.get(1) { + let count = match_.as_str().parse::().unwrap(); + return Ok(count); + } + } + + internal_err!("Query contains no Exec: file_groups") + } + + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[tokio::test] + async fn it_can_read_ndjson_in_parallel(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + + let ctx = SessionContext::new_with_config(config); + + let table_path = "tests/data/1.json"; + let options = NdJsonReadOptions::default(); + + ctx.register_json("json_parallel", table_path, options) + .await?; + + let query = "SELECT SUM(a) FROM json_parallel;"; + + let result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_num_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = [ + "+----------------------+", + "| SUM(json_parallel.a) |", + "+----------------------+", + "| -7 |", + "+----------------------+" + ]; + + assert_batches_eq!(expected, &result); + assert_eq!(n_partitions, actual_partitions); + + Ok(()) + } + + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[tokio::test] + async fn it_can_read_empty_ndjson_in_parallel(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + + let ctx = SessionContext::new_with_config(config); + + let table_path = "tests/data/empty.json"; + let options = NdJsonReadOptions::default(); + + ctx.register_json("json_parallel_empty", table_path, options) + .await?; + + let query = "SELECT * FROM json_parallel_empty WHERE random() > 0.5;"; + + let result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_num_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = [ + "++", + "++", + ]; + + assert_batches_eq!(expected, &result); + assert_eq!(1, actual_partitions); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 0c34d22e9fa9..b28bc7d56688 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -19,11 +19,10 @@ use std::any::Any; use std::io::{Read, Seek, SeekFrom}; -use std::ops::Range; use std::sync::Arc; use std::task::Poll; -use super::{FileGroupPartitioner, FileScanConfig}; +use super::{calculate_range, FileGroupPartitioner, FileScanConfig, RangeCalculation}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::listing::{FileRange, ListingTableUrl}; use crate::datasource::physical_plan::file_stream::{ @@ -318,47 +317,6 @@ impl CsvOpener { } } -/// Returns the offset of the first newline in the object store range [start, end), or the end offset if no newline is found. -async fn find_first_newline( - object_store: &Arc, - location: &object_store::path::Path, - start_byte: usize, - end_byte: usize, -) -> Result { - let options = GetOptions { - range: Some(Range { - start: start_byte, - end: end_byte, - }), - ..Default::default() - }; - - let r = object_store.get_opts(location, options).await?; - let mut input = r.into_stream(); - - let mut buffered = Bytes::new(); - let mut index = 0; - - loop { - if buffered.is_empty() { - match input.next().await { - Some(Ok(b)) => buffered = b, - Some(Err(e)) => return Err(e.into()), - None => return Ok(index), - }; - } - - for byte in &buffered { - if *byte == b'\n' { - return Ok(index); - } - index += 1; - } - - buffered.advance(buffered.len()); - } -} - impl FileOpener for CsvOpener { /// Open a partitioned CSV file. /// @@ -408,44 +366,20 @@ impl FileOpener for CsvOpener { ); } + let store = self.config.object_store.clone(); + Ok(Box::pin(async move { - let file_size = file_meta.object_meta.size; // Current partition contains bytes [start_byte, end_byte) (might contain incomplete lines at boundaries) - let range = match file_meta.range { - None => None, - Some(FileRange { start, end }) => { - let (start, end) = (start as usize, end as usize); - // Partition byte range is [start, end), the boundary might be in the middle of - // some line. Need to find out the exact line boundaries. - let start_delta = if start != 0 { - find_first_newline( - &config.object_store, - file_meta.location(), - start - 1, - file_size, - ) - .await? - } else { - 0 - }; - let end_delta = if end != file_size { - find_first_newline( - &config.object_store, - file_meta.location(), - end - 1, - file_size, - ) - .await? - } else { - 0 - }; - let range = start + start_delta..end + end_delta; - if range.start == range.end { - return Ok( - futures::stream::poll_fn(move |_| Poll::Ready(None)).boxed() - ); - } - Some(range) + + let calculated_range = calculate_range(&file_meta, &store).await?; + + let range = match calculated_range { + RangeCalculation::Range(None) => None, + RangeCalculation::Range(Some(range)) => Some(range), + RangeCalculation::TerminateEarly => { + return Ok( + futures::stream::poll_fn(move |_| Poll::Ready(None)).boxed() + ) } }; @@ -453,10 +387,8 @@ impl FileOpener for CsvOpener { range, ..Default::default() }; - let result = config - .object_store - .get_opts(file_meta.location(), options) - .await?; + + let result = store.get_opts(file_meta.location(), options).await?; match result.payload { GetResultPayload::File(mut file, _) => { diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index c74fd13e77aa..529632dab85a 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -18,11 +18,11 @@ //! Execution plan for reading line-delimited JSON files use std::any::Any; -use std::io::BufReader; +use std::io::{BufReader, Read, Seek, SeekFrom}; use std::sync::Arc; use std::task::Poll; -use super::FileScanConfig; +use super::{calculate_range, FileGroupPartitioner, FileScanConfig, RangeCalculation}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::listing::ListingTableUrl; use crate::datasource::physical_plan::file_stream::{ @@ -43,8 +43,8 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use bytes::{Buf, Bytes}; -use futures::{ready, stream, StreamExt, TryStreamExt}; -use object_store; +use futures::{ready, StreamExt, TryStreamExt}; +use object_store::{self, GetOptions}; use object_store::{GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; use tokio::task::JoinSet; @@ -134,6 +134,30 @@ impl ExecutionPlan for NdJsonExec { Ok(self) } + fn repartitioned( + &self, + target_partitions: usize, + config: &datafusion_common::config::ConfigOptions, + ) -> Result>> { + let repartition_file_min_size = config.optimizer.repartition_file_min_size; + let preserve_order_within_groups = self.output_ordering().is_some(); + let file_groups = &self.base_config.file_groups; + + let repartitioned_file_groups_option = FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_preserve_order_within_groups(preserve_order_within_groups) + .with_repartition_file_min_size(repartition_file_min_size) + .repartition_file_groups(file_groups); + + if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { + let mut new_plan = self.clone(); + new_plan.base_config.file_groups = repartitioned_file_groups; + return Ok(Some(Arc::new(new_plan))); + } + + Ok(None) + } + fn execute( &self, partition: usize, @@ -193,54 +217,89 @@ impl JsonOpener { } impl FileOpener for JsonOpener { + /// Open a partitioned NDJSON file. + /// + /// If `file_meta.range` is `None`, the entire file is opened. + /// Else `file_meta.range` is `Some(FileRange{start, end})`, which corresponds to the byte range [start, end) within the file. + /// + /// Note: `start` or `end` might be in the middle of some lines. In such cases, the following rules + /// are applied to determine which lines to read: + /// 1. The first line of the partition is the line in which the index of the first character >= `start`. + /// 2. The last line of the partition is the line in which the byte at position `end - 1` resides. + /// + /// See [`CsvOpener`](super::CsvOpener) for an example. fn open(&self, file_meta: FileMeta) -> Result { let store = self.object_store.clone(); let schema = self.projected_schema.clone(); let batch_size = self.batch_size; - let file_compression_type = self.file_compression_type.to_owned(); + Ok(Box::pin(async move { - let r = store.get(file_meta.location()).await?; - match r.payload { - GetResultPayload::File(file, _) => { - let bytes = file_compression_type.convert_read(file)?; + let calculated_range = calculate_range(&file_meta, &store).await?; + + let range = match calculated_range { + RangeCalculation::Range(None) => None, + RangeCalculation::Range(Some(range)) => Some(range), + RangeCalculation::TerminateEarly => { + return Ok( + futures::stream::poll_fn(move |_| Poll::Ready(None)).boxed() + ) + } + }; + + let options = GetOptions { + range, + ..Default::default() + }; + + let result = store.get_opts(file_meta.location(), options).await?; + + match result.payload { + GetResultPayload::File(mut file, _) => { + let bytes = match file_meta.range { + None => file_compression_type.convert_read(file)?, + Some(_) => { + file.seek(SeekFrom::Start(result.range.start as _))?; + let limit = result.range.end - result.range.start; + file_compression_type.convert_read(file.take(limit as u64))? + } + }; + let reader = ReaderBuilder::new(schema) .with_batch_size(batch_size) .build(BufReader::new(bytes))?; + Ok(futures::stream::iter(reader).boxed()) } GetResultPayload::Stream(s) => { + let s = s.map_err(DataFusionError::from); + let mut decoder = ReaderBuilder::new(schema) .with_batch_size(batch_size) .build_decoder()?; - - let s = s.map_err(DataFusionError::from); let mut input = file_compression_type.convert_stream(s.boxed())?.fuse(); - let mut buffered = Bytes::new(); + let mut buffer = Bytes::new(); - let s = stream::poll_fn(move |cx| { + let s = futures::stream::poll_fn(move |cx| { loop { - if buffered.is_empty() { - buffered = match ready!(input.poll_next_unpin(cx)) { - Some(Ok(b)) => b, + if buffer.is_empty() { + match ready!(input.poll_next_unpin(cx)) { + Some(Ok(b)) => buffer = b, Some(Err(e)) => { return Poll::Ready(Some(Err(e.into()))) } - None => break, + None => {} }; } - let read = buffered.len(); - let decoded = match decoder.decode(buffered.as_ref()) { + let decoded = match decoder.decode(buffer.as_ref()) { + Ok(0) => break, Ok(decoded) => decoded, Err(e) => return Poll::Ready(Some(Err(e))), }; - buffered.advance(decoded); - if decoded != read { - break; - } + buffer.advance(decoded); } Poll::Ready(decoder.flush().transpose()) diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 5583991355c6..d7be017a1868 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -27,6 +27,7 @@ mod json; #[cfg(feature = "parquet")] pub mod parquet; pub use file_groups::FileGroupPartitioner; +use futures::StreamExt; pub(crate) use self::csv::plan_to_csv; pub use self::csv::{CsvConfig, CsvExec, CsvOpener}; @@ -45,6 +46,7 @@ pub use json::{JsonOpener, NdJsonExec}; use std::{ fmt::{Debug, Formatter, Result as FmtResult}, + ops::Range, sync::Arc, vec, }; @@ -72,8 +74,8 @@ use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::ExecutionPlan; use log::debug; -use object_store::path::Path; use object_store::ObjectMeta; +use object_store::{path::Path, GetOptions, ObjectStore}; /// The base configurations to provide when creating a physical plan for /// writing to any given file format. @@ -522,6 +524,109 @@ pub fn is_plan_streaming(plan: &Arc) -> Result { } } +/// Represents the possible outcomes of a range calculation. +/// +/// This enum is used to encapsulate the result of calculating the range of +/// bytes to read from an object (like a file) in an object store. +/// +/// Variants: +/// - `Range(Option>)`: +/// Represents a range of bytes to be read. It contains an `Option` wrapping a +/// `Range`. `None` signifies that the entire object should be read, +/// while `Some(range)` specifies the exact byte range to read. +/// - `TerminateEarly`: +/// Indicates that the range calculation determined no further action is +/// necessary, possibly because the calculated range is empty or invalid. +enum RangeCalculation { + Range(Option>), + TerminateEarly, +} + +/// Calculates an appropriate byte range for reading from an object based on the +/// provided metadata. +/// +/// This asynchronous function examines the `FileMeta` of an object in an object store +/// and determines the range of bytes to be read. The range calculation may adjust +/// the start and end points to align with meaningful data boundaries (like newlines). +/// +/// Returns a `Result` wrapping a `RangeCalculation`, which is either a calculated byte range or an indication to terminate early. +/// +/// Returns an `Error` if any part of the range calculation fails, such as issues in reading from the object store or invalid range boundaries. +async fn calculate_range( + file_meta: &FileMeta, + store: &Arc, +) -> Result { + let location = file_meta.location(); + let file_size = file_meta.object_meta.size; + + match file_meta.range { + None => Ok(RangeCalculation::Range(None)), + Some(FileRange { start, end }) => { + let (start, end) = (start as usize, end as usize); + + let start_delta = if start != 0 { + find_first_newline(store, location, start - 1, file_size).await? + } else { + 0 + }; + + let end_delta = if end != file_size { + find_first_newline(store, location, end - 1, file_size).await? + } else { + 0 + }; + + let range = start + start_delta..end + end_delta; + + if range.start == range.end { + return Ok(RangeCalculation::TerminateEarly); + } + + Ok(RangeCalculation::Range(Some(range))) + } + } +} + +/// Asynchronously finds the position of the first newline character in a specified byte range +/// within an object, such as a file, in an object store. +/// +/// This function scans the contents of the object starting from the specified `start` position +/// up to the `end` position, looking for the first occurrence of a newline (`'\n'`) character. +/// It returns the position of the first newline relative to the start of the range. +/// +/// Returns a `Result` wrapping a `usize` that represents the position of the first newline character found within the specified range. If no newline is found, it returns the length of the scanned data, effectively indicating the end of the range. +/// +/// The function returns an `Error` if any issues arise while reading from the object store or processing the data stream. +/// +async fn find_first_newline( + object_store: &Arc, + location: &Path, + start: usize, + end: usize, +) -> Result { + let range = Some(Range { start, end }); + + let options = GetOptions { + range, + ..Default::default() + }; + + let result = object_store.get_opts(location, options).await?; + let mut result_stream = result.into_stream(); + + let mut index = 0; + + while let Some(chunk) = result_stream.next().await.transpose()? { + if let Some(position) = chunk.iter().position(|&byte| byte == b'\n') { + return Ok(index + position); + } + + index += chunk.len(); + } + + Ok(index) +} + #[cfg(test)] mod tests { use arrow_array::cast::AsArray; diff --git a/datafusion/core/tests/data/empty.json b/datafusion/core/tests/data/empty.json new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 5dcdbb504e76..3cb42c2206ad 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -198,9 +198,7 @@ select * from json_table; 4 5 -## In the future it would be cool to see the file read as "4" groups with even sizes (offsets) -## but for now it is just one group -## https://github.com/apache/arrow-datafusion/issues/8502 +## Expect to see the scan read the file as "4" groups with even sizes (offsets) query TT EXPLAIN SELECT column1 FROM json_table WHERE column1 <> 42; ---- @@ -210,9 +208,7 @@ Filter: json_table.column1 != Int32(42) physical_plan CoalesceBatchesExec: target_batch_size=8192 --FilterExec: column1@0 != 42 -----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------JsonExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json]]}, projection=[column1] - +----JsonExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:0..18], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:18..36], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:36..54], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:54..70]]}, projection=[column1] # Cleanup statement ok From f0af5eb949e2c5fa9f66eb6f6a9fcdf8f7389c9d Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 31 Dec 2023 21:50:52 +0800 Subject: [PATCH 56/63] init draft (#8625) Signed-off-by: jayzhan211 --- datafusion/expr/src/built_in_function.rs | 5 +- datafusion/expr/src/signature.rs | 7 ++ .../expr/src/type_coercion/functions.rs | 89 +++++++++++-------- datafusion/sqllogictest/test_files/array.slt | 62 ++++++++++--- 4 files changed, 115 insertions(+), 48 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index c454a9781eda..e642dae06e4f 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -960,7 +960,10 @@ impl BuiltinScalarFunction { Signature::variadic_any(self.volatility()) } BuiltinScalarFunction::ArrayPositions => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayPrepend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayPrepend => Signature { + type_signature: ElementAndArray, + volatility: self.volatility(), + }, BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemove => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()), diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 3f07c300e196..729131bd95e1 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -122,6 +122,10 @@ pub enum TypeSignature { /// List dimension of the List/LargeList is equivalent to the number of List. /// List dimension of the non-list is 0. ArrayAndElement, + /// Specialized Signature for ArrayPrepend and similar functions + /// The first argument should be non-list or list, and the second argument should be List/LargeList. + /// The first argument's list dimension should be one dimension less than the second argument's list dimension. + ElementAndArray, } impl TypeSignature { @@ -155,6 +159,9 @@ impl TypeSignature { TypeSignature::ArrayAndElement => { vec!["ArrayAndElement(List, T)".to_string()] } + TypeSignature::ElementAndArray => { + vec!["ElementAndArray(T, List)".to_string()] + } } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index f95a30e025b4..fa47c92762bf 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -79,6 +79,55 @@ fn get_valid_types( signature: &TypeSignature, current_types: &[DataType], ) -> Result>> { + fn array_append_or_prepend_valid_types( + current_types: &[DataType], + is_append: bool, + ) -> Result>> { + if current_types.len() != 2 { + return Ok(vec![vec![]]); + } + + let (array_type, elem_type) = if is_append { + (¤t_types[0], ¤t_types[1]) + } else { + (¤t_types[1], ¤t_types[0]) + }; + + // We follow Postgres on `array_append(Null, T)`, which is not valid. + if array_type.eq(&DataType::Null) { + return Ok(vec![vec![]]); + } + + // We need to find the coerced base type, mainly for cases like: + // `array_append(List(null), i64)` -> `List(i64)` + let array_base_type = datafusion_common::utils::base_type(array_type); + let elem_base_type = datafusion_common::utils::base_type(elem_type); + let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); + + if new_base_type.is_none() { + return internal_err!( + "Coercion from {array_base_type:?} to {elem_base_type:?} not supported." + ); + } + let new_base_type = new_base_type.unwrap(); + + let array_type = datafusion_common::utils::coerced_type_with_base_type_only( + array_type, + &new_base_type, + ); + + if let DataType::List(ref field) = array_type { + let elem_type = field.data_type(); + if is_append { + Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]) + } else { + Ok(vec![vec![elem_type.to_owned(), array_type.clone()]]) + } + } else { + Ok(vec![vec![]]) + } + } + let valid_types = match signature { TypeSignature::Variadic(valid_types) => valid_types .iter() @@ -112,42 +161,10 @@ fn get_valid_types( TypeSignature::Exact(valid_types) => vec![valid_types.clone()], TypeSignature::ArrayAndElement => { - if current_types.len() != 2 { - return Ok(vec![vec![]]); - } - - let array_type = ¤t_types[0]; - let elem_type = ¤t_types[1]; - - // We follow Postgres on `array_append(Null, T)`, which is not valid. - if array_type.eq(&DataType::Null) { - return Ok(vec![vec![]]); - } - - // We need to find the coerced base type, mainly for cases like: - // `array_append(List(null), i64)` -> `List(i64)` - let array_base_type = datafusion_common::utils::base_type(array_type); - let elem_base_type = datafusion_common::utils::base_type(elem_type); - let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); - - if new_base_type.is_none() { - return internal_err!( - "Coercion from {array_base_type:?} to {elem_base_type:?} not supported." - ); - } - let new_base_type = new_base_type.unwrap(); - - let array_type = datafusion_common::utils::coerced_type_with_base_type_only( - array_type, - &new_base_type, - ); - - if let DataType::List(ref field) = array_type { - let elem_type = field.data_type(); - return Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]); - } else { - return Ok(vec![vec![]]); - } + return array_append_or_prepend_valid_types(current_types, true) + } + TypeSignature::ElementAndArray => { + return array_append_or_prepend_valid_types(current_types, false) } TypeSignature::Any(number) => { if current_types.len() != *number { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index b8d89edb49b1..6dab3b3084a9 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1618,18 +1618,58 @@ select array_append(column1, make_array(1, 11, 111)), array_append(make_array(ma ## array_prepend (aliases: `list_prepend`, `array_push_front`, `list_push_front`) -# TODO: array_prepend with NULLs -# array_prepend scalar function #1 -# query ? -# select array_prepend(4, make_array()); -# ---- -# [4] +# array_prepend with NULLs + +# DuckDB: [4] +# ClickHouse: Null +# Since they dont have the same result, we just follow Postgres, return error +query error +select array_prepend(4, NULL); + +query ? +select array_prepend(4, []); +---- +[4] + +query ? +select array_prepend(4, [null]); +---- +[4, ] + +# DuckDB: [null] +# ClickHouse: [null] +query ? +select array_prepend(null, []); +---- +[] + +query ? +select array_prepend(null, [1]); +---- +[, 1] + +query ? +select array_prepend(null, [[1,2,3]]); +---- +[, [1, 2, 3]] + +# DuckDB: [[]] +# ClickHouse: [[]] +# TODO: We may also return [[]] +query error +select array_prepend([], []); + +# DuckDB: [null] +# ClickHouse: [null] +# TODO: We may also return [null] +query error +select array_prepend(null, null); + +query ? +select array_append([], null); +---- +[] -# array_prepend scalar function #2 -# query ?? -# select array_prepend(make_array(), make_array()), array_prepend(make_array(4), make_array()); -# ---- -# [[]] [[4]] # array_prepend scalar function #3 query ??? From bf3bd9259aa0e93ccc2c79a606207add30d004a4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 1 Jan 2024 00:22:18 -0800 Subject: [PATCH 57/63] Cleanup TreeNode implementations (#8672) * Refactor TreeNode and cleanup some implementations * More * More * Fix clippy * avoid cloning in `TreeNode.children_nodes()` implementations where possible using `Cow` * Remove more unnecessary apply_children * Fix clippy * Remove --------- Co-authored-by: Peter Toth --- datafusion/common/src/tree_node.rs | 33 ++++--- .../enforce_distribution.rs | 32 ++----- .../src/physical_optimizer/enforce_sorting.rs | 33 ++----- .../physical_optimizer/pipeline_checker.rs | 18 +--- .../replace_with_order_preserving_variants.rs | 17 +--- .../src/physical_optimizer/sort_pushdown.rs | 19 +--- datafusion/expr/src/tree_node/expr.rs | 93 ++++++++----------- datafusion/expr/src/tree_node/plan.rs | 20 +--- .../physical-expr/src/sort_properties.rs | 19 +--- datafusion/physical-expr/src/utils/mod.rs | 17 +--- 10 files changed, 97 insertions(+), 204 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 5da9636ffe18..5f11c8cc1d11 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -18,6 +18,7 @@ //! This module provides common traits for visiting or rewriting tree //! data structures easily. +use std::borrow::Cow; use std::sync::Arc; use crate::Result; @@ -32,7 +33,10 @@ use crate::Result; /// [`PhysicalExpr`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.PhysicalExpr.html /// [`LogicalPlan`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html /// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html -pub trait TreeNode: Sized { +pub trait TreeNode: Sized + Clone { + /// Returns all children of the TreeNode + fn children_nodes(&self) -> Vec>; + /// Use preorder to iterate the node on the tree so that we can /// stop fast for some cases. /// @@ -211,7 +215,17 @@ pub trait TreeNode: Sized { /// Apply the closure `F` to the node's children fn apply_children(&self, op: &mut F) -> Result where - F: FnMut(&Self) -> Result; + F: FnMut(&Self) -> Result, + { + for child in self.children_nodes() { + match op(&child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + Ok(VisitRecursion::Continue) + } /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result @@ -342,19 +356,8 @@ pub trait DynTreeNode { /// Blanket implementation for Arc for any tye that implements /// [`DynTreeNode`] (such as [`Arc`]) impl TreeNode for Arc { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.arc_children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.arc_children().into_iter().map(Cow::Owned).collect() } fn map_children(self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index d5a086227323..bf5aa7d02272 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -21,6 +21,7 @@ //! according to the configuration), this rule increases partition counts in //! the physical plan. +use std::borrow::Cow; use std::fmt; use std::fmt::Formatter; use std::sync::Arc; @@ -47,7 +48,7 @@ use crate::physical_plan::{ }; use arrow::compute::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; @@ -1409,18 +1410,8 @@ impl DistributionContext { } impl TreeNode for DistributionContext { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children_nodes { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result @@ -1483,19 +1474,8 @@ impl PlanWithKeyRequirements { } impl TreeNode for PlanWithKeyRequirements { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 77d04a61c59e..f609ddea66cf 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -34,6 +34,7 @@ //! in the physical plan. The first sort is unnecessary since its result is overwritten //! by another [`SortExec`]. Therefore, this rule removes it from the physical plan. +use std::borrow::Cow; use std::sync::Arc; use crate::config::ConfigOptions; @@ -57,7 +58,7 @@ use crate::physical_plan::{ with_new_children_if_necessary, Distribution, ExecutionPlan, InputOrderMode, }; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -145,19 +146,8 @@ impl PlanWithCorrespondingSort { } impl TreeNode for PlanWithCorrespondingSort { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children_nodes { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result @@ -237,19 +227,8 @@ impl PlanWithCorrespondingCoalescePartitions { } impl TreeNode for PlanWithCorrespondingCoalescePartitions { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children_nodes { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 9e9f647d073f..e281d0e7c23e 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -19,6 +19,7 @@ //! infinite sources, if there are any. It will reject non-runnable query plans //! that use pipeline-breaking operators on infinite input(s). +use std::borrow::Cow; use std::sync::Arc; use crate::config::ConfigOptions; @@ -27,7 +28,7 @@ use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::config::OptimizerOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; use datafusion_physical_plan::joins::SymmetricHashJoinExec; @@ -91,19 +92,8 @@ impl PipelineStatePropagator { } impl TreeNode for PipelineStatePropagator { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 91f3d2abc6ff..e49b358608aa 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -19,6 +19,7 @@ //! order-preserving variants when it is helpful; either in terms of //! performance or to accommodate unbounded streams by fixing the pipeline. +use std::borrow::Cow; use std::sync::Arc; use super::utils::is_repartition; @@ -29,7 +30,7 @@ use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_physical_plan::unbounded_output; /// For a given `plan`, this object carries the information one needs from its @@ -104,18 +105,8 @@ impl OrderPreservationContext { } impl TreeNode for OrderPreservationContext { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children_nodes { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index b0013863010a..97ca47baf05f 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::borrow::Cow; use std::sync::Arc; use crate::physical_optimizer::utils::{ @@ -28,7 +29,7 @@ use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{plan_err, DataFusionError, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; @@ -71,20 +72,10 @@ impl SortPushDown { } impl TreeNode for SortPushDown { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children_nodes { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 1098842716b9..56388be58b8a 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -23,17 +23,15 @@ use crate::expr::{ ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; use crate::{Expr, GetFieldAccess}; +use std::borrow::Cow; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::TreeNode; use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = match self { - Expr::Alias(Alias{expr,..}) + fn children_nodes(&self) -> Vec> { + match self { + Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::IsNotNull(expr) | Expr::IsTrue(expr) @@ -47,28 +45,26 @@ impl TreeNode for Expr { | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) | Expr::Sort(Sort { expr, .. }) - | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref().clone()], + | Expr::InSubquery(InSubquery { expr, .. }) => vec![Cow::Borrowed(expr)], Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let expr = expr.as_ref().clone(); + let expr = Cow::Borrowed(expr.as_ref()); match field { - GetFieldAccess::ListIndex {key} => { - vec![key.as_ref().clone(), expr] - }, - GetFieldAccess::ListRange {start, stop} => { - vec![start.as_ref().clone(), stop.as_ref().clone(), expr] + GetFieldAccess::ListIndex { key } => { + vec![Cow::Borrowed(key.as_ref()), expr] } - GetFieldAccess::NamedStructField {name: _name} => { + GetFieldAccess::ListRange { start, stop } => { + vec![Cow::Borrowed(start), Cow::Borrowed(stop), expr] + } + GetFieldAccess::NamedStructField { name: _name } => { vec![expr] } } } Expr::GroupingSet(GroupingSet::Rollup(exprs)) - | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(), - Expr::ScalarFunction (ScalarFunction{ args, .. } ) => { - args.clone() - } + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().map(Cow::Borrowed).collect(), + Expr::ScalarFunction(ScalarFunction { args, .. }) => args.iter().map(Cow::Borrowed).collect(), Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { - lists_of_exprs.clone().into_iter().flatten().collect() + lists_of_exprs.iter().flatten().map(Cow::Borrowed).collect() } Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression @@ -77,45 +73,49 @@ impl TreeNode for Expr { | Expr::Literal(_) | Expr::Exists { .. } | Expr::ScalarSubquery(_) - | Expr::Wildcard {..} - | Expr::Placeholder (_) => vec![], + | Expr::Wildcard { .. } + | Expr::Placeholder(_) => vec![], Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - vec![left.as_ref().clone(), right.as_ref().clone()] + vec![Cow::Borrowed(left), Cow::Borrowed(right)] } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - vec![expr.as_ref().clone(), pattern.as_ref().clone()] + vec![Cow::Borrowed(expr), Cow::Borrowed(pattern)] } Expr::Between(Between { expr, low, high, .. }) => vec![ - expr.as_ref().clone(), - low.as_ref().clone(), - high.as_ref().clone(), + Cow::Borrowed(expr), + Cow::Borrowed(low), + Cow::Borrowed(high), ], Expr::Case(case) => { let mut expr_vec = vec![]; if let Some(expr) = case.expr.as_ref() { - expr_vec.push(expr.as_ref().clone()); + expr_vec.push(Cow::Borrowed(expr.as_ref())); }; for (when, then) in case.when_then_expr.iter() { - expr_vec.push(when.as_ref().clone()); - expr_vec.push(then.as_ref().clone()); + expr_vec.push(Cow::Borrowed(when)); + expr_vec.push(Cow::Borrowed(then)); } if let Some(else_expr) = case.else_expr.as_ref() { - expr_vec.push(else_expr.as_ref().clone()); + expr_vec.push(Cow::Borrowed(else_expr)); } expr_vec } - Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) - => { - let mut expr_vec = args.clone(); + Expr::AggregateFunction(AggregateFunction { + args, + filter, + order_by, + .. + }) => { + let mut expr_vec: Vec<_> = args.iter().map(Cow::Borrowed).collect(); if let Some(f) = filter { - expr_vec.push(f.as_ref().clone()); + expr_vec.push(Cow::Borrowed(f)); } if let Some(o) = order_by { - expr_vec.extend(o.clone()); + expr_vec.extend(o.iter().map(Cow::Borrowed).collect::>()); } expr_vec @@ -126,28 +126,17 @@ impl TreeNode for Expr { order_by, .. }) => { - let mut expr_vec = args.clone(); - expr_vec.extend(partition_by.clone()); - expr_vec.extend(order_by.clone()); + let mut expr_vec: Vec<_> = args.iter().map(Cow::Borrowed).collect(); + expr_vec.extend(partition_by.iter().map(Cow::Borrowed).collect::>()); + expr_vec.extend(order_by.iter().map(Cow::Borrowed).collect::>()); expr_vec } Expr::InList(InList { expr, list, .. }) => { - let mut expr_vec = vec![]; - expr_vec.push(expr.as_ref().clone()); - expr_vec.extend(list.clone()); + let mut expr_vec = vec![Cow::Borrowed(expr.as_ref())]; + expr_vec.extend(list.iter().map(Cow::Borrowed).collect::>()); expr_vec } - }; - - for child in children.iter() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } } - - Ok(VisitRecursion::Continue) } fn map_children(self, transform: F) -> Result diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index c7621bc17833..217116530d4a 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -20,8 +20,13 @@ use crate::LogicalPlan; use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; use datafusion_common::{tree_node::TreeNode, Result}; +use std::borrow::Cow; impl TreeNode for LogicalPlan { + fn children_nodes(&self) -> Vec> { + self.inputs().into_iter().map(Cow::Borrowed).collect() + } + fn apply(&self, op: &mut F) -> Result where F: FnMut(&Self) -> Result, @@ -91,21 +96,6 @@ impl TreeNode for LogicalPlan { visitor.post_visit(self) } - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.inputs() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) - } - fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result, diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs index 91238e5b04b4..0205f85dced4 100644 --- a/datafusion/physical-expr/src/sort_properties.rs +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. +use std::borrow::Cow; use std::{ops::Neg, sync::Arc}; use arrow_schema::SortOptions; use crate::PhysicalExpr; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::TreeNode; use datafusion_common::Result; /// To propagate [`SortOptions`] across the [`PhysicalExpr`], it is insufficient @@ -147,7 +148,7 @@ impl Neg for SortProperties { /// It encapsulates the orderings (`state`) associated with the expression (`expr`), and /// orderings of the children expressions (`children_states`). The [`ExprOrdering`] of a parent /// expression is determined based on the [`ExprOrdering`] states of its children expressions. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ExprOrdering { pub expr: Arc, pub state: SortProperties, @@ -173,18 +174,8 @@ impl ExprOrdering { } impl TreeNode for ExprOrdering { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 87ef36558b96..64a62dc7820d 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -18,7 +18,7 @@ mod guarantee; pub use guarantee::{Guarantee, LiteralGuarantee}; -use std::borrow::Borrow; +use std::borrow::{Borrow, Cow}; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -154,19 +154,8 @@ impl ExprTreeNode { } impl TreeNode for ExprTreeNode { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.children() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children().iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result From 8ae7ddc7f9008db39ad86fe0983026a2ac210a5b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Jan 2024 07:13:35 -0500 Subject: [PATCH 58/63] Update sqlparser requirement from 0.40.0 to 0.41.0 (#8647) * Update sqlparser requirement from 0.40.0 to 0.41.0 Updates the requirements on [sqlparser](https://github.com/sqlparser-rs/sqlparser-rs) to permit the latest version. - [Changelog](https://github.com/sqlparser-rs/sqlparser-rs/blob/main/CHANGELOG.md) - [Commits](https://github.com/sqlparser-rs/sqlparser-rs/compare/v0.40.0...v0.40.0) --- updated-dependencies: - dependency-name: sqlparser dependency-type: direct:production ... Signed-off-by: dependabot[bot] * error on unsupported syntax * Update datafusion-cli dependencies * fix test --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Andrew Lamb --- Cargo.toml | 2 +- datafusion-cli/Cargo.lock | 82 +++++++++---------- datafusion/sql/src/statement.rs | 6 ++ .../test_files/repartition_scan.slt | 6 +- 4 files changed, 51 insertions(+), 45 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4ee29ea6298c..a87923b6a1a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,7 +70,7 @@ parquet = { version = "49.0.0", default-features = false, features = ["arrow", " rand = "0.8" rstest = "0.18.0" serde_json = "1" -sqlparser = { version = "0.40.0", features = ["visitor"] } +sqlparser = { version = "0.41.0", features = ["visitor"] } tempfile = "3" thiserror = "1.0.44" url = "2.2" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 8e9bbd8a0dfd..e85e8b1a9edb 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -385,7 +385,7 @@ checksum = "fdf6721fb0140e4f897002dd086c06f6c27775df19cfe1fccb21181a48fd2c98" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -1075,7 +1075,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30d2b3721e861707777e3195b0158f950ae6dc4a27e4d02ff9f67e3eb3de199e" dependencies = [ "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -1525,9 +1525,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" dependencies = [ "futures-channel", "futures-core", @@ -1540,9 +1540,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", "futures-sink", @@ -1550,15 +1550,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" [[package]] name = "futures-executor" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" dependencies = [ "futures-core", "futures-task", @@ -1567,32 +1567,32 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" [[package]] name = "futures-macro" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] name = "futures-sink" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" [[package]] name = "futures-task" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" [[package]] name = "futures-timer" @@ -1602,9 +1602,9 @@ checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" [[package]] name = "futures-util" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-channel", "futures-core", @@ -2286,9 +2286,9 @@ dependencies = [ [[package]] name = "object" -version = "0.32.1" +version = "0.32.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" dependencies = [ "memchr", ] @@ -2499,7 +2499,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -3023,7 +3023,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -3133,9 +3133,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.40.0" +version = "0.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c80afe31cdb649e56c0d9bb5503be9166600d68a852c38dd445636d126858e5" +checksum = "5cc2c25a6c66789625ef164b4c7d2e548d627902280c13710d33da8222169964" dependencies = [ "log", "sqlparser_derive", @@ -3189,7 +3189,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -3211,9 +3211,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.42" +version = "2.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b7d0a2c048d661a1a59fcd7355baa232f7ed34e0ee4df2eef3c1c1c0d3852d8" +checksum = "ee659fb5f3d355364e1f3e5bc10fb82068efbf824a1e9d1c9504244a6469ad53" dependencies = [ "proc-macro2", "quote", @@ -3277,22 +3277,22 @@ checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" [[package]] name = "thiserror" -version = "1.0.51" +version = "1.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f11c217e1416d6f036b870f14e0413d480dbf28edbee1f877abaf0206af43bb7" +checksum = "83a48fd946b02c0a526b2e9481c8e2a17755e47039164a86c4070446e3a4614d" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.51" +version = "1.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01742297787513b79cf8e29d1056ede1313e2420b7b3b15d0a768b4921f549df" +checksum = "e7fbe9b594d6568a6a1443250a7e67d80b74e1e96f6d1715e1e21cc1888291d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -3384,7 +3384,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -3481,7 +3481,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -3526,7 +3526,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -3680,7 +3680,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", "wasm-bindgen-shared", ] @@ -3714,7 +3714,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3978,7 +3978,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 12083554f093..a365d23f435c 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -513,7 +513,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Statement::StartTransaction { modes, begin: false, + modifier, } => { + if let Some(modifier) = modifier { + return not_impl_err!( + "Transaction modifier not supported: {modifier}" + ); + } let isolation_level: ast::TransactionIsolationLevel = modes .iter() .filter_map(|m: &ast::TransactionMode| match m { diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 3cb42c2206ad..02eccd7c5d06 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -185,12 +185,12 @@ COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/j (FORMAT json, SINGLE_FILE_OUTPUT true); statement ok -CREATE EXTERNAL TABLE json_table(column1 int) +CREATE EXTERNAL TABLE json_table (column1 int) STORED AS json LOCATION 'test_files/scratch/repartition_scan/json_table/'; query I -select * from json_table; +select * from "json_table"; ---- 1 2 @@ -200,7 +200,7 @@ select * from json_table; ## Expect to see the scan read the file as "4" groups with even sizes (offsets) query TT -EXPLAIN SELECT column1 FROM json_table WHERE column1 <> 42; +EXPLAIN SELECT column1 FROM "json_table" WHERE column1 <> 42; ---- logical_plan Filter: json_table.column1 != Int32(42) From 4dcfd7dd81153cfc70e5772f70519b7257e31932 Mon Sep 17 00:00:00 2001 From: Jeffrey <22608443+Jefffrey@users.noreply.github.com> Date: Mon, 1 Jan 2024 23:25:37 +1100 Subject: [PATCH 59/63] Update scalar functions doc for extract/datepart (#8682) --- docs/source/user-guide/sql/scalar_functions.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index ad4c6ed083bf..629a5f6ecb88 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1410,6 +1410,7 @@ date_part(part, expression) The following date parts are supported: - year + - quarter _(emits value in inclusive range [1, 4] based on which quartile of the year the date is in)_ - month - week _(week of the year)_ - day _(day of the month)_ @@ -1421,6 +1422,7 @@ date_part(part, expression) - nanosecond - dow _(day of the week)_ - doy _(day of the year)_ + - epoch _(seconds since Unix epoch)_ - **expression**: Time expression to operate on. Can be a constant, column, or function. @@ -1448,6 +1450,7 @@ extract(field FROM source) The following date fields are supported: - year + - quarter _(emits value in inclusive range [1, 4] based on which quartile of the year the date is in)_ - month - week _(week of the year)_ - day _(day of the month)_ @@ -1459,6 +1462,7 @@ extract(field FROM source) - nanosecond - dow _(day of the week)_ - doy _(day of the year)_ + - epoch _(seconds since Unix epoch)_ - **source**: Source time expression to operate on. Can be a constant, column, or function. From 77c2180cf6cb83a3e0aa6356b7017a2ed663d4f1 Mon Sep 17 00:00:00 2001 From: Jeffrey <22608443+Jefffrey@users.noreply.github.com> Date: Tue, 2 Jan 2024 04:30:20 +1100 Subject: [PATCH 60/63] Remove DescribeTableStmt in parser in favour of existing functionality from sqlparser-rs (#8703) --- datafusion/core/src/execution/context/mod.rs | 3 --- datafusion/sql/src/parser.rs | 22 -------------------- datafusion/sql/src/statement.rs | 15 +++++++------ 3 files changed, 7 insertions(+), 33 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 8916fa814a4a..c51f2d132aad 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1621,9 +1621,6 @@ impl SessionState { .0 .insert(ObjectName(vec![Ident::from(table.name.as_str())])); } - DFStatement::DescribeTableStmt(table) => { - visitor.insert(&table.table_name) - } DFStatement::CopyTo(CopyToStatement { source, target: _, diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 9c104ff18a9b..dbd72ec5eb7a 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -213,13 +213,6 @@ impl fmt::Display for CreateExternalTable { } } -/// DataFusion extension DDL for `DESCRIBE TABLE` -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct DescribeTableStmt { - /// Table name - pub table_name: ObjectName, -} - /// DataFusion SQL Statement. /// /// This can either be a [`Statement`] from [`sqlparser`] from a @@ -233,8 +226,6 @@ pub enum Statement { Statement(Box), /// Extension: `CREATE EXTERNAL TABLE` CreateExternalTable(CreateExternalTable), - /// Extension: `DESCRIBE TABLE` - DescribeTableStmt(DescribeTableStmt), /// Extension: `COPY TO` CopyTo(CopyToStatement), /// EXPLAIN for extensions @@ -246,7 +237,6 @@ impl fmt::Display for Statement { match self { Statement::Statement(stmt) => write!(f, "{stmt}"), Statement::CreateExternalTable(stmt) => write!(f, "{stmt}"), - Statement::DescribeTableStmt(_) => write!(f, "DESCRIBE TABLE ..."), Statement::CopyTo(stmt) => write!(f, "{stmt}"), Statement::Explain(stmt) => write!(f, "{stmt}"), } @@ -345,10 +335,6 @@ impl<'a> DFParser<'a> { self.parser.next_token(); // COPY self.parse_copy() } - Keyword::DESCRIBE => { - self.parser.next_token(); // DESCRIBE - self.parse_describe() - } Keyword::EXPLAIN => { // (TODO parse all supported statements) self.parser.next_token(); // EXPLAIN @@ -371,14 +357,6 @@ impl<'a> DFParser<'a> { } } - /// Parse a SQL `DESCRIBE` statement - pub fn parse_describe(&mut self) -> Result { - let table_name = self.parser.parse_object_name()?; - Ok(Statement::DescribeTableStmt(DescribeTableStmt { - table_name, - })) - } - /// Parse a SQL `COPY TO` statement pub fn parse_copy(&mut self) -> Result { // parse as a query diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index a365d23f435c..b96553ffbf86 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -19,8 +19,8 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::Arc; use crate::parser::{ - CopyToSource, CopyToStatement, CreateExternalTable, DFParser, DescribeTableStmt, - ExplainStatement, LexOrdering, Statement as DFStatement, + CopyToSource, CopyToStatement, CreateExternalTable, DFParser, ExplainStatement, + LexOrdering, Statement as DFStatement, }; use crate::planner::{ object_name_to_qualifier, ContextProvider, PlannerContext, SqlToRel, @@ -136,7 +136,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match statement { DFStatement::CreateExternalTable(s) => self.external_table_to_plan(s), DFStatement::Statement(s) => self.sql_statement_to_plan(*s), - DFStatement::DescribeTableStmt(s) => self.describe_table_to_plan(s), DFStatement::CopyTo(s) => self.copy_to_plan(s), DFStatement::Explain(ExplainStatement { verbose, @@ -170,6 +169,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { let sql = Some(statement.to_string()); match statement { + Statement::ExplainTable { + describe_alias: true, // only parse 'DESCRIBE table_name' and not 'EXPLAIN table_name' + table_name, + } => self.describe_table_to_plan(table_name), Statement::Explain { verbose, statement, @@ -635,11 +638,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - fn describe_table_to_plan( - &self, - statement: DescribeTableStmt, - ) -> Result { - let DescribeTableStmt { table_name } = statement; + fn describe_table_to_plan(&self, table_name: ObjectName) -> Result { let table_ref = self.object_name_to_table_reference(table_name)?; let table_source = self.context_provider.get_table_source(table_ref)?; From e82707ec5a912dc5f23e9fe89bea5f49ec64688f Mon Sep 17 00:00:00 2001 From: Ashim Sedhain <38435962+asimsedhain@users.noreply.github.com> Date: Mon, 1 Jan 2024 11:44:27 -0600 Subject: [PATCH 61/63] feat: simplify null in list (#8691) GH-8688 --- .../simplify_expressions/expr_simplifier.rs | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 5a300e2ff246..7d09aec7e748 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -481,6 +481,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { lit(negated) } + // null in (x, y, z) --> null + // null not in (x, y, z) --> null + Expr::InList(InList { + expr, + list: _, + negated: _, + }) if is_null(&expr) => lit_bool_null(), + // expr IN ((subquery)) -> expr IN (subquery), see ##5529 Expr::InList(InList { expr, @@ -3096,6 +3104,18 @@ mod tests { assert_eq!(simplify(in_list(col("c1"), vec![], false)), lit(false)); assert_eq!(simplify(in_list(col("c1"), vec![], true)), lit(true)); + // null in (...) --> null + assert_eq!( + simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], false)), + lit_bool_null() + ); + + // null not in (...) --> null + assert_eq!( + simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], true)), + lit_bool_null() + ); + assert_eq!( simplify(in_list(col("c1"), vec![lit(1)], false)), col("c1").eq(lit(1)) From d2b3d1c7538b9fb7ab9cfc0c4c6a238b0dcd91e6 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Mon, 1 Jan 2024 14:09:41 -0500 Subject: [PATCH 62/63] Rename `expr::window_function::WindowFunction` to `WindowFunctionDefinition`, make structure consistent with ScalarFunction (#8382) * Refactoring WindowFunction into coherent structure with AggregateFunction * One more cargo fmt --------- Co-authored-by: Andrew Lamb --- datafusion/core/src/dataframe/mod.rs | 6 +- .../core/src/physical_optimizer/test_utils.rs | 4 +- datafusion/core/tests/dataframe/mod.rs | 4 +- .../core/tests/fuzz_cases/window_fuzz.rs | 46 +- .../expr/src/built_in_window_function.rs | 207 ++++++++ datafusion/expr/src/expr.rs | 291 ++++++++++- datafusion/expr/src/lib.rs | 6 +- datafusion/expr/src/udwf.rs | 2 +- datafusion/expr/src/utils.rs | 22 +- datafusion/expr/src/window_function.rs | 483 ------------------ .../src/analyzer/count_wildcard_rule.rs | 10 +- .../optimizer/src/analyzer/type_coercion.rs | 8 +- .../optimizer/src/push_down_projection.rs | 6 +- datafusion/physical-plan/src/windows/mod.rs | 28 +- .../proto/src/logical_plan/from_proto.rs | 8 +- datafusion/proto/src/logical_plan/to_proto.rs | 10 +- .../proto/src/physical_plan/from_proto.rs | 10 +- .../tests/cases/roundtrip_logical_plan.rs | 20 +- datafusion/sql/src/expr/function.rs | 19 +- .../substrait/src/logical_plan/consumer.rs | 4 +- 20 files changed, 613 insertions(+), 581 deletions(-) create mode 100644 datafusion/expr/src/built_in_window_function.rs delete mode 100644 datafusion/expr/src/window_function.rs diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 3c3bcd497b7f..5a8c706e32cd 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1360,7 +1360,7 @@ mod tests { use datafusion_expr::{ avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum, BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, - WindowFunction, + WindowFunctionDefinition, }; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::get_plan_string; @@ -1525,7 +1525,9 @@ mod tests { // build plan using Table API let t = test_table().await?; let first_row = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::FirstValue, + ), vec![col("aggregate_test_100.c1")], vec![col("aggregate_test_100.c2")], vec![], diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 6e14cca21fed..debafefe39ab 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -41,7 +41,7 @@ use crate::prelude::{CsvReadOptions, SessionContext}; use arrow_schema::{Schema, SchemaRef, SortOptions}; use datafusion_common::{JoinType, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; +use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunctionDefinition}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -234,7 +234,7 @@ pub fn bounded_window_exec( Arc::new( crate::physical_plan::windows::BoundedWindowAggExec::try_new( vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), "count".to_owned(), &[col(col_name, &schema).unwrap()], &[], diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index ba661aa2445c..cca23ac6847c 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -45,7 +45,7 @@ use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::{ array_agg, avg, col, count, exists, expr, in_subquery, lit, max, out_ref_col, scalar_subquery, sum, wildcard, AggregateFunction, Expr, ExprSchemable, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunction, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_physical_expr::var_provider::{VarProvider, VarType}; @@ -170,7 +170,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { .table("t1") .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 44ff71d02392..3037b4857a3b 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -33,7 +33,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::coerce_types; use datafusion_expr::{ AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunction, + WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -143,7 +143,7 @@ fn get_random_function( schema: &SchemaRef, rng: &mut StdRng, is_linear: bool, -) -> (WindowFunction, Vec>, String) { +) -> (WindowFunctionDefinition, Vec>, String) { let mut args = if is_linear { // In linear test for the test version with WindowAggExec we use insert SortExecs to the plan to be able to generate // same result with BoundedWindowAggExec which doesn't use any SortExec. To make result @@ -159,28 +159,28 @@ fn get_random_function( window_fn_map.insert( "sum", ( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![], ), ); window_fn_map.insert( "count", ( - WindowFunction::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), vec![], ), ); window_fn_map.insert( "min", ( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![], ), ); window_fn_map.insert( "max", ( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![], ), ); @@ -191,28 +191,36 @@ fn get_random_function( window_fn_map.insert( "row_number", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::RowNumber), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::RowNumber, + ), vec![], ), ); window_fn_map.insert( "rank", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Rank), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Rank, + ), vec![], ), ); window_fn_map.insert( "dense_rank", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::DenseRank), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::DenseRank, + ), vec![], ), ); window_fn_map.insert( "lead", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Lead, + ), vec![ lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), @@ -222,7 +230,9 @@ fn get_random_function( window_fn_map.insert( "lag", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Lag, + ), vec![ lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), @@ -233,21 +243,27 @@ fn get_random_function( window_fn_map.insert( "first_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::FirstValue, + ), vec![], ), ); window_fn_map.insert( "last_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::LastValue, + ), vec![], ), ); window_fn_map.insert( "nth_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::NthValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::NthValue, + ), vec![lit(ScalarValue::Int64(Some(rng.gen_range(1..10))))], ), ); @@ -255,7 +271,7 @@ fn get_random_function( let rand_fn_idx = rng.gen_range(0..window_fn_map.len()); let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, new_args) = window_fn_map.values().collect::>()[rand_fn_idx]; - if let WindowFunction::AggregateFunction(f) = window_fn { + if let WindowFunctionDefinition::AggregateFunction(f) = window_fn { let a = args[0].clone(); let dt = a.data_type(schema.as_ref()).unwrap(); let sig = f.signature(); diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs new file mode 100644 index 000000000000..a03e3d2d24a9 --- /dev/null +++ b/datafusion/expr/src/built_in_window_function.rs @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Built-in functions module contains all the built-in functions definitions. + +use std::fmt; +use std::str::FromStr; + +use crate::type_coercion::functions::data_types; +use crate::utils; +use crate::{Signature, TypeSignature, Volatility}; +use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; + +use arrow::datatypes::DataType; + +use strum_macros::EnumIter; + +impl fmt::Display for BuiltInWindowFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.name()) + } +} + +/// A [window function] built in to DataFusion +/// +/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) +#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)] +pub enum BuiltInWindowFunction { + /// number of the current row within its partition, counting from 1 + RowNumber, + /// rank of the current row with gaps; same as row_number of its first peer + Rank, + /// rank of the current row without gaps; this function counts peer groups + DenseRank, + /// relative rank of the current row: (rank - 1) / (total rows - 1) + PercentRank, + /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) + CumeDist, + /// integer ranging from 1 to the argument value, dividing the partition as equally as possible + Ntile, + /// returns value evaluated at the row that is offset rows before the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lag, + /// returns value evaluated at the row that is offset rows after the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lead, + /// returns value evaluated at the row that is the first row of the window frame + FirstValue, + /// returns value evaluated at the row that is the last row of the window frame + LastValue, + /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row + NthValue, +} + +impl BuiltInWindowFunction { + fn name(&self) -> &str { + use BuiltInWindowFunction::*; + match self { + RowNumber => "ROW_NUMBER", + Rank => "RANK", + DenseRank => "DENSE_RANK", + PercentRank => "PERCENT_RANK", + CumeDist => "CUME_DIST", + Ntile => "NTILE", + Lag => "LAG", + Lead => "LEAD", + FirstValue => "FIRST_VALUE", + LastValue => "LAST_VALUE", + NthValue => "NTH_VALUE", + } + } +} + +impl FromStr for BuiltInWindowFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + Ok(match name.to_uppercase().as_str() { + "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, + "RANK" => BuiltInWindowFunction::Rank, + "DENSE_RANK" => BuiltInWindowFunction::DenseRank, + "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, + "CUME_DIST" => BuiltInWindowFunction::CumeDist, + "NTILE" => BuiltInWindowFunction::Ntile, + "LAG" => BuiltInWindowFunction::Lag, + "LEAD" => BuiltInWindowFunction::Lead, + "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, + "LAST_VALUE" => BuiltInWindowFunction::LastValue, + "NTH_VALUE" => BuiltInWindowFunction::NthValue, + _ => return plan_err!("There is no built-in window function named {name}"), + }) + } +} + +/// Returns the datatype of the built-in window function +impl BuiltInWindowFunction { + pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + // Note that this function *must* return the same type that the respective physical expression returns + // or the execution panics. + + // verify that this is a valid set of data types for this function + data_types(input_expr_types, &self.signature()) + // original errors are all related to wrong function signature + // aggregate them for better error message + .map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{self}"), + self.signature(), + input_expr_types, + ) + ) + })?; + + match self { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), + BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { + Ok(DataType::Float64) + } + BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), + BuiltInWindowFunction::Lag + | BuiltInWindowFunction::Lead + | BuiltInWindowFunction::FirstValue + | BuiltInWindowFunction::LastValue + | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), + } + } + + /// the signatures supported by the built-in window function `fun`. + pub fn signature(&self) -> Signature { + // note: the physical expression must accept the type returned by this function or the execution panics. + match self { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank + | BuiltInWindowFunction::PercentRank + | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), + BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { + Signature::one_of( + vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ) + } + BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { + Signature::any(1, Volatility::Immutable) + } + BuiltInWindowFunction::Ntile => Signature::uniform( + 1, + vec![ + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + ], + Volatility::Immutable, + ), + BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use strum::IntoEnumIterator; + #[test] + // Test for BuiltInWindowFunction's Display and from_str() implementations. + // For each variant in BuiltInWindowFunction, it converts the variant to a string + // and then back to a variant. The test asserts that the original variant and + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 + fn test_display_and_from_str() { + for func_original in BuiltInWindowFunction::iter() { + let func_name = func_original.to_string(); + let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap(); + assert_eq!(func_from_str, func_original); + } + } +} diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0ec19bcadbf6..ebf4d3143c12 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -19,13 +19,13 @@ use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; -use crate::udaf; use crate::utils::{expr_to_columns, find_out_reference_exprs}; use crate::window_frame; -use crate::window_function; + use crate::Operator; use crate::{aggregate_function, ExprSchemable}; use crate::{built_in_function, BuiltinScalarFunction}; +use crate::{built_in_window_function, udaf}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; @@ -34,8 +34,11 @@ use std::collections::HashSet; use std::fmt; use std::fmt::{Display, Formatter, Write}; use std::hash::{BuildHasher, Hash, Hasher}; +use std::str::FromStr; use std::sync::Arc; +use crate::Signature; + /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS /// int)`. @@ -566,11 +569,64 @@ impl AggregateFunction { } } +/// WindowFunction +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of an aggregate function DataFusion should call. +pub enum WindowFunctionDefinition { + /// A built in aggregate function that leverages an aggregate function + AggregateFunction(aggregate_function::AggregateFunction), + /// A a built-in window function + BuiltInWindowFunction(built_in_window_function::BuiltInWindowFunction), + /// A user defined aggregate function + AggregateUDF(Arc), + /// A user defined aggregate function + WindowUDF(Arc), +} + +impl WindowFunctionDefinition { + /// Returns the datatype of the window function + pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::AggregateUDF(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::WindowUDF(fun) => fun.return_type(input_expr_types), + } + } + + /// the signatures supported by the function `fun`. + pub fn signature(&self) -> Signature { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => fun.signature(), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.signature(), + WindowFunctionDefinition::AggregateUDF(fun) => fun.signature().clone(), + WindowFunctionDefinition::WindowUDF(fun) => fun.signature().clone(), + } + } +} + +impl fmt::Display for WindowFunctionDefinition { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => fun.fmt(f), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.fmt(f), + WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), + WindowFunctionDefinition::WindowUDF(fun) => fun.fmt(f), + } + } +} + /// Window function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct WindowFunction { /// Name of the function - pub fun: window_function::WindowFunction, + pub fun: WindowFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, /// List of partition by expressions @@ -584,7 +640,7 @@ pub struct WindowFunction { impl WindowFunction { /// Create a new Window expression pub fn new( - fun: window_function::WindowFunction, + fun: WindowFunctionDefinition, args: Vec, partition_by: Vec, order_by: Vec, @@ -600,6 +656,50 @@ impl WindowFunction { } } +/// Find DataFusion's built-in window function by name. +pub fn find_df_window_func(name: &str) -> Option { + let name = name.to_lowercase(); + // Code paths for window functions leveraging ordinary aggregators and + // built-in window functions are quite different, and the same function + // may have different implementations for these cases. If the sought + // function is not found among built-in window functions, we search for + // it among aggregate functions. + if let Ok(built_in_function) = + built_in_window_function::BuiltInWindowFunction::from_str(name.as_str()) + { + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_function, + )) + } else if let Ok(aggregate) = + aggregate_function::AggregateFunction::from_str(name.as_str()) + { + Some(WindowFunctionDefinition::AggregateFunction(aggregate)) + } else { + None + } +} + +/// Returns the datatype of the window function +#[deprecated( + since = "27.0.0", + note = "please use `WindowFunction::return_type` instead" +)] +pub fn return_type( + fun: &WindowFunctionDefinition, + input_expr_types: &[DataType], +) -> Result { + fun.return_type(input_expr_types) +} + +/// the signatures supported by the function `fun`. +#[deprecated( + since = "27.0.0", + note = "please use `WindowFunction::signature` instead" +)] +pub fn signature(fun: &WindowFunctionDefinition) -> Signature { + fun.signature() +} + // Exists expression. #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Exists { @@ -1890,4 +1990,187 @@ mod test { .is_volatile() .expect_err("Shouldn't determine volatility of unresolved function"); } + + use super::*; + + #[test] + fn test_count_return_type() -> Result<()> { + let fun = find_df_window_func("count").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Int64, observed); + + let observed = fun.return_type(&[DataType::UInt64])?; + assert_eq!(DataType::Int64, observed); + + Ok(()) + } + + #[test] + fn test_first_value_return_type() -> Result<()> { + let fun = find_df_window_func("first_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::UInt64])?; + assert_eq!(DataType::UInt64, observed); + + Ok(()) + } + + #[test] + fn test_last_value_return_type() -> Result<()> { + let fun = find_df_window_func("last_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_lead_return_type() -> Result<()> { + let fun = find_df_window_func("lead").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_lag_return_type() -> Result<()> { + let fun = find_df_window_func("lag").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_nth_value_return_type() -> Result<()> { + let fun = find_df_window_func("nth_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_percent_rank_return_type() -> Result<()> { + let fun = find_df_window_func("percent_rank").unwrap(); + let observed = fun.return_type(&[])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_cume_dist_return_type() -> Result<()> { + let fun = find_df_window_func("cume_dist").unwrap(); + let observed = fun.return_type(&[])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_ntile_return_type() -> Result<()> { + let fun = find_df_window_func("ntile").unwrap(); + let observed = fun.return_type(&[DataType::Int16])?; + assert_eq!(DataType::UInt64, observed); + + Ok(()) + } + + #[test] + fn test_window_function_case_insensitive() -> Result<()> { + let names = vec![ + "row_number", + "rank", + "dense_rank", + "percent_rank", + "cume_dist", + "ntile", + "lag", + "lead", + "first_value", + "last_value", + "nth_value", + "min", + "max", + "count", + "avg", + "sum", + ]; + for name in names { + let fun = find_df_window_func(name).unwrap(); + let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); + assert_eq!(fun, fun2); + assert_eq!(fun.to_string(), name.to_uppercase()); + } + Ok(()) + } + + #[test] + fn test_find_df_window_function() { + assert_eq!( + find_df_window_func("max"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Max + )) + ); + assert_eq!( + find_df_window_func("min"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Min + )) + ); + assert_eq!( + find_df_window_func("avg"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Avg + )) + ); + assert_eq!( + find_df_window_func("cume_dist"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::CumeDist + )) + ); + assert_eq!( + find_df_window_func("first_value"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::FirstValue + )) + ); + assert_eq!( + find_df_window_func("LAST_value"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::LastValue + )) + ); + assert_eq!( + find_df_window_func("LAG"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::Lag + )) + ); + assert_eq!( + find_df_window_func("LEAD"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::Lead + )) + ); + assert_eq!(find_df_window_func("not_exist"), None) + } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index bf8e9e2954f4..ab213a19a352 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -27,6 +27,7 @@ mod accumulator; mod built_in_function; +mod built_in_window_function; mod columnar_value; mod literal; mod nullif; @@ -53,16 +54,16 @@ pub mod tree_node; pub mod type_coercion; pub mod utils; pub mod window_frame; -pub mod window_function; pub mod window_state; pub use accumulator::Accumulator; pub use aggregate_function::AggregateFunction; pub use built_in_function::BuiltinScalarFunction; +pub use built_in_window_function::BuiltInWindowFunction; pub use columnar_value::ColumnarValue; pub use expr::{ Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, - Like, ScalarFunctionDefinition, TryCast, + Like, ScalarFunctionDefinition, TryCast, WindowFunctionDefinition, }; pub use expr_fn::*; pub use expr_schema::ExprSchemable; @@ -83,7 +84,6 @@ pub use udaf::AggregateUDF; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::WindowUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; -pub use window_function::{BuiltInWindowFunction, WindowFunction}; #[cfg(test)] #[ctor::ctor] diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index c233ee84b32d..a97a68341f5c 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -107,7 +107,7 @@ impl WindowUDF { order_by: Vec, window_frame: WindowFrame, ) -> Expr { - let fun = crate::WindowFunction::WindowUDF(Arc::new(self.clone())); + let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); Expr::WindowFunction(crate::expr::WindowFunction { fun, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 09f4842c9e64..e3ecdf154e61 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1234,7 +1234,7 @@ mod tests { use super::*; use crate::{ col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, AggregateFunction, - WindowFrame, WindowFunction, + WindowFrame, WindowFunctionDefinition, }; #[test] @@ -1248,28 +1248,28 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![], @@ -1291,28 +1291,28 @@ mod tests { let created_at_desc = Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], WindowFrame::new(true), )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], WindowFrame::new(true), )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], @@ -1343,7 +1343,7 @@ mod tests { fn test_find_sort_exprs() -> Result<()> { let exprs = &[ Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![ @@ -1353,7 +1353,7 @@ mod tests { WindowFrame::new(true), )), Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![ diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs deleted file mode 100644 index 610f1ecaeae9..000000000000 --- a/datafusion/expr/src/window_function.rs +++ /dev/null @@ -1,483 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Window functions provide the ability to perform calculations across -//! sets of rows that are related to the current query row. -//! -//! see also - -use crate::aggregate_function::AggregateFunction; -use crate::type_coercion::functions::data_types; -use crate::utils; -use crate::{AggregateUDF, Signature, TypeSignature, Volatility, WindowUDF}; -use arrow::datatypes::DataType; -use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; -use std::sync::Arc; -use std::{fmt, str::FromStr}; -use strum_macros::EnumIter; - -/// WindowFunction -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum WindowFunction { - /// A built in aggregate function that leverages an aggregate function - AggregateFunction(AggregateFunction), - /// A a built-in window function - BuiltInWindowFunction(BuiltInWindowFunction), - /// A user defined aggregate function - AggregateUDF(Arc), - /// A user defined aggregate function - WindowUDF(Arc), -} - -/// Find DataFusion's built-in window function by name. -pub fn find_df_window_func(name: &str) -> Option { - let name = name.to_lowercase(); - // Code paths for window functions leveraging ordinary aggregators and - // built-in window functions are quite different, and the same function - // may have different implementations for these cases. If the sought - // function is not found among built-in window functions, we search for - // it among aggregate functions. - if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) { - Some(WindowFunction::BuiltInWindowFunction(built_in_function)) - } else if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) { - Some(WindowFunction::AggregateFunction(aggregate)) - } else { - None - } -} - -impl fmt::Display for BuiltInWindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.name()) - } -} - -impl fmt::Display for WindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - WindowFunction::AggregateFunction(fun) => fun.fmt(f), - WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), - WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), - WindowFunction::WindowUDF(fun) => fun.fmt(f), - } - } -} - -/// A [window function] built in to DataFusion -/// -/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) -#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)] -pub enum BuiltInWindowFunction { - /// number of the current row within its partition, counting from 1 - RowNumber, - /// rank of the current row with gaps; same as row_number of its first peer - Rank, - /// rank of the current row without gaps; this function counts peer groups - DenseRank, - /// relative rank of the current row: (rank - 1) / (total rows - 1) - PercentRank, - /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) - CumeDist, - /// integer ranging from 1 to the argument value, dividing the partition as equally as possible - Ntile, - /// returns value evaluated at the row that is offset rows before the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lag, - /// returns value evaluated at the row that is offset rows after the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lead, - /// returns value evaluated at the row that is the first row of the window frame - FirstValue, - /// returns value evaluated at the row that is the last row of the window frame - LastValue, - /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row - NthValue, -} - -impl BuiltInWindowFunction { - fn name(&self) -> &str { - use BuiltInWindowFunction::*; - match self { - RowNumber => "ROW_NUMBER", - Rank => "RANK", - DenseRank => "DENSE_RANK", - PercentRank => "PERCENT_RANK", - CumeDist => "CUME_DIST", - Ntile => "NTILE", - Lag => "LAG", - Lead => "LEAD", - FirstValue => "FIRST_VALUE", - LastValue => "LAST_VALUE", - NthValue => "NTH_VALUE", - } - } -} - -impl FromStr for BuiltInWindowFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name.to_uppercase().as_str() { - "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, - "RANK" => BuiltInWindowFunction::Rank, - "DENSE_RANK" => BuiltInWindowFunction::DenseRank, - "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, - "CUME_DIST" => BuiltInWindowFunction::CumeDist, - "NTILE" => BuiltInWindowFunction::Ntile, - "LAG" => BuiltInWindowFunction::Lag, - "LEAD" => BuiltInWindowFunction::Lead, - "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, - "LAST_VALUE" => BuiltInWindowFunction::LastValue, - "NTH_VALUE" => BuiltInWindowFunction::NthValue, - _ => return plan_err!("There is no built-in window function named {name}"), - }) - } -} - -/// Returns the datatype of the window function -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::return_type` instead" -)] -pub fn return_type( - fun: &WindowFunction, - input_expr_types: &[DataType], -) -> Result { - fun.return_type(input_expr_types) -} - -impl WindowFunction { - /// Returns the datatype of the window function - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { - match self { - WindowFunction::AggregateFunction(fun) => fun.return_type(input_expr_types), - WindowFunction::BuiltInWindowFunction(fun) => { - fun.return_type(input_expr_types) - } - WindowFunction::AggregateUDF(fun) => fun.return_type(input_expr_types), - WindowFunction::WindowUDF(fun) => fun.return_type(input_expr_types), - } - } -} - -/// Returns the datatype of the built-in window function -impl BuiltInWindowFunction { - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &self.signature()) - // original errors are all related to wrong function signature - // aggregate them for better error message - .map_err(|_| { - plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types, - ) - ) - })?; - - match self { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), - BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { - Ok(DataType::Float64) - } - BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue - | BuiltInWindowFunction::LastValue - | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), - } - } -} - -/// the signatures supported by the function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::signature` instead" -)] -pub fn signature(fun: &WindowFunction) -> Signature { - fun.signature() -} - -impl WindowFunction { - /// the signatures supported by the function `fun`. - pub fn signature(&self) -> Signature { - match self { - WindowFunction::AggregateFunction(fun) => fun.signature(), - WindowFunction::BuiltInWindowFunction(fun) => fun.signature(), - WindowFunction::AggregateUDF(fun) => fun.signature().clone(), - WindowFunction::WindowUDF(fun) => fun.signature().clone(), - } - } -} - -/// the signatures supported by the built-in window function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `BuiltInWindowFunction::signature` instead" -)] -pub fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature { - fun.signature() -} - -impl BuiltInWindowFunction { - /// the signatures supported by the built-in window function `fun`. - pub fn signature(&self) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. - match self { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::PercentRank - | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), - BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { - Signature::one_of( - vec![ - TypeSignature::Any(1), - TypeSignature::Any(2), - TypeSignature::Any(3), - ], - Volatility::Immutable, - ) - } - BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { - Signature::any(1, Volatility::Immutable) - } - BuiltInWindowFunction::Ntile => Signature::uniform( - 1, - vec![ - DataType::UInt64, - DataType::UInt32, - DataType::UInt16, - DataType::UInt8, - DataType::Int64, - DataType::Int32, - DataType::Int16, - DataType::Int8, - ], - Volatility::Immutable, - ), - BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use strum::IntoEnumIterator; - - #[test] - fn test_count_return_type() -> Result<()> { - let fun = find_df_window_func("count").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::Int64, observed); - - Ok(()) - } - - #[test] - fn test_first_value_return_type() -> Result<()> { - let fun = find_df_window_func("first_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::UInt64, observed); - - Ok(()) - } - - #[test] - fn test_last_value_return_type() -> Result<()> { - let fun = find_df_window_func("last_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lead_return_type() -> Result<()> { - let fun = find_df_window_func("lead").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lag_return_type() -> Result<()> { - let fun = find_df_window_func("lag").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_nth_value_return_type() -> Result<()> { - let fun = find_df_window_func("nth_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_percent_rank_return_type() -> Result<()> { - let fun = find_df_window_func("percent_rank").unwrap(); - let observed = fun.return_type(&[])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_cume_dist_return_type() -> Result<()> { - let fun = find_df_window_func("cume_dist").unwrap(); - let observed = fun.return_type(&[])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_ntile_return_type() -> Result<()> { - let fun = find_df_window_func("ntile").unwrap(); - let observed = fun.return_type(&[DataType::Int16])?; - assert_eq!(DataType::UInt64, observed); - - Ok(()) - } - - #[test] - fn test_window_function_case_insensitive() -> Result<()> { - let names = vec![ - "row_number", - "rank", - "dense_rank", - "percent_rank", - "cume_dist", - "ntile", - "lag", - "lead", - "first_value", - "last_value", - "nth_value", - "min", - "max", - "count", - "avg", - "sum", - ]; - for name in names { - let fun = find_df_window_func(name).unwrap(); - let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); - assert_eq!(fun, fun2); - assert_eq!(fun.to_string(), name.to_uppercase()); - } - Ok(()) - } - - #[test] - fn test_find_df_window_function() { - assert_eq!( - find_df_window_func("max"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Max)) - ); - assert_eq!( - find_df_window_func("min"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Min)) - ); - assert_eq!( - find_df_window_func("avg"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Avg)) - ); - assert_eq!( - find_df_window_func("cume_dist"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::CumeDist - )) - ); - assert_eq!( - find_df_window_func("first_value"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::FirstValue - )) - ); - assert_eq!( - find_df_window_func("LAST_value"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::LastValue - )) - ); - assert_eq!( - find_df_window_func("LAG"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::Lag - )) - ); - assert_eq!( - find_df_window_func("LEAD"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::Lead - )) - ); - assert_eq!(find_df_window_func("not_exist"), None) - } - - #[test] - // Test for BuiltInWindowFunction's Display and from_str() implementations. - // For each variant in BuiltInWindowFunction, it converts the variant to a string - // and then back to a variant. The test asserts that the original variant and - // the reconstructed variant are the same. This assertion is also necessary for - // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 - fn test_display_and_from_str() { - for func_original in BuiltInWindowFunction::iter() { - let func_name = func_original.to_string(); - let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap(); - assert_eq!(func_from_str, func_original); - } - } -} diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index fd84bb80160b..953716713e41 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -24,7 +24,7 @@ use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::Expr::ScalarSubquery; use datafusion_expr::{ - aggregate_function, expr, lit, window_function, Aggregate, Expr, Filter, LogicalPlan, + aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, Sort, Subquery, }; use std::sync::Arc; @@ -121,7 +121,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { let new_expr = match old_expr.clone() { Expr::WindowFunction(expr::WindowFunction { fun: - window_function::WindowFunction::AggregateFunction( + expr::WindowFunctionDefinition::AggregateFunction( aggregate_function::AggregateFunction::Count, ), args, @@ -131,7 +131,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { Expr::WindowFunction(expr::WindowFunction { - fun: window_function::WindowFunction::AggregateFunction( + fun: expr::WindowFunctionDefinition::AggregateFunction( aggregate_function::AggregateFunction::Count, ), args: vec![lit(COUNT_STAR_EXPANSION)], @@ -229,7 +229,7 @@ mod tests { use datafusion_expr::{ col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder, max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -342,7 +342,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index b6298f5b552f..4d54dad99670 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -45,9 +45,9 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, - type_coercion, window_function, AggregateFunction, BuiltinScalarFunction, Expr, - ExprSchemable, LogicalPlan, Operator, Projection, ScalarFunctionDefinition, - Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, + type_coercion, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, + LogicalPlan, Operator, Projection, ScalarFunctionDefinition, Signature, WindowFrame, + WindowFrameBound, WindowFrameUnits, }; use crate::analyzer::AnalyzerRule; @@ -390,7 +390,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { coerce_window_frame(window_frame, &self.schema, &order_by)?; let args = match &fun { - window_function::WindowFunction::AggregateFunction(fun) => { + expr::WindowFunctionDefinition::AggregateFunction(fun) => { coerce_agg_exprs_for_signature( fun, &args, diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 10cc1879aeeb..4ee4f7e417a6 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -37,7 +37,7 @@ mod tests { }; use datafusion_expr::{ col, count, lit, max, min, AggregateFunction, Expr, LogicalPlan, Projection, - WindowFrame, WindowFunction, + WindowFrame, WindowFunctionDefinition, }; #[test] @@ -582,7 +582,7 @@ mod tests { let table_scan = test_table_scan()?; let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.a")], vec![col("test.b")], vec![], @@ -590,7 +590,7 @@ mod tests { )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.b")], vec![], vec![], diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 3187e6b0fbd3..fec168fabf48 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -34,8 +34,8 @@ use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ - window_function::{BuiltInWindowFunction, WindowFunction}, - PartitionEvaluator, WindowFrame, WindowUDF, + BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, + WindowUDF, }; use datafusion_physical_expr::equivalence::collapse_lex_req; use datafusion_physical_expr::{ @@ -56,7 +56,7 @@ pub use datafusion_physical_expr::window::{ /// Create a physical expression for window function pub fn create_window_expr( - fun: &WindowFunction, + fun: &WindowFunctionDefinition, name: String, args: &[Arc], partition_by: &[Arc], @@ -65,7 +65,7 @@ pub fn create_window_expr( input_schema: &Schema, ) -> Result> { Ok(match fun { - WindowFunction::AggregateFunction(fun) => { + WindowFunctionDefinition::AggregateFunction(fun) => { let aggregate = aggregates::create_aggregate_expr( fun, false, @@ -81,13 +81,15 @@ pub fn create_window_expr( aggregate, ) } - WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr::new( - create_built_in_window_expr(fun, args, input_schema, name)?, - partition_by, - order_by, - window_frame, - )), - WindowFunction::AggregateUDF(fun) => { + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + Arc::new(BuiltInWindowExpr::new( + create_built_in_window_expr(fun, args, input_schema, name)?, + partition_by, + order_by, + window_frame, + )) + } + WindowFunctionDefinition::AggregateUDF(fun) => { let aggregate = udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?; window_expr_from_aggregate_expr( @@ -97,7 +99,7 @@ pub fn create_window_expr( aggregate, ) } - WindowFunction::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( + WindowFunctionDefinition::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( create_udwf_window_expr(fun, args, input_schema, name)?, partition_by, order_by, @@ -647,7 +649,7 @@ mod tests { let refs = blocking_exec.refs(); let window_agg_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), "count".to_owned(), &[col("a", &schema)?], &[], diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index c582e92dc11c..36c5b44f00b9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1112,7 +1112,7 @@ pub fn parse_expr( let aggr_function = parse_i32_to_aggregate_function(i)?; Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::AggregateFunction( + datafusion_expr::expr::WindowFunctionDefinition::AggregateFunction( aggr_function, ), vec![parse_required_expr(expr.expr.as_deref(), registry, "expr")?], @@ -1131,7 +1131,7 @@ pub fn parse_expr( .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::BuiltInWindowFunction( + datafusion_expr::expr::WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, ), args, @@ -1146,7 +1146,7 @@ pub fn parse_expr( .map(|e| vec![e]) .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::AggregateUDF( + datafusion_expr::expr::WindowFunctionDefinition::AggregateUDF( udaf_function, ), args, @@ -1161,7 +1161,7 @@ pub fn parse_expr( .map(|e| vec![e]) .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::WindowUDF( + datafusion_expr::expr::WindowFunctionDefinition::WindowUDF( udwf_function, ), args, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index b9987ff6c727..a162b2389cd1 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -51,7 +51,7 @@ use datafusion_expr::expr::{ use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, Expr, JoinConstraint, JoinType, - TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; #[derive(Debug)] @@ -605,22 +605,22 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { ref window_frame, }) => { let window_function = match fun { - WindowFunction::AggregateFunction(fun) => { + WindowFunctionDefinition::AggregateFunction(fun) => { protobuf::window_expr_node::WindowFunction::AggrFunction( protobuf::AggregateFunction::from(fun).into(), ) } - WindowFunction::BuiltInWindowFunction(fun) => { + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { protobuf::window_expr_node::WindowFunction::BuiltInFunction( protobuf::BuiltInWindowFunction::from(fun).into(), ) } - WindowFunction::AggregateUDF(aggr_udf) => { + WindowFunctionDefinition::AggregateUDF(aggr_udf) => { protobuf::window_expr_node::WindowFunction::Udaf( aggr_udf.name().to_string(), ) } - WindowFunction::WindowUDF(window_udf) => { + WindowFunctionDefinition::WindowUDF(window_udf) => { protobuf::window_expr_node::WindowFunction::Udwf( window_udf.name().to_string(), ) diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 8ad6d679df4d..23ab813ca739 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -31,7 +31,7 @@ use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; use datafusion::execution::context::ExecutionProps; use datafusion::execution::FunctionRegistry; -use datafusion::logical_expr::window_function::WindowFunction; +use datafusion::logical_expr::WindowFunctionDefinition; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, @@ -414,7 +414,9 @@ fn parse_required_physical_expr( }) } -impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFunction { +impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> + for WindowFunctionDefinition +{ type Error = DataFusionError; fn try_from( @@ -428,7 +430,7 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun )) })?; - Ok(WindowFunction::AggregateFunction(f.into())) + Ok(WindowFunctionDefinition::AggregateFunction(f.into())) } protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => { let f = protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| { @@ -437,7 +439,7 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun )) })?; - Ok(WindowFunction::BuiltInWindowFunction(f.into())) + Ok(WindowFunctionDefinition::BuiltInWindowFunction(f.into())) } } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 2d7d85abda96..dea99f91e392 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -53,7 +53,7 @@ use datafusion_expr::{ col, create_udaf, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::{Sqrt, Substr}, Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, WindowUDF, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -1663,8 +1663,8 @@ fn roundtrip_window() { // 1. without window_frame let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1674,8 +1674,8 @@ fn roundtrip_window() { // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1691,8 +1691,8 @@ fn roundtrip_window() { }; let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1708,7 +1708,7 @@ fn roundtrip_window() { }; let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("col1")], vec![col("col1")], vec![col("col2")], @@ -1759,7 +1759,7 @@ fn roundtrip_window() { ); let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateUDF(Arc::new(dummy_agg.clone())), + WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], vec![col("col1")], vec![col("col2")], @@ -1808,7 +1808,7 @@ fn roundtrip_window() { ); let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::WindowUDF(Arc::new(dummy_window_udf.clone())), + WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], vec![col("col1")], vec![col("col2")], diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 3934d6701c63..395f10b6f783 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -23,8 +23,8 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::function::suggest_valid_function; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, - WindowFunction, + expr, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, + WindowFunctionDefinition, }; use sqlparser::ast::{ Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, WindowType, @@ -121,12 +121,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if let Ok(fun) = self.find_window_func(&name) { let expr = match fun { - WindowFunction::AggregateFunction(aggregate_fun) => { + WindowFunctionDefinition::AggregateFunction(aggregate_fun) => { let args = self.function_args_to_expr(args, schema, planner_context)?; Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(aggregate_fun), + WindowFunctionDefinition::AggregateFunction(aggregate_fun), args, partition_by, order_by, @@ -191,19 +191,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) } - pub(super) fn find_window_func(&self, name: &str) -> Result { - window_function::find_df_window_func(name) + pub(super) fn find_window_func( + &self, + name: &str, + ) -> Result { + expr::find_df_window_func(name) // next check user defined aggregates .or_else(|| { self.context_provider .get_aggregate_meta(name) - .map(WindowFunction::AggregateUDF) + .map(WindowFunctionDefinition::AggregateUDF) }) // next check user defined window functions .or_else(|| { self.context_provider .get_window_meta(name) - .map(WindowFunction::WindowUDF) + .map(WindowFunctionDefinition::WindowUDF) }) .ok_or_else(|| { plan_datafusion_err!("There is no window function named {name}") diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 9931dd15aec8..a4ec3e7722a2 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -23,8 +23,8 @@ use datafusion::common::{ use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ - aggregate_function, window_function::find_df_window_func, BinaryExpr, - BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, + aggregate_function, expr::find_df_window_func, BinaryExpr, BuiltinScalarFunction, + Case, Expr, LogicalPlan, Operator, }; use datafusion::logical_expr::{ expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, From bf0a39a791e7cd0e965abb8c87950cc4101149f7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 2 Jan 2024 00:28:36 -0800 Subject: [PATCH 63/63] Deprecate duplicate function `LogicalPlan::with_new_inputs` (#8707) * Remove duplicate function with_new_inputs * Make it as deprecated function --- datafusion/expr/src/logical_plan/builder.rs | 2 +- datafusion/expr/src/logical_plan/plan.rs | 47 ++----------------- datafusion/expr/src/tree_node/plan.rs | 2 +- .../optimizer/src/eliminate_outer_join.rs | 3 +- .../optimizer/src/optimize_projections.rs | 3 +- datafusion/optimizer/src/optimizer.rs | 2 +- datafusion/optimizer/src/push_down_filter.rs | 28 +++++++---- datafusion/optimizer/src/push_down_limit.rs | 23 +++++---- datafusion/optimizer/src/utils.rs | 2 +- 9 files changed, 45 insertions(+), 67 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 549c25f89bae..cfc052cfc14c 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -445,7 +445,7 @@ impl LogicalPlanBuilder { ) }) .collect::>>()?; - curr_plan.with_new_inputs(&new_inputs) + curr_plan.with_new_exprs(curr_plan.expressions(), &new_inputs) } } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9b0f441ef902..c0c520c4e211 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -541,35 +541,9 @@ impl LogicalPlan { } /// Returns a copy of this `LogicalPlan` with the new inputs + #[deprecated(since = "35.0.0", note = "please use `with_new_exprs` instead")] pub fn with_new_inputs(&self, inputs: &[LogicalPlan]) -> Result { - // with_new_inputs use original expression, - // so we don't need to recompute Schema. - match &self { - LogicalPlan::Projection(projection) => { - // Schema of the projection may change - // when its input changes. Hence we should use - // `try_new` method instead of `try_new_with_schema`. - Projection::try_new(projection.expr.to_vec(), Arc::new(inputs[0].clone())) - .map(LogicalPlan::Projection) - } - LogicalPlan::Window(Window { window_expr, .. }) => Ok(LogicalPlan::Window( - Window::try_new(window_expr.to_vec(), Arc::new(inputs[0].clone()))?, - )), - LogicalPlan::Aggregate(Aggregate { - group_expr, - aggr_expr, - .. - }) => Aggregate::try_new( - // Schema of the aggregate may change - // when its input changes. Hence we should use - // `try_new` method instead of `try_new_with_schema`. - Arc::new(inputs[0].clone()), - group_expr.to_vec(), - aggr_expr.to_vec(), - ) - .map(LogicalPlan::Aggregate), - _ => self.with_new_exprs(self.expressions(), inputs), - } + self.with_new_exprs(self.expressions(), inputs) } /// Returns a new `LogicalPlan` based on `self` with inputs and @@ -591,10 +565,6 @@ impl LogicalPlan { /// // create new plan using rewritten_exprs in same position /// let new_plan = plan.new_with_exprs(rewritten_exprs, new_inputs); /// ``` - /// - /// Note: sometimes [`Self::with_new_exprs`] will use schema of - /// original plan, it will not change the scheam. Such as - /// `Projection/Aggregate/Window` pub fn with_new_exprs( &self, mut expr: Vec, @@ -706,17 +676,10 @@ impl LogicalPlan { })) } }, - LogicalPlan::Window(Window { - window_expr, - schema, - .. - }) => { + LogicalPlan::Window(Window { window_expr, .. }) => { assert_eq!(window_expr.len(), expr.len()); - Ok(LogicalPlan::Window(Window { - input: Arc::new(inputs[0].clone()), - window_expr: expr, - schema: schema.clone(), - })) + Window::try_new(expr, Arc::new(inputs[0].clone())) + .map(LogicalPlan::Window) } LogicalPlan::Aggregate(Aggregate { group_expr, .. }) => { // group exprs are the first expressions diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 217116530d4a..208a8b57d7b0 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -113,7 +113,7 @@ impl TreeNode for LogicalPlan { .zip(new_children.iter()) .any(|(c1, c2)| c1 != &c2) { - self.with_new_inputs(new_children.as_slice()) + self.with_new_exprs(self.expressions(), new_children.as_slice()) } else { Ok(self) } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index e4d57f0209a4..53c4b3702b1e 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -106,7 +106,8 @@ impl OptimizerRule for EliminateOuterJoin { schema: join.schema.clone(), null_equals_null: join.null_equals_null, }); - let new_plan = plan.with_new_inputs(&[new_join])?; + let new_plan = + plan.with_new_exprs(plan.expressions(), &[new_join])?; Ok(Some(new_plan)) } _ => Ok(None), diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 7ae9f7edf5e5..891a909a3378 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -373,7 +373,8 @@ fn optimize_projections( // `old_child` during construction: .map(|(new_input, old_child)| new_input.unwrap_or_else(|| old_child.clone())) .collect::>(); - plan.with_new_inputs(&new_inputs).map(Some) + plan.with_new_exprs(plan.expressions(), &new_inputs) + .map(Some) } } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 0dc34cb809eb..2cb59d511ccf 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -382,7 +382,7 @@ impl Optimizer { }) .collect::>(); - Ok(Some(plan.with_new_inputs(&new_inputs)?)) + Ok(Some(plan.with_new_exprs(plan.expressions(), &new_inputs)?)) } /// Use a rule to optimize the whole plan. diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 9d277d18d2f7..4eb925ac0629 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -691,9 +691,11 @@ impl OptimizerRule for PushDownFilter { | LogicalPlan::Distinct(_) | LogicalPlan::Sort(_) => { // commutable - let new_filter = - plan.with_new_inputs(&[child_plan.inputs()[0].clone()])?; - child_plan.with_new_inputs(&[new_filter])? + let new_filter = plan.with_new_exprs( + plan.expressions(), + &[child_plan.inputs()[0].clone()], + )?; + child_plan.with_new_exprs(child_plan.expressions(), &[new_filter])? } LogicalPlan::SubqueryAlias(subquery_alias) => { let mut replace_map = HashMap::new(); @@ -716,7 +718,7 @@ impl OptimizerRule for PushDownFilter { new_predicate, subquery_alias.input.clone(), )?); - child_plan.with_new_inputs(&[new_filter])? + child_plan.with_new_exprs(child_plan.expressions(), &[new_filter])? } LogicalPlan::Projection(projection) => { // A projection is filter-commutable if it do not contain volatile predicates or contain volatile @@ -760,10 +762,15 @@ impl OptimizerRule for PushDownFilter { )?); match conjunction(keep_predicates) { - None => child_plan.with_new_inputs(&[new_filter])?, + None => child_plan.with_new_exprs( + child_plan.expressions(), + &[new_filter], + )?, Some(keep_predicate) => { - let child_plan = - child_plan.with_new_inputs(&[new_filter])?; + let child_plan = child_plan.with_new_exprs( + child_plan.expressions(), + &[new_filter], + )?; LogicalPlan::Filter(Filter::try_new( keep_predicate, Arc::new(child_plan), @@ -837,7 +844,9 @@ impl OptimizerRule for PushDownFilter { )?), None => (*agg.input).clone(), }; - let new_agg = filter.input.with_new_inputs(&vec![child])?; + let new_agg = filter + .input + .with_new_exprs(filter.input.expressions(), &vec![child])?; match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( predicate, @@ -942,7 +951,8 @@ impl OptimizerRule for PushDownFilter { None => extension_plan.node.inputs().into_iter().cloned().collect(), }; // extension with new inputs. - let new_extension = child_plan.with_new_inputs(&new_children)?; + let new_extension = + child_plan.with_new_exprs(child_plan.expressions(), &new_children)?; match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 6703a1d787a7..c2f35a790616 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -126,7 +126,7 @@ impl OptimizerRule for PushDownLimit { fetch: scan.fetch.map(|x| min(x, limit)).or(Some(limit)), projected_schema: scan.projected_schema.clone(), }); - Some(plan.with_new_inputs(&[new_input])?) + Some(plan.with_new_exprs(plan.expressions(), &[new_input])?) } } LogicalPlan::Union(union) => { @@ -145,7 +145,7 @@ impl OptimizerRule for PushDownLimit { inputs: new_inputs, schema: union.schema.clone(), }); - Some(plan.with_new_inputs(&[union])?) + Some(plan.with_new_exprs(plan.expressions(), &[union])?) } LogicalPlan::CrossJoin(cross_join) => { @@ -166,15 +166,16 @@ impl OptimizerRule for PushDownLimit { right: Arc::new(new_right), schema: plan.schema().clone(), }); - Some(plan.with_new_inputs(&[new_cross_join])?) + Some(plan.with_new_exprs(plan.expressions(), &[new_cross_join])?) } LogicalPlan::Join(join) => { let new_join = push_down_join(join, fetch + skip); match new_join { - Some(new_join) => { - Some(plan.with_new_inputs(&[LogicalPlan::Join(new_join)])?) - } + Some(new_join) => Some(plan.with_new_exprs( + plan.expressions(), + &[LogicalPlan::Join(new_join)], + )?), None => None, } } @@ -192,14 +193,16 @@ impl OptimizerRule for PushDownLimit { input: Arc::new((*sort.input).clone()), fetch: new_fetch, }); - Some(plan.with_new_inputs(&[new_sort])?) + Some(plan.with_new_exprs(plan.expressions(), &[new_sort])?) } } LogicalPlan::Projection(_) | LogicalPlan::SubqueryAlias(_) => { // commute - let new_limit = - plan.with_new_inputs(&[child_plan.inputs()[0].clone()])?; - Some(child_plan.with_new_inputs(&[new_limit])?) + let new_limit = plan.with_new_exprs( + plan.expressions(), + &[child_plan.inputs()[0].clone()], + )?; + Some(child_plan.with_new_exprs(child_plan.expressions(), &[new_limit])?) } _ => None, }; diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 48f72ee7a0f8..44f2404afade 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -46,7 +46,7 @@ pub fn optimize_children( new_inputs.push(new_input.unwrap_or_else(|| input.clone())) } if plan_is_changed { - Ok(Some(plan.with_new_inputs(&new_inputs)?)) + Ok(Some(plan.with_new_exprs(plan.expressions(), &new_inputs)?)) } else { Ok(None) }