From 2511c5da43cb934b3c437aae764647f1e0c4eb72 Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Mon, 22 Jan 2024 16:47:33 +0800 Subject: [PATCH 01/10] support list stride --- datafusion/core/src/physical_planner.rs | 6 + datafusion/expr/src/expr.rs | 23 ++ datafusion/expr/src/expr_schema.rs | 9 + datafusion/expr/src/field_util.rs | 15 + .../src/expressions/get_indexed_field.rs | 94 +++++- datafusion/physical-expr/src/planner.rs | 13 + datafusion/proto/proto/datafusion.proto | 14 + datafusion/proto/src/generated/pbjson.rs | 278 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 28 +- .../proto/src/logical_plan/from_proto.rs | 19 ++ datafusion/proto/src/logical_plan/to_proto.rs | 9 + .../proto/src/physical_plan/from_proto.rs | 23 +- .../proto/src/physical_plan/to_proto.rs | 9 + datafusion/sql/src/expr/mod.rs | 44 ++- 14 files changed, 569 insertions(+), 15 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ac3b7ebaeac1..2bd91cc0d04c 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -214,6 +214,12 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { let stop = create_physical_name(stop, false)?; format!("{expr}[{start}:{stop}]") } + GetFieldAccess::ListStride { start, stop, stride } => { + let start = create_physical_name(start, false)?; + let stop = create_physical_name(stop, false)?; + let stride = create_physical_name(stride, false)?; + format!("{expr}[{start}:{stop}:{stride}]") + } }; Ok(name) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index c5d158d87638..19bc50d8f241 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -423,6 +423,12 @@ pub enum GetFieldAccess { ListIndex { key: Box }, /// List range, for example `list[i:j]` ListRange { start: Box, stop: Box }, + /// List stride, for example `list[i:j:k]` + ListStride { + start: Box, + stop: Box, + stride: Box, + }, } /// Returns the field of a [`arrow::array::ListArray`] or @@ -1533,6 +1539,13 @@ impl fmt::Display for Expr { GetFieldAccess::ListRange { start, stop } => { write!(f, "({expr})[{start}:{stop}]") } + GetFieldAccess::ListStride { + start, + stop, + stride, + } => { + write!(f, "({expr})[{start}:{stop}:{stride}]") + } }, Expr::GroupingSet(grouping_sets) => match grouping_sets { GroupingSet::Rollup(exprs) => { @@ -1737,6 +1750,16 @@ fn create_name(e: &Expr) -> Result { let stop = create_name(stop)?; Ok(format!("{expr}[{start}:{stop}]")) } + GetFieldAccess::ListStride { + start, + stop, + stride, + } => { + let start = create_name(start)?; + let stop = create_name(stop)?; + let stride = create_name(stride)?; + Ok(format!("{expr}[{start}:{stop}:{stride}]")) + } } } Expr::ScalarFunction(fun) => create_function_name(fun.name(), false, &fun.args), diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index ba21d09f0619..b1777ee5cfce 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -378,6 +378,15 @@ fn field_for_index( start_dt: start.get_type(schema)?, stop_dt: stop.get_type(schema)?, }, + GetFieldAccess::ListStride { + start, + stop, + stride, + } => GetFieldAccessSchema::ListStride { + start_dt: start.get_type(schema)?, + stop_dt: stop.get_type(schema)?, + stride_dt: stride.get_type(schema)?, + }, } .get_accessed_field(&expr_dt) } diff --git a/datafusion/expr/src/field_util.rs b/datafusion/expr/src/field_util.rs index 3829a2086b26..84e3b06f7888 100644 --- a/datafusion/expr/src/field_util.rs +++ b/datafusion/expr/src/field_util.rs @@ -33,6 +33,12 @@ pub enum GetFieldAccessSchema { start_dt: DataType, stop_dt: DataType, }, + /// List stride, for example `list[i:j:k]` + ListStride { + start_dt: DataType, + stop_dt: DataType, + stride_dt: DataType, + }, } impl GetFieldAccessSchema { @@ -94,6 +100,15 @@ impl GetFieldAccessSchema { (other, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), } } + Self::ListStride { start_dt, stop_dt, stride_dt } => { + match (data_type, start_dt, stop_dt, stride_dt) { + (DataType::List(_), DataType::Int64, DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)), + (DataType::List(_), _, _, _) => plan_err!( + "Only ints are valid as an indexed field in a list" + ), + (other, _, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + } + } } } } diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index 43fd5a812a16..d39b24ccae57 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -48,6 +48,12 @@ pub enum GetFieldAccessExpr { start: Arc, stop: Arc, }, + /// List stride, for example `list[i:j:k]` + ListStride { + start: Arc, + stop: Arc, + stride: Arc, + }, } impl std::fmt::Display for GetFieldAccessExpr { @@ -58,13 +64,20 @@ impl std::fmt::Display for GetFieldAccessExpr { GetFieldAccessExpr::ListRange { start, stop } => { write!(f, "[{}:{}]", start, stop) } + GetFieldAccessExpr::ListStride { + start, + stop, + stride, + } => { + write!(f, "[{}:{}:{}]", start, stop, stride) + } } } } impl PartialEq for GetFieldAccessExpr { fn eq(&self, other: &dyn Any) -> bool { - use GetFieldAccessExpr::{ListIndex, ListRange, NamedStructField}; + use GetFieldAccessExpr::{ListIndex, ListRange, ListStride, NamedStructField}; down_cast_any_ref(other) .downcast_ref::() .map(|x| match (self, x) { @@ -82,9 +95,38 @@ impl PartialEq for GetFieldAccessExpr { stop: stop_rhs, }, ) => start_lhs.eq(start_rhs) && stop_lhs.eq(stop_rhs), - (NamedStructField { .. }, ListIndex { .. } | ListRange { .. }) => false, - (ListIndex { .. }, NamedStructField { .. } | ListRange { .. }) => false, - (ListRange { .. }, NamedStructField { .. } | ListIndex { .. }) => false, + ( + ListStride { + start: start_lhs, + stop: stop_lhs, + stride: stride_lhs, + }, + ListStride { + start: start_rhs, + stop: stop_rhs, + stride: stride_rhs, + }, + ) => { + start_lhs.eq(start_rhs) + && stop_lhs.eq(stop_rhs) + && stride_lhs.eq(stride_rhs) + } + ( + NamedStructField { .. }, + ListIndex { .. } | ListRange { .. } | ListStride { .. }, + ) => false, + ( + ListIndex { .. }, + NamedStructField { .. } | ListRange { .. } | ListStride { .. }, + ) => false, + ( + ListRange { .. }, + NamedStructField { .. } | ListIndex { .. } | ListStride { .. }, + ) => false, + ( + ListStride { .. }, + NamedStructField { .. } | ListIndex { .. } | ListRange { .. }, + ) => false, }) .unwrap_or(false) } @@ -129,6 +171,23 @@ impl GetIndexedFieldExpr { Self::new(arg, GetFieldAccessExpr::ListRange { start, stop }) } + /// Create a new [`GetIndexedFieldExpr`] for accessing the stride + pub fn new_stride( + arg: Arc, + start: Arc, + stop: Arc, + stride: Arc, + ) -> Self { + Self::new( + arg, + GetFieldAccessExpr::ListStride { + start, + stop, + stride, + }, + ) + } + /// Get the description of what field should be accessed pub fn field(&self) -> &GetFieldAccessExpr { &self.field @@ -153,6 +212,15 @@ impl GetIndexedFieldExpr { stop_dt: stop.data_type(input_schema)?, } } + GetFieldAccessExpr::ListStride { + start, + stop, + stride, + } => GetFieldAccessSchema::ListStride { + start_dt: start.data_type(input_schema)?, + stop_dt: stop.data_type(input_schema)?, + stride_dt: stride.data_type(input_schema)?, + }, }) } } @@ -238,6 +306,24 @@ impl PhysicalExpr for GetIndexedFieldExpr { with utf8 indexes. Tried {dt:?} with {start:?} and {stop:?} indices"), } }, + GetFieldAccessExpr::ListStride { start, stop, stride } => { + let start = start.evaluate(batch)?.into_array(batch.num_rows())?; + let stop = stop.evaluate(batch)?.into_array(batch.num_rows())?; + let stride = stride.evaluate(batch)?.into_array(batch.num_rows())?; + match (array.data_type(), start.data_type(), stop.data_type(), stride.data_type()) { + (DataType::List(_), DataType::Int64, DataType::Int64, DataType::Int64) => { + Ok(ColumnarValue::Array((array_slice(&[ + array, start, stop, stride + ]))?)) + }, + (DataType::List(_), start, stop, stride) => exec_err!( + "get indexed field is only possible on lists with int64 indexes. \ + Tried with {start:?}, {stop:?} and {stride:?} indices"), + (dt, start, stop, stride) => exec_err!( + "get indexed field is only possible on lists with int64 indexes or struct \ + with utf8 indexes. Tried {dt:?} with {start:?}, {stop:?} and {stride:?} indices"), + } + } } } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 09b8da836c30..b3efff04bf71 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -252,6 +252,19 @@ pub fn create_physical_expr( )?, } } + GetFieldAccess::ListStride { + start, + stop, + stride, + } => GetFieldAccessExpr::ListStride { + start: create_physical_expr(start, input_dfschema, execution_props)?, + stop: create_physical_expr(stop, input_dfschema, execution_props)?, + stride: create_physical_expr( + stride, + input_dfschema, + execution_props, + )?, + }, }; Ok(Arc::new(GetIndexedFieldExpr::new( create_physical_expr(expr, input_dfschema, execution_props)?, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 66c1271e65c1..2487fad1bc1d 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -468,12 +468,19 @@ message ListRange { LogicalExprNode stop = 2; } +message ListStride { + LogicalExprNode start = 1; + LogicalExprNode stop = 2; + LogicalExprNode stride = 3; +} + message GetIndexedField { LogicalExprNode expr = 1; oneof field { NamedStructField named_struct_field = 2; ListIndex list_index = 3; ListRange list_range = 4; + ListStride list_stride = 5; } } @@ -1775,11 +1782,18 @@ message ListRangeExpr { PhysicalExprNode stop = 2; } +message ListStrideExpr { + PhysicalExprNode start = 1; + PhysicalExprNode stop = 2; + PhysicalExprNode stride = 3; +} + message PhysicalGetIndexedFieldExprNode { PhysicalExprNode arg = 1; oneof field { NamedStructFieldExpr named_struct_field_expr = 2; ListIndexExpr list_index_expr = 3; ListRangeExpr list_range_expr = 4; + ListStrideExpr list_stride_expr = 5; } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 39a8678ef250..b76d2dc8ca60 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -9028,6 +9028,9 @@ impl serde::Serialize for GetIndexedField { get_indexed_field::Field::ListRange(v) => { struct_ser.serialize_field("listRange", v)?; } + get_indexed_field::Field::ListStride(v) => { + struct_ser.serialize_field("listStride", v)?; + } } } struct_ser.end() @@ -9047,6 +9050,8 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { "listIndex", "list_range", "listRange", + "list_stride", + "listStride", ]; #[allow(clippy::enum_variant_names)] @@ -9055,6 +9060,7 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { NamedStructField, ListIndex, ListRange, + ListStride, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9080,6 +9086,7 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { "namedStructField" | "named_struct_field" => Ok(GeneratedField::NamedStructField), "listIndex" | "list_index" => Ok(GeneratedField::ListIndex), "listRange" | "list_range" => Ok(GeneratedField::ListRange), + "listStride" | "list_stride" => Ok(GeneratedField::ListStride), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9128,6 +9135,13 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { return Err(serde::de::Error::duplicate_field("listRange")); } field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListRange) +; + } + GeneratedField::ListStride => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("listStride")); + } + field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListStride) ; } } @@ -12744,6 +12758,256 @@ impl<'de> serde::Deserialize<'de> for ListRangeExpr { deserializer.deserialize_struct("datafusion.ListRangeExpr", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ListStride { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.start.is_some() { + len += 1; + } + if self.stop.is_some() { + len += 1; + } + if self.stride.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListStride", len)?; + if let Some(v) = self.start.as_ref() { + struct_ser.serialize_field("start", v)?; + } + if let Some(v) = self.stop.as_ref() { + struct_ser.serialize_field("stop", v)?; + } + if let Some(v) = self.stride.as_ref() { + struct_ser.serialize_field("stride", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ListStride { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "start", + "stop", + "stride", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Start, + Stop, + Stride, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "start" => Ok(GeneratedField::Start), + "stop" => Ok(GeneratedField::Stop), + "stride" => Ok(GeneratedField::Stride), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ListStride; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ListStride") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut start__ = None; + let mut stop__ = None; + let mut stride__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Start => { + if start__.is_some() { + return Err(serde::de::Error::duplicate_field("start")); + } + start__ = map_.next_value()?; + } + GeneratedField::Stop => { + if stop__.is_some() { + return Err(serde::de::Error::duplicate_field("stop")); + } + stop__ = map_.next_value()?; + } + GeneratedField::Stride => { + if stride__.is_some() { + return Err(serde::de::Error::duplicate_field("stride")); + } + stride__ = map_.next_value()?; + } + } + } + Ok(ListStride { + start: start__, + stop: stop__, + stride: stride__, + }) + } + } + deserializer.deserialize_struct("datafusion.ListStride", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ListStrideExpr { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.start.is_some() { + len += 1; + } + if self.stop.is_some() { + len += 1; + } + if self.stride.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListStrideExpr", len)?; + if let Some(v) = self.start.as_ref() { + struct_ser.serialize_field("start", v)?; + } + if let Some(v) = self.stop.as_ref() { + struct_ser.serialize_field("stop", v)?; + } + if let Some(v) = self.stride.as_ref() { + struct_ser.serialize_field("stride", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ListStrideExpr { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "start", + "stop", + "stride", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Start, + Stop, + Stride, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "start" => Ok(GeneratedField::Start), + "stop" => Ok(GeneratedField::Stop), + "stride" => Ok(GeneratedField::Stride), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ListStrideExpr; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ListStrideExpr") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut start__ = None; + let mut stop__ = None; + let mut stride__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Start => { + if start__.is_some() { + return Err(serde::de::Error::duplicate_field("start")); + } + start__ = map_.next_value()?; + } + GeneratedField::Stop => { + if stop__.is_some() { + return Err(serde::de::Error::duplicate_field("stop")); + } + stop__ = map_.next_value()?; + } + GeneratedField::Stride => { + if stride__.is_some() { + return Err(serde::de::Error::duplicate_field("stride")); + } + stride__ = map_.next_value()?; + } + } + } + Ok(ListStrideExpr { + start: start__, + stop: stop__, + stride: stride__, + }) + } + } + deserializer.deserialize_struct("datafusion.ListStrideExpr", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ListingTableScanNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -18113,6 +18377,9 @@ impl serde::Serialize for PhysicalGetIndexedFieldExprNode { physical_get_indexed_field_expr_node::Field::ListRangeExpr(v) => { struct_ser.serialize_field("listRangeExpr", v)?; } + physical_get_indexed_field_expr_node::Field::ListStrideExpr(v) => { + struct_ser.serialize_field("listStrideExpr", v)?; + } } } struct_ser.end() @@ -18132,6 +18399,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { "listIndexExpr", "list_range_expr", "listRangeExpr", + "list_stride_expr", + "listStrideExpr", ]; #[allow(clippy::enum_variant_names)] @@ -18140,6 +18409,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { NamedStructFieldExpr, ListIndexExpr, ListRangeExpr, + ListStrideExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18165,6 +18435,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { "namedStructFieldExpr" | "named_struct_field_expr" => Ok(GeneratedField::NamedStructFieldExpr), "listIndexExpr" | "list_index_expr" => Ok(GeneratedField::ListIndexExpr), "listRangeExpr" | "list_range_expr" => Ok(GeneratedField::ListRangeExpr), + "listStrideExpr" | "list_stride_expr" => Ok(GeneratedField::ListStrideExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18213,6 +18484,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { return Err(serde::de::Error::duplicate_field("listRangeExpr")); } field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::ListRangeExpr) +; + } + GeneratedField::ListStrideExpr => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("listStrideExpr")); + } + field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::ListStrideExpr) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 7bf1d8ed0450..7a8cd1cf0111 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -734,10 +734,20 @@ pub struct ListRange { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct ListStride { + #[prost(message, optional, boxed, tag = "1")] + pub start: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub stop: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "3")] + pub stride: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct GetIndexedField { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(oneof = "get_indexed_field::Field", tags = "2, 3, 4")] + #[prost(oneof = "get_indexed_field::Field", tags = "2, 3, 4, 5")] pub field: ::core::option::Option, } /// Nested message and enum types in `GetIndexedField`. @@ -751,6 +761,8 @@ pub mod get_indexed_field { ListIndex(::prost::alloc::boxed::Box), #[prost(message, tag = "4")] ListRange(::prost::alloc::boxed::Box), + #[prost(message, tag = "5")] + ListStride(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -2541,10 +2553,20 @@ pub struct ListRangeExpr { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct ListStrideExpr { + #[prost(message, optional, boxed, tag = "1")] + pub start: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub stop: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "3")] + pub stride: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalGetIndexedFieldExprNode { #[prost(message, optional, boxed, tag = "1")] pub arg: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(oneof = "physical_get_indexed_field_expr_node::Field", tags = "2, 3, 4")] + #[prost(oneof = "physical_get_indexed_field_expr_node::Field", tags = "2, 3, 4, 5")] pub field: ::core::option::Option, } /// Nested message and enum types in `PhysicalGetIndexedFieldExprNode`. @@ -2558,6 +2580,8 @@ pub mod physical_get_indexed_field_expr_node { ListIndexExpr(::prost::alloc::boxed::Box), #[prost(message, tag = "4")] ListRangeExpr(::prost::alloc::boxed::Box), + #[prost(message, tag = "5")] + ListStrideExpr(::prost::alloc::boxed::Box), } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 42d39b5c5139..b085a7cae0dc 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1069,6 +1069,25 @@ pub fn parse_expr( )?), } } + Some(protobuf::get_indexed_field::Field::ListStride(list_stride)) => { + GetFieldAccess::ListStride { + start: Box::new(parse_required_expr( + list_stride.start.as_deref(), + registry, + "start", + )?), + stop: Box::new(parse_required_expr( + list_stride.stop.as_deref(), + registry, + "stop", + )?), + stride: Box::new(parse_required_expr( + list_stride.stride.as_deref(), + registry, + "stride", + )?), + } + } None => return Err(proto_error("Field must not be None")), }; diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index dbb52eced36c..48d2a39d5d1e 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1041,6 +1041,15 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { }, )) } + GetFieldAccess::ListStride { start, stop, stride } => { + protobuf::get_indexed_field::Field::ListStride(Box::new( + protobuf::ListStride { + start: Some(Box::new(start.as_ref().try_into()?)), + stop: Some(Box::new(stop.as_ref().try_into()?)), + stride: Some(Box::new(stride.as_ref().try_into()?)), + }, + )) + } }; Self { diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index dc827d02bf25..98086ce91ec4 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -412,7 +412,28 @@ pub fn parse_physical_expr( input_schema )?, }, - None => return Err(proto_error( + Some(protobuf::physical_get_indexed_field_expr_node::Field::ListStrideExpr(list_stride_expr)) => GetFieldAccessExpr::ListStride{ + start: parse_required_physical_expr( + list_stride_expr.start.as_deref(), + registry, + "start", + input_schema, + )?, + stop: parse_required_physical_expr( + list_stride_expr.stop.as_deref(), + registry, + "stop", + input_schema + )?, + stride: parse_required_physical_expr( + list_stride_expr.stride.as_deref(), + registry, + "stride", + input_schema + )?, + }, + None => + return Err(proto_error( "Field must not be None", )), }; diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index cff32ca2f8c9..430739f38231 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -568,6 +568,15 @@ impl TryFrom> for protobuf::PhysicalExprNode { stop: Some(Box::new(stop.to_owned().try_into()?)), })) ), + GetFieldAccessExpr::ListStride { start, stop, stride } => { + Some( + protobuf::physical_get_indexed_field_expr_node::Field::ListStrideExpr(Box::new(protobuf::ListStrideExpr { + start: Some(Box::new(start.to_owned().try_into()?)), + stop: Some(Box::new(stop.to_owned().try_into()?)), + stride: Some(Box::new(stride.to_owned().try_into()?)), + })) + ) + } }; Ok(protobuf::PhysicalExprNode { diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 9fded63af3fc..ba372b4a117f 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -753,18 +753,46 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { operator: JsonOperator::Colon, right, } => { - let start = Box::new(self.sql_expr_to_logical_expr( - *left, - schema, - planner_context, - )?); - let stop = Box::new(self.sql_expr_to_logical_expr( - *right, + // the last value could represent stop or stride + let last = Box::new(self.sql_expr_to_logical_expr( + *right.clone(), schema, planner_context, )?); - GetFieldAccess::ListRange { start, stop } + match *left { + SQLExpr::JsonAccess { + left, + operator: JsonOperator::Colon, + right, + } => { + let start = Box::new(self.sql_expr_to_logical_expr( + *left, + schema, + planner_context, + )?); + let stop = Box::new(self.sql_expr_to_logical_expr( + *right, + schema, + planner_context, + )?); + + GetFieldAccess::ListStride { + start, + stop, + stride: last, + } + } + _ => { + let start = Box::new(self.sql_expr_to_logical_expr( + *left, + schema, + planner_context, + )?); + + GetFieldAccess::ListRange { start, stop: last } + } + } } _ => GetFieldAccess::ListIndex { key: Box::new(self.sql_expr_to_logical_expr( From b08d588d077e2280322301f26a5ce1cfbe54121b Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Mon, 22 Jan 2024 17:18:45 +0800 Subject: [PATCH 02/10] add test --- datafusion/sqllogictest/test_files/array.slt | 45 ++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index b7d92aec88e6..e072e4146f13 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -759,6 +759,51 @@ select column1[0:5], column2[0:3], column3[0:9] from arrays; # select column1[column2:column3] from arrays_with_repeating_elements; # ---- +# array[i:j:k] + +# multiple index with columns #1 (positive index) +query ??? +select make_array(1, 2, 3)[1:2:2], make_array(1.0, 2.0, 3.0)[2:3:2], make_array('h', 'e', 'l', 'l', 'o')[2:4:2]; +---- +[1] [2.0] [e, l] + +# multiple index with columns #2 (zero index) +query ??? +select make_array(1, 2, 3)[0:0:2], make_array(1.0, 2.0, 3.0)[0:2:2], make_array('h', 'e', 'l', 'l', 'o')[0:6:2]; +---- +[] [1.0] [h, l, o] + +#TODO: sqlparser does not support negative index +## multiple index with columns #3 (negative index) +#query ??? +#select make_array(1, 2, 3)[-1:-2:-2], make_array(1.0, 2.0, 3.0)[-2:-3:-2], make_array('h', 'e', 'l', 'l', 'o')[-2:-4:-2]; +#---- +#[1] [2.0] [e, l] + +# multiple index with columns #1 (positive index) +query ??? +select column1[2:4:2], column2[1:4:2], column3[3:4:2] from arrays; +---- +[[3, ]] [1.1, 3.3] [r] +[[5, 6]] [, 6.6] [] +[[7, 8]] [7.7, 9.9] [l] +[[9, 10]] [10.1, 12.2] [t] +[] [13.3, 15.5] [e] +[[13, 14]] [] [] +[[, 18]] [16.6, 18.8] [] + +# multiple index with columns #2 (zero index) +query ??? +select column1[0:5:2], column2[0:3:2], column3[0:9:2] from arrays; +---- +[[, 2]] [1.1, 3.3] [L, r, m] +[[3, 4]] [, 6.6] [i, , m] +[[5, 6]] [7.7, 9.9] [d, l, r] +[[7, ]] [10.1, 12.2] [s, t] +[] [13.3, 15.5] [a, e] +[[11, 12]] [] [,] +[[15, 16]] [16.6, 18.8] [] + ### Array function tests From 2046afed992e2be124e49ca23b1b63af125fda34 Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Mon, 22 Jan 2024 17:42:45 +0800 Subject: [PATCH 03/10] fix fmt --- datafusion/core/src/physical_planner.rs | 6 +++++- datafusion/proto/src/logical_plan/to_proto.rs | 20 ++++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 2bd91cc0d04c..0564de8904cd 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -214,7 +214,11 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { let stop = create_physical_name(stop, false)?; format!("{expr}[{start}:{stop}]") } - GetFieldAccess::ListStride { start, stop, stride } => { + GetFieldAccess::ListStride { + start, + stop, + stride, + } => { let start = create_physical_name(start, false)?; let stop = create_physical_name(stop, false)?; let stride = create_physical_name(stride, false)?; diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 48d2a39d5d1e..3fc0fe258a9a 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1041,15 +1041,17 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { }, )) } - GetFieldAccess::ListStride { start, stop, stride } => { - protobuf::get_indexed_field::Field::ListStride(Box::new( - protobuf::ListStride { - start: Some(Box::new(start.as_ref().try_into()?)), - stop: Some(Box::new(stop.as_ref().try_into()?)), - stride: Some(Box::new(stride.as_ref().try_into()?)), - }, - )) - } + GetFieldAccess::ListStride { + start, + stop, + stride, + } => protobuf::get_indexed_field::Field::ListStride(Box::new( + protobuf::ListStride { + start: Some(Box::new(start.as_ref().try_into()?)), + stop: Some(Box::new(stop.as_ref().try_into()?)), + stride: Some(Box::new(stride.as_ref().try_into()?)), + }, + )), }; Self { From 10ba8083456d8eafff94f305d700604ea30c1b12 Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Thu, 25 Jan 2024 11:11:46 +0800 Subject: [PATCH 04/10] rename and extend ListRange to ListStride --- datafusion/core/src/physical_planner.rs | 5 - datafusion/expr/src/expr.rs | 13 +- datafusion/expr/src/expr_schema.rs | 4 - datafusion/expr/src/field_util.rs | 14 - .../src/expressions/get_indexed_field.rs | 71 +---- datafusion/physical-expr/src/planner.rs | 14 - datafusion/proto/proto/datafusion.proto | 16 +- datafusion/proto/src/generated/pbjson.rs | 244 ------------------ datafusion/proto/src/generated/prost.rs | 24 +- .../proto/src/logical_plan/from_proto.rs | 14 - datafusion/proto/src/logical_plan/to_proto.rs | 8 - .../proto/src/physical_plan/from_proto.rs | 14 - .../proto/src/physical_plan/to_proto.rs | 6 - .../tests/cases/roundtrip_physical_plan.rs | 1 + datafusion/sql/src/expr/mod.rs | 9 +- 15 files changed, 27 insertions(+), 430 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 0564de8904cd..7d8b1a862288 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -209,11 +209,6 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { let key = create_physical_name(key, false)?; format!("{expr}[{key}]") } - GetFieldAccess::ListRange { start, stop } => { - let start = create_physical_name(start, false)?; - let stop = create_physical_name(stop, false)?; - format!("{expr}[{start}:{stop}]") - } GetFieldAccess::ListStride { start, stop, diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 19bc50d8f241..56a3d2bbc567 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -421,8 +421,6 @@ pub enum GetFieldAccess { NamedStructField { name: ScalarValue }, /// Single list index, for example: `list[i]` ListIndex { key: Box }, - /// List range, for example `list[i:j]` - ListRange { start: Box, stop: Box }, /// List stride, for example `list[i:j:k]` ListStride { start: Box, @@ -1220,9 +1218,10 @@ impl Expr { pub fn range(self, start: Expr, stop: Expr) -> Self { Expr::GetIndexedField(GetIndexedField { expr: Box::new(self), - field: GetFieldAccess::ListRange { + field: GetFieldAccess::ListStride { start: Box::new(start), stop: Box::new(stop), + stride: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))), }, }) } @@ -1536,9 +1535,6 @@ impl fmt::Display for Expr { write!(f, "({expr})[{name}]") } GetFieldAccess::ListIndex { key } => write!(f, "({expr})[{key}]"), - GetFieldAccess::ListRange { start, stop } => { - write!(f, "({expr})[{start}:{stop}]") - } GetFieldAccess::ListStride { start, stop, @@ -1745,11 +1741,6 @@ fn create_name(e: &Expr) -> Result { let key = create_name(key)?; Ok(format!("{expr}[{key}]")) } - GetFieldAccess::ListRange { start, stop } => { - let start = create_name(start)?; - let stop = create_name(stop)?; - Ok(format!("{expr}[{start}:{stop}]")) - } GetFieldAccess::ListStride { start, stop, diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index b1777ee5cfce..14437a704932 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -374,10 +374,6 @@ fn field_for_index( GetFieldAccess::ListIndex { key } => GetFieldAccessSchema::ListIndex { key_dt: key.get_type(schema)?, }, - GetFieldAccess::ListRange { start, stop } => GetFieldAccessSchema::ListRange { - start_dt: start.get_type(schema)?, - stop_dt: stop.get_type(schema)?, - }, GetFieldAccess::ListStride { start, stop, diff --git a/datafusion/expr/src/field_util.rs b/datafusion/expr/src/field_util.rs index 84e3b06f7888..a24cb42bbb58 100644 --- a/datafusion/expr/src/field_util.rs +++ b/datafusion/expr/src/field_util.rs @@ -28,11 +28,6 @@ pub enum GetFieldAccessSchema { NamedStructField { name: ScalarValue }, /// Single list index, for example: `list[i]` ListIndex { key_dt: DataType }, - /// List range, for example `list[i:j]` - ListRange { - start_dt: DataType, - stop_dt: DataType, - }, /// List stride, for example `list[i:j:k]` ListStride { start_dt: DataType, @@ -91,15 +86,6 @@ impl GetFieldAccessSchema { (other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), } } - Self::ListRange{ start_dt, stop_dt } => { - match (data_type, start_dt, stop_dt) { - (DataType::List(_), DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)), - (DataType::List(_), _, _) => plan_err!( - "Only ints are valid as an indexed field in a list" - ), - (other, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), - } - } Self::ListStride { start_dt, stop_dt, stride_dt } => { match (data_type, start_dt, stop_dt, stride_dt) { (DataType::List(_), DataType::Int64, DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)), diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index d39b24ccae57..b67e06475f43 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -21,6 +21,7 @@ use crate::PhysicalExpr; use datafusion_common::exec_err; use crate::array_expressions::{array_element, array_slice}; +use crate::expressions::Literal; use crate::physical_expr::down_cast_any_ref; use arrow::{ array::{Array, Scalar, StringArray}, @@ -43,11 +44,6 @@ pub enum GetFieldAccessExpr { NamedStructField { name: ScalarValue }, /// Single list index, for example: `list[i]` ListIndex { key: Arc }, - /// List range, for example `list[i:j]` - ListRange { - start: Arc, - stop: Arc, - }, /// List stride, for example `list[i:j:k]` ListStride { start: Arc, @@ -61,9 +57,6 @@ impl std::fmt::Display for GetFieldAccessExpr { match self { GetFieldAccessExpr::NamedStructField { name } => write!(f, "[{}]", name), GetFieldAccessExpr::ListIndex { key } => write!(f, "[{}]", key), - GetFieldAccessExpr::ListRange { start, stop } => { - write!(f, "[{}:{}]", start, stop) - } GetFieldAccessExpr::ListStride { start, stop, @@ -77,7 +70,7 @@ impl std::fmt::Display for GetFieldAccessExpr { impl PartialEq for GetFieldAccessExpr { fn eq(&self, other: &dyn Any) -> bool { - use GetFieldAccessExpr::{ListIndex, ListRange, ListStride, NamedStructField}; + use GetFieldAccessExpr::{ListIndex, ListStride, NamedStructField}; down_cast_any_ref(other) .downcast_ref::() .map(|x| match (self, x) { @@ -85,16 +78,6 @@ impl PartialEq for GetFieldAccessExpr { lhs.eq(rhs) } (ListIndex { key: lhs }, ListIndex { key: rhs }) => lhs.eq(rhs), - ( - ListRange { - start: start_lhs, - stop: stop_lhs, - }, - ListRange { - start: start_rhs, - stop: stop_rhs, - }, - ) => start_lhs.eq(start_rhs) && stop_lhs.eq(stop_rhs), ( ListStride { start: start_lhs, @@ -111,22 +94,9 @@ impl PartialEq for GetFieldAccessExpr { && stop_lhs.eq(stop_rhs) && stride_lhs.eq(stride_rhs) } - ( - NamedStructField { .. }, - ListIndex { .. } | ListRange { .. } | ListStride { .. }, - ) => false, - ( - ListIndex { .. }, - NamedStructField { .. } | ListRange { .. } | ListStride { .. }, - ) => false, - ( - ListRange { .. }, - NamedStructField { .. } | ListIndex { .. } | ListStride { .. }, - ) => false, - ( - ListStride { .. }, - NamedStructField { .. } | ListIndex { .. } | ListRange { .. }, - ) => false, + (NamedStructField { .. }, ListIndex { .. } | ListStride { .. }) => false, + (ListIndex { .. }, NamedStructField { .. } | ListStride { .. }) => false, + (ListStride { .. }, NamedStructField { .. } | ListIndex { .. }) => false, }) .unwrap_or(false) } @@ -168,7 +138,15 @@ impl GetIndexedFieldExpr { start: Arc, stop: Arc, ) -> Self { - Self::new(arg, GetFieldAccessExpr::ListRange { start, stop }) + Self::new( + arg, + GetFieldAccessExpr::ListStride { + start, + stop, + stride: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))) + as Arc, + }, + ) } /// Create a new [`GetIndexedFieldExpr`] for accessing the stride @@ -206,12 +184,6 @@ impl GetIndexedFieldExpr { GetFieldAccessExpr::ListIndex { key } => GetFieldAccessSchema::ListIndex { key_dt: key.data_type(input_schema)?, }, - GetFieldAccessExpr::ListRange { start, stop } => { - GetFieldAccessSchema::ListRange { - start_dt: start.data_type(input_schema)?, - stop_dt: stop.data_type(input_schema)?, - } - } GetFieldAccessExpr::ListStride { start, stop, @@ -291,21 +263,6 @@ impl PhysicalExpr for GetIndexedFieldExpr { with utf8 indexes. Tried {dt:?} with {key:?} index"), } }, - GetFieldAccessExpr::ListRange{start, stop} => { - let start = start.evaluate(batch)?.into_array(batch.num_rows())?; - let stop = stop.evaluate(batch)?.into_array(batch.num_rows())?; - match (array.data_type(), start.data_type(), stop.data_type()) { - (DataType::List(_), DataType::Int64, DataType::Int64) => Ok(ColumnarValue::Array(array_slice(&[ - array, start, stop - ])?)), - (DataType::List(_), start, stop) => exec_err!( - "get indexed field is only possible on lists with int64 indexes. \ - Tried with {start:?} and {stop:?} indices"), - (dt, start, stop) => exec_err!( - "get indexed field is only possible on lists with int64 indexes or struct \ - with utf8 indexes. Tried {dt:?} with {start:?} and {stop:?} indices"), - } - }, GetFieldAccessExpr::ListStride { start, stop, stride } => { let start = start.evaluate(batch)?.into_array(batch.num_rows())?; let stop = stop.evaluate(batch)?.into_array(batch.num_rows())?; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index b3efff04bf71..834c29a944fd 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -238,20 +238,6 @@ pub fn create_physical_expr( GetFieldAccess::ListIndex { key } => GetFieldAccessExpr::ListIndex { key: create_physical_expr(key, input_dfschema, execution_props)?, }, - GetFieldAccess::ListRange { start, stop } => { - GetFieldAccessExpr::ListRange { - start: create_physical_expr( - start, - input_dfschema, - execution_props, - )?, - stop: create_physical_expr( - stop, - input_dfschema, - execution_props, - )?, - } - } GetFieldAccess::ListStride { start, stop, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 2487fad1bc1d..43d0d92307ce 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -463,11 +463,6 @@ message ListIndex { LogicalExprNode key = 1; } -message ListRange { - LogicalExprNode start = 1; - LogicalExprNode stop = 2; -} - message ListStride { LogicalExprNode start = 1; LogicalExprNode stop = 2; @@ -479,8 +474,7 @@ message GetIndexedField { oneof field { NamedStructField named_struct_field = 2; ListIndex list_index = 3; - ListRange list_range = 4; - ListStride list_stride = 5; + ListStride list_stride = 4; } } @@ -1777,11 +1771,6 @@ message ListIndexExpr { PhysicalExprNode key = 1; } -message ListRangeExpr { - PhysicalExprNode start = 1; - PhysicalExprNode stop = 2; -} - message ListStrideExpr { PhysicalExprNode start = 1; PhysicalExprNode stop = 2; @@ -1793,7 +1782,6 @@ message PhysicalGetIndexedFieldExprNode { oneof field { NamedStructFieldExpr named_struct_field_expr = 2; ListIndexExpr list_index_expr = 3; - ListRangeExpr list_range_expr = 4; - ListStrideExpr list_stride_expr = 5; + ListStrideExpr list_stride_expr = 4; } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index b76d2dc8ca60..ce40a7de8e11 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -9025,9 +9025,6 @@ impl serde::Serialize for GetIndexedField { get_indexed_field::Field::ListIndex(v) => { struct_ser.serialize_field("listIndex", v)?; } - get_indexed_field::Field::ListRange(v) => { - struct_ser.serialize_field("listRange", v)?; - } get_indexed_field::Field::ListStride(v) => { struct_ser.serialize_field("listStride", v)?; } @@ -9048,8 +9045,6 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { "namedStructField", "list_index", "listIndex", - "list_range", - "listRange", "list_stride", "listStride", ]; @@ -9059,7 +9054,6 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { Expr, NamedStructField, ListIndex, - ListRange, ListStride, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -9085,7 +9079,6 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { "expr" => Ok(GeneratedField::Expr), "namedStructField" | "named_struct_field" => Ok(GeneratedField::NamedStructField), "listIndex" | "list_index" => Ok(GeneratedField::ListIndex), - "listRange" | "list_range" => Ok(GeneratedField::ListRange), "listStride" | "list_stride" => Ok(GeneratedField::ListStride), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -9128,13 +9121,6 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { return Err(serde::de::Error::duplicate_field("listIndex")); } field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListIndex) -; - } - GeneratedField::ListRange => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("listRange")); - } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListRange) ; } GeneratedField::ListStride => { @@ -12542,222 +12528,6 @@ impl<'de> serde::Deserialize<'de> for ListIndexExpr { deserializer.deserialize_struct("datafusion.ListIndexExpr", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ListRange { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.start.is_some() { - len += 1; - } - if self.stop.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ListRange", len)?; - if let Some(v) = self.start.as_ref() { - struct_ser.serialize_field("start", v)?; - } - if let Some(v) = self.stop.as_ref() { - struct_ser.serialize_field("stop", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ListRange { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "start", - "stop", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Start, - Stop, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "start" => Ok(GeneratedField::Start), - "stop" => Ok(GeneratedField::Stop), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ListRange; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ListRange") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut start__ = None; - let mut stop__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Start => { - if start__.is_some() { - return Err(serde::de::Error::duplicate_field("start")); - } - start__ = map_.next_value()?; - } - GeneratedField::Stop => { - if stop__.is_some() { - return Err(serde::de::Error::duplicate_field("stop")); - } - stop__ = map_.next_value()?; - } - } - } - Ok(ListRange { - start: start__, - stop: stop__, - }) - } - } - deserializer.deserialize_struct("datafusion.ListRange", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ListRangeExpr { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.start.is_some() { - len += 1; - } - if self.stop.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ListRangeExpr", len)?; - if let Some(v) = self.start.as_ref() { - struct_ser.serialize_field("start", v)?; - } - if let Some(v) = self.stop.as_ref() { - struct_ser.serialize_field("stop", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ListRangeExpr { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "start", - "stop", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Start, - Stop, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "start" => Ok(GeneratedField::Start), - "stop" => Ok(GeneratedField::Stop), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ListRangeExpr; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ListRangeExpr") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut start__ = None; - let mut stop__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Start => { - if start__.is_some() { - return Err(serde::de::Error::duplicate_field("start")); - } - start__ = map_.next_value()?; - } - GeneratedField::Stop => { - if stop__.is_some() { - return Err(serde::de::Error::duplicate_field("stop")); - } - stop__ = map_.next_value()?; - } - } - } - Ok(ListRangeExpr { - start: start__, - stop: stop__, - }) - } - } - deserializer.deserialize_struct("datafusion.ListRangeExpr", FIELDS, GeneratedVisitor) - } -} impl serde::Serialize for ListStride { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -18374,9 +18144,6 @@ impl serde::Serialize for PhysicalGetIndexedFieldExprNode { physical_get_indexed_field_expr_node::Field::ListIndexExpr(v) => { struct_ser.serialize_field("listIndexExpr", v)?; } - physical_get_indexed_field_expr_node::Field::ListRangeExpr(v) => { - struct_ser.serialize_field("listRangeExpr", v)?; - } physical_get_indexed_field_expr_node::Field::ListStrideExpr(v) => { struct_ser.serialize_field("listStrideExpr", v)?; } @@ -18397,8 +18164,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { "namedStructFieldExpr", "list_index_expr", "listIndexExpr", - "list_range_expr", - "listRangeExpr", "list_stride_expr", "listStrideExpr", ]; @@ -18408,7 +18173,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { Arg, NamedStructFieldExpr, ListIndexExpr, - ListRangeExpr, ListStrideExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -18434,7 +18198,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { "arg" => Ok(GeneratedField::Arg), "namedStructFieldExpr" | "named_struct_field_expr" => Ok(GeneratedField::NamedStructFieldExpr), "listIndexExpr" | "list_index_expr" => Ok(GeneratedField::ListIndexExpr), - "listRangeExpr" | "list_range_expr" => Ok(GeneratedField::ListRangeExpr), "listStrideExpr" | "list_stride_expr" => Ok(GeneratedField::ListStrideExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -18477,13 +18240,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { return Err(serde::de::Error::duplicate_field("listIndexExpr")); } field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::ListIndexExpr) -; - } - GeneratedField::ListRangeExpr => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("listRangeExpr")); - } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::ListRangeExpr) ; } GeneratedField::ListStrideExpr => { diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 7a8cd1cf0111..9aaf06109153 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -726,14 +726,6 @@ pub struct ListIndex { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct ListRange { - #[prost(message, optional, boxed, tag = "1")] - pub start: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, boxed, tag = "2")] - pub stop: ::core::option::Option<::prost::alloc::boxed::Box>, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] pub struct ListStride { #[prost(message, optional, boxed, tag = "1")] pub start: ::core::option::Option<::prost::alloc::boxed::Box>, @@ -747,7 +739,7 @@ pub struct ListStride { pub struct GetIndexedField { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(oneof = "get_indexed_field::Field", tags = "2, 3, 4, 5")] + #[prost(oneof = "get_indexed_field::Field", tags = "2, 3, 4")] pub field: ::core::option::Option, } /// Nested message and enum types in `GetIndexedField`. @@ -760,8 +752,6 @@ pub mod get_indexed_field { #[prost(message, tag = "3")] ListIndex(::prost::alloc::boxed::Box), #[prost(message, tag = "4")] - ListRange(::prost::alloc::boxed::Box), - #[prost(message, tag = "5")] ListStride(::prost::alloc::boxed::Box), } } @@ -2545,14 +2535,6 @@ pub struct ListIndexExpr { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct ListRangeExpr { - #[prost(message, optional, boxed, tag = "1")] - pub start: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, boxed, tag = "2")] - pub stop: ::core::option::Option<::prost::alloc::boxed::Box>, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] pub struct ListStrideExpr { #[prost(message, optional, boxed, tag = "1")] pub start: ::core::option::Option<::prost::alloc::boxed::Box>, @@ -2566,7 +2548,7 @@ pub struct ListStrideExpr { pub struct PhysicalGetIndexedFieldExprNode { #[prost(message, optional, boxed, tag = "1")] pub arg: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(oneof = "physical_get_indexed_field_expr_node::Field", tags = "2, 3, 4, 5")] + #[prost(oneof = "physical_get_indexed_field_expr_node::Field", tags = "2, 3, 4")] pub field: ::core::option::Option, } /// Nested message and enum types in `PhysicalGetIndexedFieldExprNode`. @@ -2579,8 +2561,6 @@ pub mod physical_get_indexed_field_expr_node { #[prost(message, tag = "3")] ListIndexExpr(::prost::alloc::boxed::Box), #[prost(message, tag = "4")] - ListRangeExpr(::prost::alloc::boxed::Box), - #[prost(message, tag = "5")] ListStrideExpr(::prost::alloc::boxed::Box), } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b085a7cae0dc..300f36203cf1 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1055,20 +1055,6 @@ pub fn parse_expr( )?), } } - Some(protobuf::get_indexed_field::Field::ListRange(list_range)) => { - GetFieldAccess::ListRange { - start: Box::new(parse_required_expr( - list_range.start.as_deref(), - registry, - "start", - )?), - stop: Box::new(parse_required_expr( - list_range.stop.as_deref(), - registry, - "stop", - )?), - } - } Some(protobuf::get_indexed_field::Field::ListStride(list_stride)) => { GetFieldAccess::ListStride { start: Box::new(parse_required_expr( diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 3fc0fe258a9a..4aa668a40afc 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1033,14 +1033,6 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { }, )) } - GetFieldAccess::ListRange { start, stop } => { - protobuf::get_indexed_field::Field::ListRange(Box::new( - protobuf::ListRange { - start: Some(Box::new(start.as_ref().try_into()?)), - stop: Some(Box::new(stop.as_ref().try_into()?)), - }, - )) - } GetFieldAccess::ListStride { start, stop, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 98086ce91ec4..4fced9da81f3 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -398,20 +398,6 @@ pub fn parse_physical_expr( "key", input_schema, )?}, - Some(protobuf::physical_get_indexed_field_expr_node::Field::ListRangeExpr(list_range_expr)) => GetFieldAccessExpr::ListRange{ - start: parse_required_physical_expr( - list_range_expr.start.as_deref(), - registry, - "start", - input_schema, - )?, - stop: parse_required_physical_expr( - list_range_expr.stop.as_deref(), - registry, - "stop", - input_schema - )?, - }, Some(protobuf::physical_get_indexed_field_expr_node::Field::ListStrideExpr(list_stride_expr)) => GetFieldAccessExpr::ListStride{ start: parse_required_physical_expr( list_stride_expr.start.as_deref(), diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 430739f38231..55259002b5c8 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -562,12 +562,6 @@ impl TryFrom> for protobuf::PhysicalExprNode { key: Some(Box::new(key.to_owned().try_into()?)) })) ), - GetFieldAccessExpr::ListRange{start, stop} => Some( - protobuf::physical_get_indexed_field_expr_node::Field::ListRangeExpr(Box::new(protobuf::ListRangeExpr { - start: Some(Box::new(start.to_owned().try_into()?)), - stop: Some(Box::new(stop.to_owned().try_into()?)), - })) - ), GetFieldAccessExpr::ListStride { start, stop, stride } => { Some( protobuf::physical_get_indexed_field_expr_node::Field::ListStrideExpr(Box::new(protobuf::ListStrideExpr { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 38eb39000317..636143d806e9 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -750,6 +750,7 @@ fn roundtrip_get_indexed_field_list_range() -> Result<()> { GetFieldAccessExpr::ListRange { start: col_start, stop: col_stop, + stride: Box::new(ScalarValue::Int64(Some(1))), }, )); diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index ba372b4a117f..bb3dcf87833e 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -776,7 +776,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, planner_context, )?); - GetFieldAccess::ListStride { start, stop, @@ -789,8 +788,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, planner_context, )?); - - GetFieldAccess::ListRange { start, stop: last } + let stride = Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))); + GetFieldAccess::ListStride { + start, + stop: last, + stride, + } } } } From df4cf8064e8cac97caccc4a300ff6161228cd73d Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Thu, 25 Jan 2024 16:14:36 +0800 Subject: [PATCH 05/10] fix ci --- datafusion/proto/tests/cases/roundtrip_physical_plan.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 636143d806e9..b2e4f05928eb 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -747,10 +747,11 @@ fn roundtrip_get_indexed_field_list_range() -> Result<()> { let col_stop = col("stop", &schema)?; let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new( col_arg, - GetFieldAccessExpr::ListRange { + GetFieldAccessExpr::ListStride { start: col_start, stop: col_stop, - stride: Box::new(ScalarValue::Int64(Some(1))), + stride: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))) + as Arc, }, )); From 26f5b32c137d1331536464a8ce159a9b5d380012 Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Sat, 27 Jan 2024 16:42:38 +0800 Subject: [PATCH 06/10] fix doctest --- datafusion/expr/src/expr.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 56a3d2bbc567..d5de254cb18c 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1213,7 +1213,7 @@ impl Expr { /// # use datafusion_expr::{lit, col}; /// let expr = col("c1") /// .range(lit(2), lit(4)); - /// assert_eq!(expr.display_name().unwrap(), "c1[Int32(2):Int32(4)]"); + /// assert_eq!(expr.display_name().unwrap(), "c1[Int32(2):Int32(4):Int64(1)]"); /// ``` pub fn range(self, start: Expr, stop: Expr) -> Self { Expr::GetIndexedField(GetIndexedField { From 293904086362b6eb2f9f4817d1f4d1ce3b8d3d20 Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Mon, 29 Jan 2024 10:35:05 +0800 Subject: [PATCH 07/10] fix conflict and keep ListRange --- datafusion/core/src/physical_planner.rs | 2 +- datafusion/expr/src/expr.rs | 8 +-- datafusion/expr/src/expr_schema.rs | 4 +- datafusion/expr/src/field_util.rs | 4 +- datafusion/expr/src/tree_node/expr.rs | 4 +- .../src/expressions/get_indexed_field.rs | 26 +++---- datafusion/physical-expr/src/planner.rs | 4 +- datafusion/proto/proto/datafusion.proto | 8 +-- datafusion/proto/src/generated/pbjson.rs | 68 +++++++++---------- datafusion/proto/src/generated/prost.rs | 8 +-- .../proto/src/logical_plan/from_proto.rs | 4 +- datafusion/proto/src/logical_plan/to_proto.rs | 6 +- .../proto/src/physical_plan/from_proto.rs | 2 +- .../proto/src/physical_plan/to_proto.rs | 2 +- .../tests/cases/roundtrip_physical_plan.rs | 2 +- datafusion/sql/src/expr/mod.rs | 4 +- 16 files changed, 78 insertions(+), 78 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 7d8b1a862288..d383ddce9242 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -209,7 +209,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { let key = create_physical_name(key, false)?; format!("{expr}[{key}]") } - GetFieldAccess::ListStride { + GetFieldAccess::ListRange { start, stop, stride, diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index d5de254cb18c..9da1f4bb4df7 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -422,7 +422,7 @@ pub enum GetFieldAccess { /// Single list index, for example: `list[i]` ListIndex { key: Box }, /// List stride, for example `list[i:j:k]` - ListStride { + ListRange { start: Box, stop: Box, stride: Box, @@ -1218,7 +1218,7 @@ impl Expr { pub fn range(self, start: Expr, stop: Expr) -> Self { Expr::GetIndexedField(GetIndexedField { expr: Box::new(self), - field: GetFieldAccess::ListStride { + field: GetFieldAccess::ListRange { start: Box::new(start), stop: Box::new(stop), stride: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))), @@ -1535,7 +1535,7 @@ impl fmt::Display for Expr { write!(f, "({expr})[{name}]") } GetFieldAccess::ListIndex { key } => write!(f, "({expr})[{key}]"), - GetFieldAccess::ListStride { + GetFieldAccess::ListRange { start, stop, stride, @@ -1741,7 +1741,7 @@ fn create_name(e: &Expr) -> Result { let key = create_name(key)?; Ok(format!("{expr}[{key}]")) } - GetFieldAccess::ListStride { + GetFieldAccess::ListRange { start, stop, stride, diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 14437a704932..4967e66fed40 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -374,11 +374,11 @@ fn field_for_index( GetFieldAccess::ListIndex { key } => GetFieldAccessSchema::ListIndex { key_dt: key.get_type(schema)?, }, - GetFieldAccess::ListStride { + GetFieldAccess::ListRange { start, stop, stride, - } => GetFieldAccessSchema::ListStride { + } => GetFieldAccessSchema::ListRange { start_dt: start.get_type(schema)?, stop_dt: stop.get_type(schema)?, stride_dt: stride.get_type(schema)?, diff --git a/datafusion/expr/src/field_util.rs b/datafusion/expr/src/field_util.rs index a24cb42bbb58..c46ec50234dd 100644 --- a/datafusion/expr/src/field_util.rs +++ b/datafusion/expr/src/field_util.rs @@ -29,7 +29,7 @@ pub enum GetFieldAccessSchema { /// Single list index, for example: `list[i]` ListIndex { key_dt: DataType }, /// List stride, for example `list[i:j:k]` - ListStride { + ListRange { start_dt: DataType, stop_dt: DataType, stride_dt: DataType, @@ -86,7 +86,7 @@ impl GetFieldAccessSchema { (other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), } } - Self::ListStride { start_dt, stop_dt, stride_dt } => { + Self::ListRange { start_dt, stop_dt, stride_dt } => { match (data_type, start_dt, stop_dt, stride_dt) { (DataType::List(_), DataType::Int64, DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)), (DataType::List(_), _, _, _) => plan_err!( diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 05464c96d05e..8b38d1cf01d6 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -52,8 +52,8 @@ impl TreeNode for Expr { let expr = expr.as_ref(); match field { GetFieldAccess::ListIndex {key} => vec![key.as_ref(), expr], - GetFieldAccess::ListRange {start, stop} => { - vec![start.as_ref(), stop.as_ref(), expr] + GetFieldAccess::ListRange {start, stop, stride} => { + vec![start.as_ref(), stop.as_ref(),stride.as_ref(), expr] } GetFieldAccess::NamedStructField { .. } => vec![expr], } diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index b67e06475f43..58fe4728543d 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -45,7 +45,7 @@ pub enum GetFieldAccessExpr { /// Single list index, for example: `list[i]` ListIndex { key: Arc }, /// List stride, for example `list[i:j:k]` - ListStride { + ListRange { start: Arc, stop: Arc, stride: Arc, @@ -57,7 +57,7 @@ impl std::fmt::Display for GetFieldAccessExpr { match self { GetFieldAccessExpr::NamedStructField { name } => write!(f, "[{}]", name), GetFieldAccessExpr::ListIndex { key } => write!(f, "[{}]", key), - GetFieldAccessExpr::ListStride { + GetFieldAccessExpr::ListRange { start, stop, stride, @@ -70,7 +70,7 @@ impl std::fmt::Display for GetFieldAccessExpr { impl PartialEq for GetFieldAccessExpr { fn eq(&self, other: &dyn Any) -> bool { - use GetFieldAccessExpr::{ListIndex, ListStride, NamedStructField}; + use GetFieldAccessExpr::{ListIndex, ListRange, NamedStructField}; down_cast_any_ref(other) .downcast_ref::() .map(|x| match (self, x) { @@ -79,12 +79,12 @@ impl PartialEq for GetFieldAccessExpr { } (ListIndex { key: lhs }, ListIndex { key: rhs }) => lhs.eq(rhs), ( - ListStride { + ListRange { start: start_lhs, stop: stop_lhs, stride: stride_lhs, }, - ListStride { + ListRange { start: start_rhs, stop: stop_rhs, stride: stride_rhs, @@ -94,9 +94,9 @@ impl PartialEq for GetFieldAccessExpr { && stop_lhs.eq(stop_rhs) && stride_lhs.eq(stride_rhs) } - (NamedStructField { .. }, ListIndex { .. } | ListStride { .. }) => false, - (ListIndex { .. }, NamedStructField { .. } | ListStride { .. }) => false, - (ListStride { .. }, NamedStructField { .. } | ListIndex { .. }) => false, + (NamedStructField { .. }, ListIndex { .. } | ListRange { .. }) => false, + (ListIndex { .. }, NamedStructField { .. } | ListRange { .. }) => false, + (ListRange { .. }, NamedStructField { .. } | ListIndex { .. }) => false, }) .unwrap_or(false) } @@ -140,7 +140,7 @@ impl GetIndexedFieldExpr { ) -> Self { Self::new( arg, - GetFieldAccessExpr::ListStride { + GetFieldAccessExpr::ListRange { start, stop, stride: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))) @@ -158,7 +158,7 @@ impl GetIndexedFieldExpr { ) -> Self { Self::new( arg, - GetFieldAccessExpr::ListStride { + GetFieldAccessExpr::ListRange { start, stop, stride, @@ -184,11 +184,11 @@ impl GetIndexedFieldExpr { GetFieldAccessExpr::ListIndex { key } => GetFieldAccessSchema::ListIndex { key_dt: key.data_type(input_schema)?, }, - GetFieldAccessExpr::ListStride { + GetFieldAccessExpr::ListRange { start, stop, stride, - } => GetFieldAccessSchema::ListStride { + } => GetFieldAccessSchema::ListRange { start_dt: start.data_type(input_schema)?, stop_dt: stop.data_type(input_schema)?, stride_dt: stride.data_type(input_schema)?, @@ -263,7 +263,7 @@ impl PhysicalExpr for GetIndexedFieldExpr { with utf8 indexes. Tried {dt:?} with {key:?} index"), } }, - GetFieldAccessExpr::ListStride { start, stop, stride } => { + GetFieldAccessExpr::ListRange { start, stop, stride } => { let start = start.evaluate(batch)?.into_array(batch.num_rows())?; let stop = stop.evaluate(batch)?.into_array(batch.num_rows())?; let stride = stride.evaluate(batch)?.into_array(batch.num_rows())?; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 834c29a944fd..ee5da05d1151 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -238,11 +238,11 @@ pub fn create_physical_expr( GetFieldAccess::ListIndex { key } => GetFieldAccessExpr::ListIndex { key: create_physical_expr(key, input_dfschema, execution_props)?, }, - GetFieldAccess::ListStride { + GetFieldAccess::ListRange { start, stop, stride, - } => GetFieldAccessExpr::ListStride { + } => GetFieldAccessExpr::ListRange { start: create_physical_expr(start, input_dfschema, execution_props)?, stop: create_physical_expr(stop, input_dfschema, execution_props)?, stride: create_physical_expr( diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 43d0d92307ce..c8468e1709c3 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -463,7 +463,7 @@ message ListIndex { LogicalExprNode key = 1; } -message ListStride { +message ListRange { LogicalExprNode start = 1; LogicalExprNode stop = 2; LogicalExprNode stride = 3; @@ -474,7 +474,7 @@ message GetIndexedField { oneof field { NamedStructField named_struct_field = 2; ListIndex list_index = 3; - ListStride list_stride = 4; + ListRange list_range = 4; } } @@ -1771,7 +1771,7 @@ message ListIndexExpr { PhysicalExprNode key = 1; } -message ListStrideExpr { +message ListRangeExpr { PhysicalExprNode start = 1; PhysicalExprNode stop = 2; PhysicalExprNode stride = 3; @@ -1782,6 +1782,6 @@ message PhysicalGetIndexedFieldExprNode { oneof field { NamedStructFieldExpr named_struct_field_expr = 2; ListIndexExpr list_index_expr = 3; - ListStrideExpr list_stride_expr = 4; + ListRangeExpr list_range_expr = 4; } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index ce40a7de8e11..47667fb68c43 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -9025,8 +9025,8 @@ impl serde::Serialize for GetIndexedField { get_indexed_field::Field::ListIndex(v) => { struct_ser.serialize_field("listIndex", v)?; } - get_indexed_field::Field::ListStride(v) => { - struct_ser.serialize_field("listStride", v)?; + get_indexed_field::Field::ListRange(v) => { + struct_ser.serialize_field("listRange", v)?; } } } @@ -9045,8 +9045,8 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { "namedStructField", "list_index", "listIndex", - "list_stride", - "listStride", + "list_range", + "listRange", ]; #[allow(clippy::enum_variant_names)] @@ -9054,7 +9054,7 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { Expr, NamedStructField, ListIndex, - ListStride, + ListRange, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9079,7 +9079,7 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { "expr" => Ok(GeneratedField::Expr), "namedStructField" | "named_struct_field" => Ok(GeneratedField::NamedStructField), "listIndex" | "list_index" => Ok(GeneratedField::ListIndex), - "listStride" | "list_stride" => Ok(GeneratedField::ListStride), + "listRange" | "list_range" => Ok(GeneratedField::ListRange), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9123,11 +9123,11 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListIndex) ; } - GeneratedField::ListStride => { + GeneratedField::ListRange => { if field__.is_some() { - return Err(serde::de::Error::duplicate_field("listStride")); + return Err(serde::de::Error::duplicate_field("listRange")); } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListStride) + field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListRange) ; } } @@ -12528,7 +12528,7 @@ impl<'de> serde::Deserialize<'de> for ListIndexExpr { deserializer.deserialize_struct("datafusion.ListIndexExpr", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ListStride { +impl serde::Serialize for ListRange { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -12545,7 +12545,7 @@ impl serde::Serialize for ListStride { if self.stride.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ListStride", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ListRange", len)?; if let Some(v) = self.start.as_ref() { struct_ser.serialize_field("start", v)?; } @@ -12558,7 +12558,7 @@ impl serde::Serialize for ListStride { struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ListStride { +impl<'de> serde::Deserialize<'de> for ListRange { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -12608,13 +12608,13 @@ impl<'de> serde::Deserialize<'de> for ListStride { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ListStride; + type Value = ListRange; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ListStride") + formatter.write_str("struct datafusion.ListRange") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -12643,17 +12643,17 @@ impl<'de> serde::Deserialize<'de> for ListStride { } } } - Ok(ListStride { + Ok(ListRange { start: start__, stop: stop__, stride: stride__, }) } } - deserializer.deserialize_struct("datafusion.ListStride", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ListRange", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ListStrideExpr { +impl serde::Serialize for ListRangeExpr { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -12670,7 +12670,7 @@ impl serde::Serialize for ListStrideExpr { if self.stride.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ListStrideExpr", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ListRangeExpr", len)?; if let Some(v) = self.start.as_ref() { struct_ser.serialize_field("start", v)?; } @@ -12683,7 +12683,7 @@ impl serde::Serialize for ListStrideExpr { struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ListStrideExpr { +impl<'de> serde::Deserialize<'de> for ListRangeExpr { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -12733,13 +12733,13 @@ impl<'de> serde::Deserialize<'de> for ListStrideExpr { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ListStrideExpr; + type Value = ListRangeExpr; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ListStrideExpr") + formatter.write_str("struct datafusion.ListRangeExpr") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -12768,14 +12768,14 @@ impl<'de> serde::Deserialize<'de> for ListStrideExpr { } } } - Ok(ListStrideExpr { + Ok(ListRangeExpr { start: start__, stop: stop__, stride: stride__, }) } } - deserializer.deserialize_struct("datafusion.ListStrideExpr", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ListRangeExpr", FIELDS, GeneratedVisitor) } } impl serde::Serialize for ListingTableScanNode { @@ -18144,8 +18144,8 @@ impl serde::Serialize for PhysicalGetIndexedFieldExprNode { physical_get_indexed_field_expr_node::Field::ListIndexExpr(v) => { struct_ser.serialize_field("listIndexExpr", v)?; } - physical_get_indexed_field_expr_node::Field::ListStrideExpr(v) => { - struct_ser.serialize_field("listStrideExpr", v)?; + physical_get_indexed_field_expr_node::Field::ListRangeExpr(v) => { + struct_ser.serialize_field("listRangeExpr", v)?; } } } @@ -18164,8 +18164,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { "namedStructFieldExpr", "list_index_expr", "listIndexExpr", - "list_stride_expr", - "listStrideExpr", + "list_range_expr", + "listRangeExpr", ]; #[allow(clippy::enum_variant_names)] @@ -18173,7 +18173,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { Arg, NamedStructFieldExpr, ListIndexExpr, - ListStrideExpr, + ListRangeExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18198,7 +18198,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { "arg" => Ok(GeneratedField::Arg), "namedStructFieldExpr" | "named_struct_field_expr" => Ok(GeneratedField::NamedStructFieldExpr), "listIndexExpr" | "list_index_expr" => Ok(GeneratedField::ListIndexExpr), - "listStrideExpr" | "list_stride_expr" => Ok(GeneratedField::ListStrideExpr), + "listRangeExpr" | "list_range_expr" => Ok(GeneratedField::ListRangeExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18242,11 +18242,11 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::ListIndexExpr) ; } - GeneratedField::ListStrideExpr => { + GeneratedField::ListRangeExpr => { if field__.is_some() { - return Err(serde::de::Error::duplicate_field("listStrideExpr")); + return Err(serde::de::Error::duplicate_field("listRangeExpr")); } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::ListStrideExpr) + field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::ListRangeExpr) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 9aaf06109153..a5582cc2dc64 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -726,7 +726,7 @@ pub struct ListIndex { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct ListStride { +pub struct ListRange { #[prost(message, optional, boxed, tag = "1")] pub start: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, boxed, tag = "2")] @@ -752,7 +752,7 @@ pub mod get_indexed_field { #[prost(message, tag = "3")] ListIndex(::prost::alloc::boxed::Box), #[prost(message, tag = "4")] - ListStride(::prost::alloc::boxed::Box), + ListRange(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -2535,7 +2535,7 @@ pub struct ListIndexExpr { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct ListStrideExpr { +pub struct ListRangeExpr { #[prost(message, optional, boxed, tag = "1")] pub start: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, boxed, tag = "2")] @@ -2561,7 +2561,7 @@ pub mod physical_get_indexed_field_expr_node { #[prost(message, tag = "3")] ListIndexExpr(::prost::alloc::boxed::Box), #[prost(message, tag = "4")] - ListStrideExpr(::prost::alloc::boxed::Box), + ListRangeExpr(::prost::alloc::boxed::Box), } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 300f36203cf1..aa4f223877f2 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1055,8 +1055,8 @@ pub fn parse_expr( )?), } } - Some(protobuf::get_indexed_field::Field::ListStride(list_stride)) => { - GetFieldAccess::ListStride { + Some(protobuf::get_indexed_field::Field::ListRange(list_stride)) => { + GetFieldAccess::ListRange { start: Box::new(parse_required_expr( list_stride.start.as_deref(), registry, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 4aa668a40afc..e1fc3f0c8525 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1033,12 +1033,12 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { }, )) } - GetFieldAccess::ListStride { + GetFieldAccess::ListRange { start, stop, stride, - } => protobuf::get_indexed_field::Field::ListStride(Box::new( - protobuf::ListStride { + } => protobuf::get_indexed_field::Field::ListRange(Box::new( + protobuf::ListRange { start: Some(Box::new(start.as_ref().try_into()?)), stop: Some(Box::new(stop.as_ref().try_into()?)), stride: Some(Box::new(stride.as_ref().try_into()?)), diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 4fced9da81f3..96cb98e18e31 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -398,7 +398,7 @@ pub fn parse_physical_expr( "key", input_schema, )?}, - Some(protobuf::physical_get_indexed_field_expr_node::Field::ListStrideExpr(list_stride_expr)) => GetFieldAccessExpr::ListStride{ + Some(protobuf::physical_get_indexed_field_expr_node::Field::ListStrideExpr(list_stride_expr)) => GetFieldAccessExpr::ListRange{ start: parse_required_physical_expr( list_stride_expr.start.as_deref(), registry, diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 55259002b5c8..e91fa5cffd52 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -562,7 +562,7 @@ impl TryFrom> for protobuf::PhysicalExprNode { key: Some(Box::new(key.to_owned().try_into()?)) })) ), - GetFieldAccessExpr::ListStride { start, stop, stride } => { + GetFieldAccessExpr::ListRange { start, stop, stride } => { Some( protobuf::physical_get_indexed_field_expr_node::Field::ListStrideExpr(Box::new(protobuf::ListStrideExpr { start: Some(Box::new(start.to_owned().try_into()?)), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index b2e4f05928eb..75025ed5afd8 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -747,7 +747,7 @@ fn roundtrip_get_indexed_field_list_range() -> Result<()> { let col_stop = col("stop", &schema)?; let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new( col_arg, - GetFieldAccessExpr::ListStride { + GetFieldAccessExpr::ListRange { start: col_start, stop: col_stop, stride: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index bb3dcf87833e..261e7b299e1e 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -776,7 +776,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, planner_context, )?); - GetFieldAccess::ListStride { + GetFieldAccess::ListRange { start, stop, stride: last, @@ -789,7 +789,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context, )?); let stride = Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))); - GetFieldAccess::ListStride { + GetFieldAccess::ListRange { start, stop: last, stride, From 1ec7d43b1eb247e178f2f19272a8cb2c4bd6fe5d Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Mon, 29 Jan 2024 10:45:17 +0800 Subject: [PATCH 08/10] clean up thde code --- .../proto/src/logical_plan/from_proto.rs | 8 +- .../proto/src/physical_plan/from_proto.rs | 8 +- .../proto/src/physical_plan/to_proto.rs | 2 +- datafusion/sql/src/expr/mod.rs | 82 +++++++++---------- 4 files changed, 49 insertions(+), 51 deletions(-) diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index aa4f223877f2..eb72d1f9c3e8 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1055,20 +1055,20 @@ pub fn parse_expr( )?), } } - Some(protobuf::get_indexed_field::Field::ListRange(list_stride)) => { + Some(protobuf::get_indexed_field::Field::ListRange(list_range)) => { GetFieldAccess::ListRange { start: Box::new(parse_required_expr( - list_stride.start.as_deref(), + list_range.start.as_deref(), registry, "start", )?), stop: Box::new(parse_required_expr( - list_stride.stop.as_deref(), + list_range.stop.as_deref(), registry, "stop", )?), stride: Box::new(parse_required_expr( - list_stride.stride.as_deref(), + list_range.stride.as_deref(), registry, "stride", )?), diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 96cb98e18e31..454f74dfd132 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -398,21 +398,21 @@ pub fn parse_physical_expr( "key", input_schema, )?}, - Some(protobuf::physical_get_indexed_field_expr_node::Field::ListStrideExpr(list_stride_expr)) => GetFieldAccessExpr::ListRange{ + Some(protobuf::physical_get_indexed_field_expr_node::Field::ListRangeExpr(list_range_expr)) => GetFieldAccessExpr::ListRange{ start: parse_required_physical_expr( - list_stride_expr.start.as_deref(), + list_range_expr.start.as_deref(), registry, "start", input_schema, )?, stop: parse_required_physical_expr( - list_stride_expr.stop.as_deref(), + list_range_expr.stop.as_deref(), registry, "stop", input_schema )?, stride: parse_required_physical_expr( - list_stride_expr.stride.as_deref(), + list_range_expr.stride.as_deref(), registry, "stride", input_schema diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index e91fa5cffd52..a67410da57f4 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -564,7 +564,7 @@ impl TryFrom> for protobuf::PhysicalExprNode { ), GetFieldAccessExpr::ListRange { start, stop, stride } => { Some( - protobuf::physical_get_indexed_field_expr_node::Field::ListStrideExpr(Box::new(protobuf::ListStrideExpr { + protobuf::physical_get_indexed_field_expr_node::Field::ListRangeExpr(Box::new(protobuf::ListRangeExpr { start: Some(Box::new(start.to_owned().try_into()?)), stop: Some(Box::new(stop.to_owned().try_into()?)), stride: Some(Box::new(stride.to_owned().try_into()?)), diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 261e7b299e1e..b22c458b6db6 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -753,48 +753,46 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { operator: JsonOperator::Colon, right, } => { - // the last value could represent stop or stride - let last = Box::new(self.sql_expr_to_logical_expr( - *right.clone(), - schema, - planner_context, - )?); - - match *left { - SQLExpr::JsonAccess { - left, - operator: JsonOperator::Colon, - right, - } => { - let start = Box::new(self.sql_expr_to_logical_expr( - *left, - schema, - planner_context, - )?); - let stop = Box::new(self.sql_expr_to_logical_expr( - *right, - schema, - planner_context, - )?); - GetFieldAccess::ListRange { - start, - stop, - stride: last, - } - } - _ => { - let start = Box::new(self.sql_expr_to_logical_expr( - *left, - schema, - planner_context, - )?); - let stride = Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))); - GetFieldAccess::ListRange { - start, - stop: last, - stride, - } - } + let (start, stop, stride) = if let SQLExpr::JsonAccess { + left: l, + operator: JsonOperator::Colon, + right: r, + } = *left + { + let start = Box::new(self.sql_expr_to_logical_expr( + *l, + schema, + planner_context, + )?); + let stop = Box::new(self.sql_expr_to_logical_expr( + *r, + schema, + planner_context, + )?); + let stride = Box::new(self.sql_expr_to_logical_expr( + *right, + schema, + planner_context, + )?); + (start, stop, stride) + } else { + let start = Box::new(self.sql_expr_to_logical_expr( + *left, + schema, + planner_context, + )?); + let stop = Box::new(self.sql_expr_to_logical_expr( + *right, + schema, + planner_context, + )?); + let stride = Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))); + (start, stop, stride) + }; + GetFieldAccess::ListRange { + start, + stop, + stride, } } _ => GetFieldAccess::ListIndex { From 1582b3ea04fe30d997f6a96c05c256eb7b7ffb68 Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Mon, 29 Jan 2024 10:47:44 +0800 Subject: [PATCH 09/10] chore --- datafusion/sql/src/expr/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index b22c458b6db6..a502bf259c10 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -786,7 +786,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, planner_context, )?); - let stride = Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))); + let stride = Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))); (start, stop, stride) }; GetFieldAccess::ListRange { From c344374baa9296ddbedc555b4539d89e2430f8d8 Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Mon, 29 Jan 2024 11:08:39 +0800 Subject: [PATCH 10/10] fix ci --- datafusion/proto/tests/cases/roundtrip_physical_plan.rs | 1 + datafusion/sql/src/expr/mod.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 75025ed5afd8..eba3db298f84 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -35,6 +35,7 @@ use datafusion::logical_expr::{ create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, }; use datafusion::parquet::file::properties::WriterProperties; +use datafusion::physical_expr::expressions::Literal; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index a502bf259c10..b22c458b6db6 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -786,7 +786,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, planner_context, )?); - let stride = Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))); + let stride = Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))); (start, stop, stride) }; GetFieldAccess::ListRange {