From 668bff97ff96b69b53920476a3302712698fd916 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 3 Jul 2023 12:31:27 -0700 Subject: [PATCH 1/7] Fix cargo build warning (#6831) --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index fac3f990fd2c..1e493f864c03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ members = [ "test-utils", "benchmarks", ] +resolver = "2" [workspace.package] version = "27.0.0" From bee22654c26633500ea93c425f5d085a8a66b86a Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Tue, 4 Jul 2023 03:34:07 +0800 Subject: [PATCH 2/7] Simplify IsUnkown and IsNotUnkown expression (#6830) --- .../simplify_expressions/expr_simplifier.rs | 33 ++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 96a31212c6ed..d88459bcf48a 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1187,11 +1187,17 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { lit(!negated) } - // a IS NOT NULL --> true, if a is not nullable - Expr::IsNotNull(expr) if !info.nullable(&expr)? => lit(true), + // a is not null/unknown --> true (if a is not nullable) + Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr) + if !info.nullable(&expr)? => + { + lit(true) + } - // a IS NULL --> false, if a is not nullable - Expr::IsNull(expr) if !info.nullable(&expr)? => lit(false), + // a is null/unknown --> false (if a is not nullable) + Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => { + lit(false) + } // no additional rewrites possible expr => expr, @@ -2726,6 +2732,25 @@ mod tests { ); } + #[test] + fn simplify_expr_is_unknown() { + assert_eq!(simplify(col("c2").is_unknown()), col("c2").is_unknown(),); + + // 'c2_non_null is unknown' is always false + assert_eq!(simplify(col("c2_non_null").is_unknown()), lit(false)); + } + + #[test] + fn simplify_expr_is_not_known() { + assert_eq!( + simplify(col("c2").is_not_unknown()), + col("c2").is_not_unknown() + ); + + // 'c2_non_null is not unknown' is always true + assert_eq!(simplify(col("c2_non_null").is_not_unknown()), lit(true)); + } + #[test] fn simplify_expr_eq() { let schema = expr_test_schema(); From 2983a7ff0fcd16130243e4d7e4fb9f0efe69978a Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Tue, 4 Jul 2023 03:36:30 +0800 Subject: [PATCH 3/7] fix: incorrect nullability of `Like` expressions (#6829) * fix: incorrect nullability of Like expr * Improve the documentation for ScalarType --- datafusion/common/src/scalar.rs | 2 +- datafusion/expr/src/expr_schema.rs | 24 +++++++++++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index e84ef545198e..044d40534ec1 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -3802,7 +3802,7 @@ impl fmt::Debug for ScalarValue { } } -/// Trait used to map a NativeTime to a ScalarType. +/// Trait used to map a NativeType to a ScalarValue pub trait ScalarType { /// returns a scalar from an optional T fn scalar(r: Option) -> ScalarValue; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 30537d0fdd81..76f37e4d6cec 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -251,9 +251,11 @@ impl ExprSchemable for Expr { ref right, .. }) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), - Expr::Like(Like { expr, .. }) => expr.nullable(input_schema), - Expr::ILike(Like { expr, .. }) => expr.nullable(input_schema), - Expr::SimilarTo(Like { expr, .. }) => expr.nullable(input_schema), + Expr::Like(Like { expr, pattern, .. }) + | Expr::ILike(Like { expr, pattern, .. }) + | Expr::SimilarTo(Like { expr, pattern, .. }) => { + Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?) + } Expr::Wildcard => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), @@ -437,6 +439,22 @@ mod tests { assert!(expr.nullable(&get_schema(false)).unwrap()); } + #[test] + fn test_like_nullability() { + let get_schema = |nullable| { + MockExprSchema::new() + .with_data_type(DataType::Utf8) + .with_nullable(nullable) + }; + + let expr = col("foo").like(lit("bar")); + assert!(!expr.nullable(&get_schema(false)).unwrap()); + assert!(expr.nullable(&get_schema(true)).unwrap()); + + let expr = col("foo").like(lit(ScalarValue::Utf8(None))); + assert!(expr.nullable(&get_schema(false)).unwrap()); + } + #[test] fn expr_schema_data_type() { let expr = col("foo"); From 4ff0a8ede5c853fcbc2b7cac078bb1a6fce87df6 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 3 Jul 2023 16:38:06 -0400 Subject: [PATCH 4/7] Minor: Add one more assert to `hash_array_primitive` (#6834) --- datafusion/physical-expr/src/hash_utils.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/datafusion/physical-expr/src/hash_utils.rs b/datafusion/physical-expr/src/hash_utils.rs index 44699285a915..cf0a3f5eab89 100644 --- a/datafusion/physical-expr/src/hash_utils.rs +++ b/datafusion/physical-expr/src/hash_utils.rs @@ -96,6 +96,12 @@ fn hash_array_primitive( T: ArrowPrimitiveType, ::Native: HashValue, { + assert_eq!( + hashes_buffer.len(), + array.len(), + "hashes_buffer and array should be of equal length" + ); + if array.null_count() == 0 { if rehash { for (hash, &value) in hashes_buffer.iter_mut().zip(array.values().iter()) { From 137bf81a39d4e4279a79f31aadee5bd75612017a Mon Sep 17 00:00:00 2001 From: jakevin Date: Tue, 4 Jul 2023 14:59:04 +0800 Subject: [PATCH 5/7] revert #6595 #6820 (#6827) * revert: from_plan keep same schema Project in #6595 * revert: from_plan keep same schema Agg/Window in #6820 * revert type coercion * add comment --- datafusion/common/src/dfschema.rs | 8 +--- datafusion/expr/src/utils.rs | 44 ++++++++++++------- .../optimizer/src/analyzer/type_coercion.rs | 11 ++++- 3 files changed, 38 insertions(+), 25 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index c490852c6ee3..cb07f15b9d26 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -384,12 +384,8 @@ impl DFSchema { let self_fields = self.fields().iter(); let other_fields = other.fields().iter(); self_fields.zip(other_fields).all(|(f1, f2)| { - // TODO: resolve field when exist alias - // f1.qualifier() == f2.qualifier() - // && f1.name() == f2.name() - // column(t1.a) field is "t1"."a" - // column(x) as t1.a field is ""."t1.a" - f1.qualified_name() == f2.qualified_name() + f1.qualifier() == f2.qualifier() + && f1.name() == f2.name() && Self::datatype_is_semantically_equal(f1.data_type(), f2.data_type()) }) } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 069ce6df71bc..3111579246f2 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -724,16 +724,22 @@ where /// // create new plan using rewritten_exprs in same position /// let new_plan = from_plan(&plan, rewritten_exprs, new_inputs); /// ``` +/// +/// Notice: sometimes [from_plan] will use schema of original plan, it don't change schema! +/// Such as `Projection/Aggregate/Window` pub fn from_plan( plan: &LogicalPlan, expr: &[Expr], inputs: &[LogicalPlan], ) -> Result { match plan { - LogicalPlan::Projection(_) => Ok(LogicalPlan::Projection(Projection::try_new( - expr.to_vec(), - Arc::new(inputs[0].clone()), - )?)), + LogicalPlan::Projection(Projection { schema, .. }) => { + Ok(LogicalPlan::Projection(Projection::try_new_with_schema( + expr.to_vec(), + Arc::new(inputs[0].clone()), + schema.clone(), + )?)) + } LogicalPlan::Dml(DmlStatement { table_name, table_schema, @@ -818,19 +824,23 @@ pub fn from_plan( input: Arc::new(inputs[0].clone()), })), }, - LogicalPlan::Window(Window { window_expr, .. }) => { - Ok(LogicalPlan::Window(Window::try_new( - expr[0..window_expr.len()].to_vec(), - Arc::new(inputs[0].clone()), - )?)) - } - LogicalPlan::Aggregate(Aggregate { group_expr, .. }) => { - Ok(LogicalPlan::Aggregate(Aggregate::try_new( - Arc::new(inputs[0].clone()), - expr[0..group_expr.len()].to_vec(), - expr[group_expr.len()..].to_vec(), - )?)) - } + LogicalPlan::Window(Window { + window_expr, + schema, + .. + }) => Ok(LogicalPlan::Window(Window { + input: Arc::new(inputs[0].clone()), + window_expr: expr[0..window_expr.len()].to_vec(), + schema: schema.clone(), + })), + LogicalPlan::Aggregate(Aggregate { + group_expr, schema, .. + }) => Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema( + Arc::new(inputs[0].clone()), + expr[0..group_expr.len()].to_vec(), + expr[group_expr.len()..].to_vec(), + schema.clone(), + )?)), LogicalPlan::Sort(SortPlan { fetch, .. }) => Ok(LogicalPlan::Sort(SortPlan { expr: expr.to_vec(), input: Arc::new(inputs[0].clone()), diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 8edf734b474f..5d1fef53520b 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -43,7 +43,7 @@ use datafusion_expr::utils::from_plan; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, type_coercion, AggregateFunction, BuiltinScalarFunction, Expr, LogicalPlan, Operator, - WindowFrame, WindowFrameBound, WindowFrameUnits, + Projection, WindowFrame, WindowFrameBound, WindowFrameUnits, }; use datafusion_expr::{ExprSchemable, Signature}; @@ -109,7 +109,14 @@ fn analyze_internal( }) .collect::>>()?; - from_plan(plan, &new_expr, &new_inputs) + // TODO: from_plan can't change the schema, so we need to do this here + match &plan { + LogicalPlan::Projection(_) => Ok(LogicalPlan::Projection(Projection::try_new( + new_expr, + Arc::new(new_inputs[0].clone()), + )?)), + _ => from_plan(plan, &new_expr, &new_inputs), + } } pub(crate) struct TypeCoercionRewriter { From 07a721f67fd9501159f33ec20f5ac670584c8e9f Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Tue, 4 Jul 2023 13:11:44 +0100 Subject: [PATCH 6/7] Add Duration to ScalarValue (#6838) --- datafusion/common/src/scalar.rs | 102 +++++++++++++++++- datafusion/proto/proto/datafusion.proto | 6 ++ datafusion/proto/src/generated/pbjson.rs | 52 +++++++++ datafusion/proto/src/generated/prost.rs | 10 +- .../proto/src/logical_plan/from_proto.rs | 4 + datafusion/proto/src/logical_plan/to_proto.rs | 100 ++++++++++------- 6 files changed, 236 insertions(+), 38 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 044d40534ec1..4fef60020f77 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -132,6 +132,14 @@ pub enum ScalarValue { /// Months and days are encoded as 32-bit signed integers. /// Nanoseconds is encoded as a 64-bit signed integer (no leap seconds). IntervalMonthDayNano(Option), + /// Duration in seconds + DurationSecond(Option), + /// Duration in milliseconds + DurationMillisecond(Option), + /// Duration in microseconds + DurationMicrosecond(Option), + /// Duration in nanoseconds + DurationNanosecond(Option), /// struct of nested ScalarValue Struct(Option>, Fields), /// Dictionary type: index type and value @@ -210,6 +218,14 @@ impl PartialEq for ScalarValue { (TimestampMicrosecond(_, _), _) => false, (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => v1.eq(v2), (TimestampNanosecond(_, _), _) => false, + (DurationSecond(v1), DurationSecond(v2)) => v1.eq(v2), + (DurationSecond(_), _) => false, + (DurationMillisecond(v1), DurationMillisecond(v2)) => v1.eq(v2), + (DurationMillisecond(_), _) => false, + (DurationMicrosecond(v1), DurationMicrosecond(v2)) => v1.eq(v2), + (DurationMicrosecond(_), _) => false, + (DurationNanosecond(v1), DurationNanosecond(v2)) => v1.eq(v2), + (DurationNanosecond(_), _) => false, (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.eq(v2), (IntervalYearMonth(v1), IntervalDayTime(v2)) => { ym_to_milli(v1).eq(&dt_to_milli(v2)) @@ -357,6 +373,14 @@ impl PartialOrd for ScalarValue { mdn_to_nano(v1).partial_cmp(&dt_to_nano(v2)) } (IntervalMonthDayNano(_), _) => None, + (DurationSecond(v1), DurationSecond(v2)) => v1.partial_cmp(v2), + (DurationSecond(_), _) => None, + (DurationMillisecond(v1), DurationMillisecond(v2)) => v1.partial_cmp(v2), + (DurationMillisecond(_), _) => None, + (DurationMicrosecond(v1), DurationMicrosecond(v2)) => v1.partial_cmp(v2), + (DurationMicrosecond(_), _) => None, + (DurationNanosecond(v1), DurationNanosecond(v2)) => v1.partial_cmp(v2), + (DurationNanosecond(_), _) => None, (Struct(v1, t1), Struct(v2, t2)) => { if t1.eq(t2) { v1.partial_cmp(v2) @@ -1508,6 +1532,10 @@ impl std::hash::Hash for ScalarValue { TimestampMillisecond(v, _) => v.hash(state), TimestampMicrosecond(v, _) => v.hash(state), TimestampNanosecond(v, _) => v.hash(state), + DurationSecond(v) => v.hash(state), + DurationMillisecond(v) => v.hash(state), + DurationMicrosecond(v) => v.hash(state), + DurationNanosecond(v) => v.hash(state), IntervalYearMonth(v) => v.hash(state), IntervalDayTime(v) => v.hash(state), IntervalMonthDayNano(v) => v.hash(state), @@ -1984,6 +2012,16 @@ impl ScalarValue { ScalarValue::IntervalMonthDayNano(_) => { DataType::Interval(IntervalUnit::MonthDayNano) } + ScalarValue::DurationSecond(_) => DataType::Duration(TimeUnit::Second), + ScalarValue::DurationMillisecond(_) => { + DataType::Duration(TimeUnit::Millisecond) + } + ScalarValue::DurationMicrosecond(_) => { + DataType::Duration(TimeUnit::Microsecond) + } + ScalarValue::DurationNanosecond(_) => { + DataType::Duration(TimeUnit::Nanosecond) + } ScalarValue::Struct(_, fields) => DataType::Struct(fields.clone()), ScalarValue::Dictionary(k, v) => { DataType::Dictionary(k.clone(), Box::new(v.get_datatype())) @@ -2118,6 +2156,10 @@ impl ScalarValue { ScalarValue::IntervalYearMonth(v) => v.is_none(), ScalarValue::IntervalDayTime(v) => v.is_none(), ScalarValue::IntervalMonthDayNano(v) => v.is_none(), + ScalarValue::DurationSecond(v) => v.is_none(), + ScalarValue::DurationMillisecond(v) => v.is_none(), + ScalarValue::DurationMicrosecond(v) => v.is_none(), + ScalarValue::DurationNanosecond(v) => v.is_none(), ScalarValue::Struct(v, _) => v.is_none(), ScalarValue::Dictionary(_, v) => v.is_null(), } @@ -2897,6 +2939,34 @@ impl ScalarValue { e, size ), + ScalarValue::DurationSecond(e) => build_array_from_option!( + Duration, + TimeUnit::Second, + DurationSecondArray, + e, + size + ), + ScalarValue::DurationMillisecond(e) => build_array_from_option!( + Duration, + TimeUnit::Millisecond, + DurationMillisecondArray, + e, + size + ), + ScalarValue::DurationMicrosecond(e) => build_array_from_option!( + Duration, + TimeUnit::Microsecond, + DurationMicrosecondArray, + e, + size + ), + ScalarValue::DurationNanosecond(e) => build_array_from_option!( + Duration, + TimeUnit::Nanosecond, + DurationNanosecondArray, + e, + size + ), ScalarValue::Struct(values, fields) => match values { Some(values) => { let field_values: Vec<_> = fields @@ -3264,6 +3334,18 @@ impl ScalarValue { ScalarValue::IntervalMonthDayNano(val) => { eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val) } + ScalarValue::DurationSecond(val) => { + eq_array_primitive!(array, index, DurationSecondArray, val) + } + ScalarValue::DurationMillisecond(val) => { + eq_array_primitive!(array, index, DurationMillisecondArray, val) + } + ScalarValue::DurationMicrosecond(val) => { + eq_array_primitive!(array, index, DurationMicrosecondArray, val) + } + ScalarValue::DurationNanosecond(val) => { + eq_array_primitive!(array, index, DurationNanosecondArray, val) + } ScalarValue::Struct(_, _) => unimplemented!(), ScalarValue::Dictionary(key_type, v) => { let (values_array, values_index) = match key_type.as_ref() { @@ -3313,7 +3395,11 @@ impl ScalarValue { | ScalarValue::Time64Nanosecond(_) | ScalarValue::IntervalYearMonth(_) | ScalarValue::IntervalDayTime(_) - | ScalarValue::IntervalMonthDayNano(_) => 0, + | ScalarValue::IntervalMonthDayNano(_) + | ScalarValue::DurationSecond(_) + | ScalarValue::DurationMillisecond(_) + | ScalarValue::DurationMicrosecond(_) + | ScalarValue::DurationNanosecond(_) => 0, ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) => { s.as_ref().map(|s| s.capacity()).unwrap_or_default() } @@ -3699,6 +3785,10 @@ impl fmt::Display for ScalarValue { ScalarValue::IntervalDayTime(e) => format_option!(f, e)?, ScalarValue::IntervalYearMonth(e) => format_option!(f, e)?, ScalarValue::IntervalMonthDayNano(e) => format_option!(f, e)?, + ScalarValue::DurationSecond(e) => format_option!(f, e)?, + ScalarValue::DurationMillisecond(e) => format_option!(f, e)?, + ScalarValue::DurationMicrosecond(e) => format_option!(f, e)?, + ScalarValue::DurationNanosecond(e) => format_option!(f, e)?, ScalarValue::Struct(e, fields) => match e { Some(l) => write!( f, @@ -3781,6 +3871,16 @@ impl fmt::Debug for ScalarValue { ScalarValue::IntervalMonthDayNano(_) => { write!(f, "IntervalMonthDayNano(\"{self}\")") } + ScalarValue::DurationSecond(_) => write!(f, "DurationSecond(\"{self}\")"), + ScalarValue::DurationMillisecond(_) => { + write!(f, "DurationMillisecond(\"{self}\")") + } + ScalarValue::DurationMicrosecond(_) => { + write!(f, "DurationMicrosecond(\"{self}\")") + } + ScalarValue::DurationNanosecond(_) => { + write!(f, "DurationNanosecond(\"{self}\")") + } ScalarValue::Struct(e, fields) => { // Use Debug representation of field values match e { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 0d61cd2b3573..81a8bc6b2342 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -904,6 +904,12 @@ message ScalarValue{ int64 date_64_value = 21; int32 interval_yearmonth_value = 24; int64 interval_daytime_value = 25; + + int64 duration_second_value = 35; + int64 duration_millisecond_value = 36; + int64 duration_microsecond_value = 37; + int64 duration_nanosecond_value = 38; + ScalarTimestampValue timestamp_value = 26; ScalarDictionaryValue dictionary_value = 27; bytes binary_value = 28; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 831dd49618f7..3c7763a15463 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -18938,6 +18938,18 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::IntervalDaytimeValue(v) => { struct_ser.serialize_field("intervalDaytimeValue", ToString::to_string(&v).as_str())?; } + scalar_value::Value::DurationSecondValue(v) => { + struct_ser.serialize_field("durationSecondValue", ToString::to_string(&v).as_str())?; + } + scalar_value::Value::DurationMillisecondValue(v) => { + struct_ser.serialize_field("durationMillisecondValue", ToString::to_string(&v).as_str())?; + } + scalar_value::Value::DurationMicrosecondValue(v) => { + struct_ser.serialize_field("durationMicrosecondValue", ToString::to_string(&v).as_str())?; + } + scalar_value::Value::DurationNanosecondValue(v) => { + struct_ser.serialize_field("durationNanosecondValue", ToString::to_string(&v).as_str())?; + } scalar_value::Value::TimestampValue(v) => { struct_ser.serialize_field("timestampValue", v)?; } @@ -19016,6 +19028,14 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "intervalYearmonthValue", "interval_daytime_value", "intervalDaytimeValue", + "duration_second_value", + "durationSecondValue", + "duration_millisecond_value", + "durationMillisecondValue", + "duration_microsecond_value", + "durationMicrosecondValue", + "duration_nanosecond_value", + "durationNanosecondValue", "timestamp_value", "timestampValue", "dictionary_value", @@ -19057,6 +19077,10 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { Date64Value, IntervalYearmonthValue, IntervalDaytimeValue, + DurationSecondValue, + DurationMillisecondValue, + DurationMicrosecondValue, + DurationNanosecondValue, TimestampValue, DictionaryValue, BinaryValue, @@ -19107,6 +19131,10 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "date64Value" | "date_64_value" => Ok(GeneratedField::Date64Value), "intervalYearmonthValue" | "interval_yearmonth_value" => Ok(GeneratedField::IntervalYearmonthValue), "intervalDaytimeValue" | "interval_daytime_value" => Ok(GeneratedField::IntervalDaytimeValue), + "durationSecondValue" | "duration_second_value" => Ok(GeneratedField::DurationSecondValue), + "durationMillisecondValue" | "duration_millisecond_value" => Ok(GeneratedField::DurationMillisecondValue), + "durationMicrosecondValue" | "duration_microsecond_value" => Ok(GeneratedField::DurationMicrosecondValue), + "durationNanosecondValue" | "duration_nanosecond_value" => Ok(GeneratedField::DurationNanosecondValue), "timestampValue" | "timestamp_value" => Ok(GeneratedField::TimestampValue), "dictionaryValue" | "dictionary_value" => Ok(GeneratedField::DictionaryValue), "binaryValue" | "binary_value" => Ok(GeneratedField::BinaryValue), @@ -19267,6 +19295,30 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { } value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::IntervalDaytimeValue(x.0)); } + GeneratedField::DurationSecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("durationSecondValue")); + } + value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationSecondValue(x.0)); + } + GeneratedField::DurationMillisecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("durationMillisecondValue")); + } + value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationMillisecondValue(x.0)); + } + GeneratedField::DurationMicrosecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("durationMicrosecondValue")); + } + value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationMicrosecondValue(x.0)); + } + GeneratedField::DurationNanosecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("durationNanosecondValue")); + } + value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationNanosecondValue(x.0)); + } GeneratedField::TimestampValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("timestampValue")); diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index e6c076e7d453..aca90c5f57b8 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1087,7 +1087,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 20, 21, 24, 25, 26, 27, 28, 29, 30, 31, 32, 34" + tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 20, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" )] pub value: ::core::option::Option, } @@ -1142,6 +1142,14 @@ pub mod scalar_value { IntervalYearmonthValue(i32), #[prost(int64, tag = "25")] IntervalDaytimeValue(i64), + #[prost(int64, tag = "35")] + DurationSecondValue(i64), + #[prost(int64, tag = "36")] + DurationMillisecondValue(i64), + #[prost(int64, tag = "37")] + DurationMicrosecondValue(i64), + #[prost(int64, tag = "38")] + DurationNanosecondValue(i64), #[prost(message, tag = "26")] TimestampValue(super::ScalarTimestampValue), #[prost(message, tag = "27")] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b4a49713c244..c4dc8eb9b256 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -680,6 +680,10 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { } Value::IntervalYearmonthValue(v) => Self::IntervalYearMonth(Some(*v)), Value::IntervalDaytimeValue(v) => Self::IntervalDayTime(Some(*v)), + Value::DurationSecondValue(v) => Self::DurationSecond(Some(*v)), + Value::DurationMillisecondValue(v) => Self::DurationMillisecond(Some(*v)), + Value::DurationMicrosecondValue(v) => Self::DurationMicrosecond(Some(*v)), + Value::DurationNanosecondValue(v) => Self::DurationNanosecond(Some(*v)), Value::TimestampValue(v) => { let timezone = if v.timezone.is_empty() { None diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 5b09aee91095..d81e92c3f3d3 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1013,63 +1013,62 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { type Error = Error; fn try_from(val: &ScalarValue) -> Result { - use datafusion_common::scalar; use protobuf::scalar_value::Value; let data_type = val.get_datatype(); match val { - scalar::ScalarValue::Boolean(val) => { + ScalarValue::Boolean(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::BoolValue(*s)) } - scalar::ScalarValue::Float32(val) => { + ScalarValue::Float32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Float32Value(*s)) } - scalar::ScalarValue::Float64(val) => { + ScalarValue::Float64(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Float64Value(*s)) } - scalar::ScalarValue::Int8(val) => { + ScalarValue::Int8(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::Int8Value(*s as i32) }) } - scalar::ScalarValue::Int16(val) => { + ScalarValue::Int16(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::Int16Value(*s as i32) }) } - scalar::ScalarValue::Int32(val) => { + ScalarValue::Int32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Int32Value(*s)) } - scalar::ScalarValue::Int64(val) => { + ScalarValue::Int64(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Int64Value(*s)) } - scalar::ScalarValue::UInt8(val) => { + ScalarValue::UInt8(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::Uint8Value(*s as u32) }) } - scalar::ScalarValue::UInt16(val) => { + ScalarValue::UInt16(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::Uint16Value(*s as u32) }) } - scalar::ScalarValue::UInt32(val) => { + ScalarValue::UInt32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Uint32Value(*s)) } - scalar::ScalarValue::UInt64(val) => { + ScalarValue::UInt64(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Uint64Value(*s)) } - scalar::ScalarValue::Utf8(val) => { + ScalarValue::Utf8(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::Utf8Value(s.to_owned()) }) } - scalar::ScalarValue::LargeUtf8(val) => { + ScalarValue::LargeUtf8(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::LargeUtf8Value(s.to_owned()) }) } - scalar::ScalarValue::List(values, boxed_field) => { + ScalarValue::List(values, boxed_field) => { let is_null = values.is_none(); let values = if let Some(values) = values.as_ref() { @@ -1093,10 +1092,10 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { )), }) } - datafusion::scalar::ScalarValue::Date32(val) => { + ScalarValue::Date32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s)) } - datafusion::scalar::ScalarValue::TimestampMicrosecond(val, tz) => { + ScalarValue::TimestampMicrosecond(val, tz) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::TimestampValue(protobuf::ScalarTimestampValue { timezone: tz.as_deref().unwrap_or("").to_string(), @@ -1108,7 +1107,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) }) } - datafusion::scalar::ScalarValue::TimestampNanosecond(val, tz) => { + ScalarValue::TimestampNanosecond(val, tz) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::TimestampValue(protobuf::ScalarTimestampValue { timezone: tz.as_deref().unwrap_or("").to_string(), @@ -1120,7 +1119,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) }) } - datafusion::scalar::ScalarValue::Decimal128(val, p, s) => match *val { + ScalarValue::Decimal128(val, p, s) => match *val { Some(v) => { let array = v.to_be_bytes(); let vec_val: Vec = array.to_vec(); @@ -1138,10 +1137,10 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { )), }), }, - datafusion::scalar::ScalarValue::Date64(val) => { + ScalarValue::Date64(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date64Value(*s)) } - datafusion::scalar::ScalarValue::TimestampSecond(val, tz) => { + ScalarValue::TimestampSecond(val, tz) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::TimestampValue(protobuf::ScalarTimestampValue { timezone: tz.as_deref().unwrap_or("").to_string(), @@ -1151,7 +1150,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) }) } - datafusion::scalar::ScalarValue::TimestampMillisecond(val, tz) => { + ScalarValue::TimestampMillisecond(val, tz) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::TimestampValue(protobuf::ScalarTimestampValue { timezone: tz.as_deref().unwrap_or("").to_string(), @@ -1163,31 +1162,31 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) }) } - datafusion::scalar::ScalarValue::IntervalYearMonth(val) => { + ScalarValue::IntervalYearMonth(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::IntervalYearmonthValue(*s) }) } - datafusion::scalar::ScalarValue::IntervalDayTime(val) => { + ScalarValue::IntervalDayTime(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::IntervalDaytimeValue(*s) }) } - datafusion::scalar::ScalarValue::Null => Ok(protobuf::ScalarValue { + ScalarValue::Null => Ok(protobuf::ScalarValue { value: Some(Value::NullValue((&data_type).try_into()?)), }), - scalar::ScalarValue::Binary(val) => { + ScalarValue::Binary(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::BinaryValue(s.to_owned()) }) } - scalar::ScalarValue::LargeBinary(val) => { + ScalarValue::LargeBinary(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::LargeBinaryValue(s.to_owned()) }) } - scalar::ScalarValue::FixedSizeBinary(length, val) => { + ScalarValue::FixedSizeBinary(length, val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::FixedSizeBinaryValue(protobuf::ScalarFixedSizeBinary { values: s.to_owned(), @@ -1196,7 +1195,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) } - datafusion::scalar::ScalarValue::Time32Second(v) => { + ScalarValue::Time32Second(v) => { create_proto_scalar(v.as_ref(), &data_type, |v| { Value::Time32Value(protobuf::ScalarTime32Value { value: Some( @@ -1206,7 +1205,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) } - datafusion::scalar::ScalarValue::Time32Millisecond(v) => { + ScalarValue::Time32Millisecond(v) => { create_proto_scalar(v.as_ref(), &data_type, |v| { Value::Time32Value(protobuf::ScalarTime32Value { value: Some( @@ -1218,7 +1217,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) } - datafusion::scalar::ScalarValue::Time64Microsecond(v) => { + ScalarValue::Time64Microsecond(v) => { create_proto_scalar(v.as_ref(), &data_type, |v| { Value::Time64Value(protobuf::ScalarTime64Value { value: Some( @@ -1230,7 +1229,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) } - datafusion::scalar::ScalarValue::Time64Nanosecond(v) => { + ScalarValue::Time64Nanosecond(v) => { create_proto_scalar(v.as_ref(), &data_type, |v| { Value::Time64Value(protobuf::ScalarTime64Value { value: Some( @@ -1242,7 +1241,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) } - datafusion::scalar::ScalarValue::IntervalMonthDayNano(v) => { + ScalarValue::IntervalMonthDayNano(v) => { let value = if let Some(v) = v { let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v); Value::IntervalMonthDayNano(protobuf::IntervalMonthDayNanoValue { @@ -1251,13 +1250,42 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { nanos, }) } else { - protobuf::scalar_value::Value::NullValue((&data_type).try_into()?) + Value::NullValue((&data_type).try_into()?) }; Ok(protobuf::ScalarValue { value: Some(value) }) } - datafusion::scalar::ScalarValue::Struct(values, fields) => { + ScalarValue::DurationSecond(v) => { + let value = match v { + Some(v) => Value::DurationSecondValue(*v), + None => Value::NullValue((&data_type).try_into()?), + }; + Ok(protobuf::ScalarValue { value: Some(value) }) + } + ScalarValue::DurationMillisecond(v) => { + let value = match v { + Some(v) => Value::DurationMillisecondValue(*v), + None => Value::NullValue((&data_type).try_into()?), + }; + Ok(protobuf::ScalarValue { value: Some(value) }) + } + ScalarValue::DurationMicrosecond(v) => { + let value = match v { + Some(v) => Value::DurationMicrosecondValue(*v), + None => Value::NullValue((&data_type).try_into()?), + }; + Ok(protobuf::ScalarValue { value: Some(value) }) + } + ScalarValue::DurationNanosecond(v) => { + let value = match v { + Some(v) => Value::DurationNanosecondValue(*v), + None => Value::NullValue((&data_type).try_into()?), + }; + Ok(protobuf::ScalarValue { value: Some(value) }) + } + + ScalarValue::Struct(values, fields) => { // encode null as empty field values list let field_values = if let Some(values) = values { if values.is_empty() { @@ -1284,7 +1312,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) } - datafusion::scalar::ScalarValue::Dictionary(index_type, val) => { + ScalarValue::Dictionary(index_type, val) => { let value: protobuf::ScalarValue = val.as_ref().try_into()?; Ok(protobuf::ScalarValue { value: Some(Value::DictionaryValue(Box::new( From 02a470f6061cce8ee8e57f7af8a6a0e0ddc1571b Mon Sep 17 00:00:00 2001 From: Armin Primadi Date: Tue, 4 Jul 2023 20:00:33 +0700 Subject: [PATCH 7/7] Replace AbortOnDrop / AbortDropOnMany with tokio JoinSet (#6750) * Use JoinSet in MemTable * Fix error handling * Refactor AbortOnDropSingle in csv physical plan * Fix csv write physical plan error propagation * Refactor json write physical plan to use JoinSet * Refactor parquet write physical plan to use JoinSet * Refactor collect_partitioned to use JoinSet * Refactor pull_from_input method to make it easier to read * Fix typo --- datafusion/core/src/datasource/memory.rs | 39 +++++++++-------- .../core/src/datasource/physical_plan/csv.rs | 32 ++++++++------ .../core/src/datasource/physical_plan/json.rs | 32 ++++++++------ .../src/datasource/physical_plan/parquet.rs | 43 ++++++++++--------- datafusion/core/src/physical_plan/mod.rs | 41 ++++++++++++------ .../core/src/physical_plan/repartition/mod.rs | 33 +++++++------- 6 files changed, 127 insertions(+), 93 deletions(-) diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 784aa2aff232..5398bb0903ca 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -29,12 +29,12 @@ use async_trait::async_trait; use datafusion_common::SchemaExt; use datafusion_execution::TaskContext; use tokio::sync::RwLock; +use tokio::task::JoinSet; use crate::datasource::{TableProvider, TableType}; use crate::error::{DataFusionError, Result}; use crate::execution::context::SessionState; use crate::logical_expr::Expr; -use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::insert::{DataSink, InsertExec}; use crate::physical_plan::memory::MemoryExec; use crate::physical_plan::{common, SendableRecordBatchStream}; @@ -89,26 +89,31 @@ impl MemTable { let exec = t.scan(state, None, &[], None).await?; let partition_count = exec.output_partitioning().partition_count(); - let tasks = (0..partition_count) - .map(|part_i| { - let task = state.task_ctx(); - let exec = exec.clone(); - let task = tokio::spawn(async move { - let stream = exec.execute(part_i, task)?; - common::collect(stream).await - }); - - AbortOnDropSingle::new(task) - }) - // this collect *is needed* so that the join below can - // switch between tasks - .collect::>(); + let mut join_set = JoinSet::new(); + + for part_idx in 0..partition_count { + let task = state.task_ctx(); + let exec = exec.clone(); + join_set.spawn(async move { + let stream = exec.execute(part_idx, task)?; + common::collect(stream).await + }); + } let mut data: Vec> = Vec::with_capacity(exec.output_partitioning().partition_count()); - for result in futures::future::join_all(tasks).await { - data.push(result.map_err(|e| DataFusionError::External(Box::new(e)))??) + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => data.push(res?), + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } } let exec = MemoryExec::try_new(&data, schema.clone(), None)?; diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 027bd1945be6..eba51615cddf 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -23,7 +23,6 @@ use crate::datasource::physical_plan::file_stream::{ }; use crate::datasource::physical_plan::FileMeta; use crate::error::{DataFusionError, Result}; -use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::physical_plan::{ @@ -46,7 +45,7 @@ use std::fs; use std::path::Path; use std::sync::Arc; use std::task::Poll; -use tokio::task::{self, JoinHandle}; +use tokio::task::JoinSet; /// Execution plan for scanning a CSV file #[derive(Debug, Clone)] @@ -331,7 +330,7 @@ pub async fn plan_to_csv( ))); } - let mut tasks = vec![]; + let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { let plan = plan.clone(); let filename = format!("part-{i}.csv"); @@ -340,22 +339,29 @@ pub async fn plan_to_csv( let mut writer = csv::Writer::new(file); let stream = plan.execute(i, task_ctx.clone())?; - let handle: JoinHandle> = task::spawn(async move { - stream + join_set.spawn(async move { + let result: Result<()> = stream .map(|batch| writer.write(&batch?)) .try_collect() .await - .map_err(DataFusionError::from) + .map_err(DataFusionError::from); + result }); - tasks.push(AbortOnDropSingle::new(handle)); } - futures::future::join_all(tasks) - .await - .into_iter() - .try_for_each(|result| { - result.map_err(|e| DataFusionError::Execution(format!("{e}")))? - })?; + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => res?, // propagate DataFusion error + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + Ok(()) } diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index b736fd783999..64f70776606a 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -22,7 +22,6 @@ use crate::datasource::physical_plan::file_stream::{ }; use crate::datasource::physical_plan::FileMeta; use crate::error::{DataFusionError, Result}; -use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::physical_plan::{ @@ -44,7 +43,7 @@ use std::io::BufReader; use std::path::Path; use std::sync::Arc; use std::task::Poll; -use tokio::task::{self, JoinHandle}; +use tokio::task::JoinSet; use super::FileScanConfig; @@ -266,7 +265,7 @@ pub async fn plan_to_json( ))); } - let mut tasks = vec![]; + let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { let plan = plan.clone(); let filename = format!("part-{i}.json"); @@ -274,22 +273,29 @@ pub async fn plan_to_json( let file = fs::File::create(path)?; let mut writer = json::LineDelimitedWriter::new(file); let stream = plan.execute(i, task_ctx.clone())?; - let handle: JoinHandle> = task::spawn(async move { - stream + join_set.spawn(async move { + let result: Result<()> = stream .map(|batch| writer.write(&batch?)) .try_collect() .await - .map_err(DataFusionError::from) + .map_err(DataFusionError::from); + result }); - tasks.push(AbortOnDropSingle::new(handle)); } - futures::future::join_all(tasks) - .await - .into_iter() - .try_for_each(|result| { - result.map_err(|e| DataFusionError::Execution(format!("{e}")))? - })?; + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => res?, // propagate DataFusion error + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + Ok(()) } diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index f538255bc20d..96e5ce9fa0fd 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -31,7 +31,6 @@ use crate::{ execution::context::TaskContext, physical_optimizer::pruning::PruningPredicate, physical_plan::{ - common::AbortOnDropSingle, metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, ordering_equivalence_properties_helper, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, @@ -64,6 +63,7 @@ use parquet::arrow::{ArrowWriter, ParquetRecordBatchStreamBuilder, ProjectionMas use parquet::basic::{ConvertedType, LogicalType}; use parquet::file::{metadata::ParquetMetaData, properties::WriterProperties}; use parquet::schema::types::ColumnDescriptor; +use tokio::task::JoinSet; mod metrics; pub mod page_filter; @@ -701,7 +701,7 @@ pub async fn plan_to_parquet( ))); } - let mut tasks = vec![]; + let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { let plan = plan.clone(); let filename = format!("part-{i}.parquet"); @@ -710,27 +710,30 @@ pub async fn plan_to_parquet( let mut writer = ArrowWriter::try_new(file, plan.schema(), writer_properties.clone())?; let stream = plan.execute(i, task_ctx.clone())?; - let handle: tokio::task::JoinHandle> = - tokio::task::spawn(async move { - stream - .map(|batch| { - writer.write(&batch?).map_err(DataFusionError::ParquetError) - }) - .try_collect() - .await - .map_err(DataFusionError::from)?; + join_set.spawn(async move { + stream + .map(|batch| writer.write(&batch?).map_err(DataFusionError::ParquetError)) + .try_collect() + .await + .map_err(DataFusionError::from)?; + + writer.close().map_err(DataFusionError::from).map(|_| ()) + }); + } - writer.close().map_err(DataFusionError::from).map(|_| ()) - }); - tasks.push(AbortOnDropSingle::new(handle)); + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => res?, + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } } - futures::future::join_all(tasks) - .await - .into_iter() - .try_for_each(|result| { - result.map_err(|e| DataFusionError::Execution(format!("{e}")))? - })?; Ok(()) } diff --git a/datafusion/core/src/physical_plan/mod.rs b/datafusion/core/src/physical_plan/mod.rs index 5abecf6b167c..7efd5a19eeac 100644 --- a/datafusion/core/src/physical_plan/mod.rs +++ b/datafusion/core/src/physical_plan/mod.rs @@ -38,6 +38,7 @@ pub use display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; use futures::stream::{Stream, TryStreamExt}; use std::fmt; use std::fmt::Debug; +use tokio::task::JoinSet; use datafusion_common::tree_node::Transformed; use datafusion_common::DataFusionError; @@ -445,20 +446,37 @@ pub async fn collect_partitioned( ) -> Result>> { let streams = execute_stream_partitioned(plan, context)?; + let mut join_set = JoinSet::new(); // Execute the plan and collect the results into batches. - let handles = streams - .into_iter() - .enumerate() - .map(|(idx, stream)| async move { - let handle = tokio::task::spawn(stream.try_collect()); - AbortOnDropSingle::new(handle).await.map_err(|e| { - DataFusionError::Execution(format!( - "collect_partitioned partition {idx} panicked: {e}" - )) - })? + streams.into_iter().enumerate().for_each(|(idx, stream)| { + join_set.spawn(async move { + let result: Result> = stream.try_collect().await; + (idx, result) }); + }); + + let mut batches = vec![]; + // Note that currently this doesn't identify the thread that panicked + // + // TODO: Replace with [join_next_with_id](https://docs.rs/tokio/latest/tokio/task/struct.JoinSet.html#method.join_next_with_id + // once it is stable + while let Some(result) = join_set.join_next().await { + match result { + Ok((idx, res)) => batches.push((idx, res?)), + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + batches.sort_by_key(|(idx, _)| *idx); + let batches = batches.into_iter().map(|(_, batch)| batch).collect(); - futures::future::try_join_all(handles).await + Ok(batches) } /// Execute the [ExecutionPlan] and return a vec with one stream per output partition @@ -713,7 +731,6 @@ pub mod unnest; pub mod values; pub mod windows; -use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_execution::TaskContext; diff --git a/datafusion/core/src/physical_plan/repartition/mod.rs b/datafusion/core/src/physical_plan/repartition/mod.rs index 85225eb47176..3c689e97ab29 100644 --- a/datafusion/core/src/physical_plan/repartition/mod.rs +++ b/datafusion/core/src/physical_plan/repartition/mod.rs @@ -263,7 +263,7 @@ struct RepartitionMetrics { /// Time in nanos to execute child operator and fetch batches fetch_time: metrics::Time, /// Time in nanos to perform repartitioning - repart_time: metrics::Time, + repartition_time: metrics::Time, /// Time in nanos for sending resulting batches to channels send_time: metrics::Time, } @@ -293,7 +293,7 @@ impl RepartitionMetrics { Self { fetch_time, - repart_time, + repartition_time: repart_time, send_time, } } @@ -407,7 +407,7 @@ impl ExecutionPlan for RepartitionExec { // note we use a custom channel that ensures there is always data for each receiver // but limits the amount of buffering if required. let (txs, rxs) = channels(num_output_partitions); - // Clone sender for ech input partitions + // Clone sender for each input partitions let txs = txs .into_iter() .map(|item| vec![item; num_input_partitions]) @@ -565,34 +565,31 @@ impl RepartitionExec { /// Pulls data from the specified input plan, feeding it to the /// output partitions based on the desired partitioning /// - /// i is the input partition index - /// /// txs hold the output sending channels for each output partition async fn pull_from_input( input: Arc, - i: usize, - mut txs: HashMap< + partition: usize, + mut output_channels: HashMap< usize, (DistributionSender, SharedMemoryReservation), >, partitioning: Partitioning, - r_metrics: RepartitionMetrics, + metrics: RepartitionMetrics, context: Arc, ) -> Result<()> { let mut partitioner = - BatchPartitioner::try_new(partitioning, r_metrics.repart_time.clone())?; + BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?; // execute the child operator - let timer = r_metrics.fetch_time.timer(); - let mut stream = input.execute(i, context)?; + let timer = metrics.fetch_time.timer(); + let mut stream = input.execute(partition, context)?; timer.done(); - // While there are still outputs to send to, keep - // pulling inputs + // While there are still outputs to send to, keep pulling inputs let mut batches_until_yield = partitioner.num_partitions(); - while !txs.is_empty() { + while !output_channels.is_empty() { // fetch the next batch - let timer = r_metrics.fetch_time.timer(); + let timer = metrics.fetch_time.timer(); let result = stream.next().await; timer.done(); @@ -606,15 +603,15 @@ impl RepartitionExec { let (partition, batch) = res?; let size = batch.get_array_memory_size(); - let timer = r_metrics.send_time.timer(); + let timer = metrics.send_time.timer(); // if there is still a receiver, send to it - if let Some((tx, reservation)) = txs.get_mut(&partition) { + if let Some((tx, reservation)) = output_channels.get_mut(&partition) { reservation.lock().try_grow(size)?; if tx.send(Some(Ok(batch))).await.is_err() { // If the other end has hung up, it was an early shutdown (e.g. LIMIT) reservation.lock().shrink(size); - txs.remove(&partition); + output_channels.remove(&partition); } } timer.done();