Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/main' into alamb/stack_overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Oct 8, 2024
2 parents ea97277 + e00af2c commit 7342d4f
Show file tree
Hide file tree
Showing 69 changed files with 3,446 additions and 1,586 deletions.
2 changes: 1 addition & 1 deletion datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ impl<T> Transformed<T> {
}
}

/// Create a `Transformed` with `transformed and [`TreeNodeRecursion::Continue`].
/// Create a `Transformed` with `transformed` and [`TreeNodeRecursion::Continue`].
pub fn new_transformed(data: T, transformed: bool) -> Self {
Self::new(data, transformed, TreeNodeRecursion::Continue)
}
Expand Down
9 changes: 5 additions & 4 deletions datafusion/core/src/bin/print_functions_docs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,14 @@ fn print_docs(
.find(|f| f.get_name() == name || f.get_aliases().contains(&name))
.unwrap();

let name = f.get_name();
let aliases = f.get_aliases();
let documentation = f.get_documentation();

// if this name is an alias we need to display what it's an alias of
if aliases.contains(&name) {
let _ = write!(docs, "_Alias of [{name}](#{name})._");
let fname = f.get_name();
let _ = writeln!(docs, r#"### `{name}`"#);
let _ = writeln!(docs, "_Alias of [{fname}](#{fname})._");
continue;
}

Expand Down Expand Up @@ -183,10 +184,10 @@ fn print_docs(

// next, aliases
if !f.get_aliases().is_empty() {
let _ = write!(docs, "#### Aliases");
let _ = writeln!(docs, "#### Aliases");

for alias in f.get_aliases() {
let _ = writeln!(docs, "- {alias}");
let _ = writeln!(docs, "- {}", alias.replace("_", r#"\_"#));
}
}

Expand Down
17 changes: 17 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,26 @@ impl DataFrame {
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<DataFrame> {
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
let aggr_expr_len = aggr_expr.len();
let plan = LogicalPlanBuilder::from(self.plan)
.aggregate(group_expr, aggr_expr)?
.build()?;
let plan = if is_grouping_set {
let grouping_id_pos = plan.schema().fields().len() - 1 - aggr_expr_len;
// For grouping sets we do a project to not expose the internal grouping id
let exprs = plan
.schema()
.columns()
.into_iter()
.enumerate()
.filter(|(idx, _)| *idx != grouping_id_pos)
.map(|(_, column)| Expr::Column(column))
.collect::<Vec<_>>();
LogicalPlanBuilder::from(plan).project(exprs)?.build()?
} else {
plan
};
Ok(DataFrame {
session_state: self.session_state,
plan,
Expand Down
14 changes: 2 additions & 12 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -692,10 +692,6 @@ impl DefaultPhysicalPlanner {
physical_input_schema.clone(),
)?);

// update group column indices based on partial aggregate plan evaluation
let final_group: Vec<Arc<dyn PhysicalExpr>> =
initial_aggr.output_group_expr();

let can_repartition = !groups.is_empty()
&& session_state.config().target_partitions() > 1
&& session_state.config().repartition_aggregations();
Expand All @@ -716,13 +712,7 @@ impl DefaultPhysicalPlanner {
AggregateMode::Final
};

let final_grouping_set = PhysicalGroupBy::new_single(
final_group
.iter()
.enumerate()
.map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone()))
.collect(),
);
let final_grouping_set = initial_aggr.group_expr().as_final();

Arc::new(AggregateExec::try_new(
next_partition_mode,
Expand Down Expand Up @@ -2345,7 +2335,7 @@ mod tests {
.expect("hash aggregate");
assert_eq!(
"sum(aggregate_test_100.c3)",
final_hash_agg.schema().field(2).name()
final_hash_agg.schema().field(3).name()
);
// we need access to the input to the partial aggregate so that other projects can
// implement serde
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/expr_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ mod simplification;
fn test_octet_length() {
#[rustfmt::skip]
evaluate_expr_test(
octet_length(col("list")),
octet_length(col("id")),
vec![
"+------+",
"| expr |",
"+------+",
"| 5 |",
"| 18 |",
"| 6 |",
"| 1 |",
"| 1 |",
"| 1 |",
"+------+",
],
);
Expand Down
18 changes: 17 additions & 1 deletion datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ pub enum TypeSignature {
/// Fixed number of arguments of numeric types.
/// See <https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html#method.is_numeric> to know which type is considered numeric
Numeric(usize),
/// Fixed number of arguments of all the same string types.
/// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8.
/// Null is considerd as Utf8 by default
/// Dictionary with string value type is also handled.
String(usize),
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
Expand Down Expand Up @@ -190,8 +195,11 @@ impl TypeSignature {
.collect::<Vec<String>>()
.join(", ")]
}
TypeSignature::String(num) => {
vec![format!("String({num})")]
}
TypeSignature::Numeric(num) => {
vec![format!("Numeric({})", num)]
vec![format!("Numeric({num})")]
}
TypeSignature::Exact(types) | TypeSignature::Coercible(types) => {
vec![Self::join_types(types, ", ")]
Expand Down Expand Up @@ -280,6 +288,14 @@ impl Signature {
}
}

/// A specified number of numeric arguments
pub fn string(arg_count: usize, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::String(arg_count),
volatility,
}
}

/// An arbitrary number of arguments of any type.
pub fn variadic_any(volatility: Volatility) -> Self {
Self {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,7 @@ fn string_concat_internal_coercion(
/// based on the observation that StringArray to StringViewArray is cheap but not vice versa.
///
/// Between Utf8 and LargeUtf8, we coerce to LargeUtf8.
fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
pub fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
// If Utf8View is in any side, we coerce to Utf8View.
Expand Down
56 changes: 55 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::sync::{Arc, OnceLock};

use super::dml::CopyTo;
use super::DdlStatement;
Expand Down Expand Up @@ -2965,6 +2965,15 @@ impl Aggregate {
.into_iter()
.map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into()))
.collect::<Vec<_>>();
qualified_fields.push((
None,
Field::new(
Self::INTERNAL_GROUPING_ID,
Self::grouping_id_type(qualified_fields.len()),
false,
)
.into(),
));
}

qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?);
Expand Down Expand Up @@ -3016,9 +3025,19 @@ impl Aggregate {
})
}

fn is_grouping_set(&self) -> bool {
matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)])
}

/// Get the output expressions.
fn output_expressions(&self) -> Result<Vec<&Expr>> {
static INTERNAL_ID_EXPR: OnceLock<Expr> = OnceLock::new();
let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?;
if self.is_grouping_set() {
exprs.push(INTERNAL_ID_EXPR.get_or_init(|| {
Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID))
}));
}
exprs.extend(self.aggr_expr.iter());
debug_assert!(exprs.len() == self.schema.fields().len());
Ok(exprs)
Expand All @@ -3030,6 +3049,41 @@ impl Aggregate {
pub fn group_expr_len(&self) -> Result<usize> {
grouping_set_expr_count(&self.group_expr)
}

/// Returns the data type of the grouping id.
/// The grouping ID value is a bitmask where each set bit
/// indicates that the corresponding grouping expression is
/// null
pub fn grouping_id_type(group_exprs: usize) -> DataType {
if group_exprs <= 8 {
DataType::UInt8
} else if group_exprs <= 16 {
DataType::UInt16
} else if group_exprs <= 32 {
DataType::UInt32
} else {
DataType::UInt64
}
}

/// Internal column used when the aggregation is a grouping set.
///
/// This column contains a bitmask where each bit represents a grouping
/// expression. The least significant bit corresponds to the rightmost
/// grouping expression. A bit value of 0 indicates that the corresponding
/// column is included in the grouping set, while a value of 1 means it is excluded.
///
/// For example, for the grouping expressions CUBE(a, b), the grouping ID
/// column will have the following values:
/// 0b00: Both `a` and `b` are included
/// 0b01: `b` is excluded
/// 0b10: `a` is excluded
/// 0b11: Both `a` and `b` are excluded
///
/// This internal column is necessary because excluded columns are replaced
/// with `NULL` values. To handle these cases correctly, we must distinguish
/// between an actual `NULL` value in a column and a column being excluded from the set.
pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id";
}

// Manual implementation needed because of `schema` field. Comparison excludes this field.
Expand Down
67 changes: 65 additions & 2 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ use datafusion_common::{
utils::{coerced_fixed_size_list_to_list, list_ndims},
Result,
};
use datafusion_expr_common::signature::{
ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD,
use datafusion_expr_common::{
signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
type_coercion::binary::string_coercion,
};
use std::sync::Arc;

Expand Down Expand Up @@ -176,6 +177,7 @@ fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
type_signature,
TypeSignature::UserDefined
| TypeSignature::Numeric(_)
| TypeSignature::String(_)
| TypeSignature::Coercible(_)
| TypeSignature::Any(_)
)
Expand Down Expand Up @@ -381,6 +383,67 @@ fn get_valid_types(
.iter()
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::String(number) => {
if *number < 1 {
return plan_err!(
"The signature expected at least one argument but received {}",
current_types.len()
);
}
if *number != current_types.len() {
return plan_err!(
"The signature expected {} arguments but received {}",
number,
current_types.len()
);
}

fn coercion_rule(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Result<DataType> {
match (lhs_type, rhs_type) {
(DataType::Null, DataType::Null) => Ok(DataType::Utf8),
(DataType::Null, data_type) | (data_type, DataType::Null) => {
coercion_rule(data_type, &DataType::Utf8)
}
(DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
coercion_rule(lhs, rhs)
}
(DataType::Dictionary(_, v), other)
| (other, DataType::Dictionary(_, v)) => coercion_rule(v, other),
_ => {
if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
Ok(coerced_type)
} else {
plan_err!(
"{} and {} are not coercible to a common string type",
lhs_type,
rhs_type
)
}
}
}
}

// Length checked above, safe to unwrap
let mut coerced_type = current_types.first().unwrap().to_owned();
for t in current_types.iter().skip(1) {
coerced_type = coercion_rule(&coerced_type, t)?;
}

fn base_type_or_default_type(data_type: &DataType) -> DataType {
if data_type.is_null() {
DataType::Utf8
} else if let DataType::Dictionary(_, v) = data_type {
base_type_or_default_type(v)
} else {
data_type.to_owned()
}
}

vec![vec![base_type_or_default_type(&coerced_type); *number]]
}
TypeSignature::Numeric(number) => {
if *number < 1 {
return plan_err!(
Expand Down
12 changes: 11 additions & 1 deletion datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,17 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result
/// Count the number of distinct exprs in a list of group by expressions. If the
/// first element is a `GroupingSet` expression then it must be the only expr.
pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
if group_expr.len() > 1 {
return plan_err!(
"Invalid group by expressions, GroupingSet must be the only expression"
);
}
// Groupings sets have an additional interal column for the grouping id
Ok(grouping_set.distinct_expr().len() + 1)
} else {
grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
}
}

/// The [power set] (or powerset) of a set S is the set of all subsets of S, \
Expand Down
Loading

0 comments on commit 7342d4f

Please sign in to comment.