Skip to content

Commit

Permalink
Prototype user defined sql planner might look like
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jun 28, 2024
1 parent 47db63f commit a7c9e7f
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 54 deletions.
1 change: 0 additions & 1 deletion datafusion/functions-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ pub mod set_ops;
pub mod sort;
pub mod string;
pub mod utils;

use datafusion_common::Result;
use datafusion_execution::FunctionRegistry;
use datafusion_expr::ScalarUDF;
Expand Down
133 changes: 80 additions & 53 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use arrow_schema::DataType;
use arrow_schema::TimeUnit;
use datafusion_common::utils::list_ndims;
use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value};
use std::sync::Arc;

use datafusion_common::{
internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, Result,
Expand All @@ -28,10 +29,10 @@ use datafusion_expr::expr::InList;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{
lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, ExprSchemable,
GetFieldAccess, Like, Literal, Operator, TryCast,
GetFieldAccess, Like, Literal, Operator, ScalarUDF, TryCast,
};

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use crate::planner::{ContextProvider, PlannerContext, SqlToRel, UserDefinedPlanner};

mod binary_op;
mod function;
Expand All @@ -52,7 +53,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
) -> Result<Expr> {
enum StackEntry {
SQLExpr(Box<SQLExpr>),
Operator(Operator),
Operator(sqlparser::ast::BinaryOperator),
}

// Virtual stack machine to convert SQLExpr to Expr
Expand All @@ -69,7 +70,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLExpr::BinaryOp { left, op, right } => {
// Note the order that we push the entries to the stack
// is important. We want to visit the left node first.
let op = self.parse_sql_binary_op(op)?;
stack.push(StackEntry::Operator(op));
stack.push(StackEntry::SQLExpr(right));
stack.push(StackEntry::SQLExpr(left));
Expand Down Expand Up @@ -100,63 +100,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

fn build_logical_expr(
&self,
op: Operator,
op: sqlparser::ast::BinaryOperator,
left: Expr,
right: Expr,
schema: &DFSchema,
) -> Result<Expr> {
// Rewrite string concat operator to function based on types
// if we get list || list then we rewrite it to array_concat()
// if we get list || non-list then we rewrite it to array_append()
// if we get non-list || list then we rewrite it to array_prepend()
// if we get string || string then we rewrite it to concat()
if op == Operator::StringConcat {
let left_type = left.get_type(schema)?;
let right_type = right.get_type(schema)?;
let left_list_ndims = list_ndims(&left_type);
let right_list_ndims = list_ndims(&right_type);

// We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient.
// The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite.
if left_list_ndims + right_list_ndims == 0 {
// TODO: concat function ignore null, but string concat takes null into consideration
// we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator`
} else if left_list_ndims == right_list_ndims {
if let Some(udf) = self.context_provider.get_function_meta("array_concat")
{
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![left, right],
)));
} else {
return internal_err!("array_concat not found");
}
} else if left_list_ndims > right_list_ndims {
if let Some(udf) = self.context_provider.get_function_meta("array_append")
{
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![left, right],
)));
} else {
return internal_err!("array_append not found");
}
} else if left_list_ndims < right_list_ndims {
if let Some(udf) =
self.context_provider.get_function_meta("array_prepend")
{
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![left, right],
)));
} else {
return internal_err!("array_append not found");
}
// try extension planers
for planner in self.planners.iter() {
if let Some(expr) =
planner.plan_binary_op(op.clone(), left.clone(), right.clone(), schema)?
{
return Ok(expr);
}
}

// by default, convert to datafusion operator

Ok(Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
op,
self.parse_sql_binary_op(op)?,
Box::new(right),
)))
}
Expand Down Expand Up @@ -1017,6 +979,71 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}

pub struct ArrayFunctionPlanner {
array_concat: Arc<ScalarUDF>,
array_append: Arc<ScalarUDF>,
array_prepend: Arc<ScalarUDF>,
}

impl ArrayFunctionPlanner {
pub fn try_new(context_provider: &dyn ContextProvider) -> Result<Self> {
let Some(array_concat) = context_provider.get_function_meta("array_concat")
else {
return internal_err!("array_concat not found");
};
let Some(array_append) = context_provider.get_function_meta("array_append")
else {
return internal_err!("array_append not found");
};
let Some(array_prepend) = context_provider.get_function_meta("array_prepend")
else {
return internal_err!("array_prepend not found");
};

Ok(Self {
array_concat,
array_append,
array_prepend,
})
}
}
impl UserDefinedPlanner for ArrayFunctionPlanner {
fn plan_binary_op(
&self,
op: sqlparser::ast::BinaryOperator,
left: Expr,
right: Expr,
schema: &DFSchema,
) -> Result<Option<Expr>> {
// Rewrite string concat operator to function based on types
// if we get list || list then we rewrite it to array_concat()
// if we get list || non-list then we rewrite it to array_append()
// if we get non-list || list then we rewrite it to array_prepend()
// if we get string || string then we rewrite it to concat()
if op == sqlparser::ast::BinaryOperator::StringConcat {
let left_type = left.get_type(schema)?;
let right_type = right.get_type(schema)?;
let left_list_ndims = list_ndims(&left_type);
let right_list_ndims = list_ndims(&right_type);

// We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient.
// The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite.
if left_list_ndims + right_list_ndims == 0 {
// TODO: concat function ignore null, but string concat takes null into consideration
// we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator`
} else if left_list_ndims == right_list_ndims {
return Ok(Some(self.array_concat.call(vec![left, right])));
} else if left_list_ndims > right_list_ndims {
return Ok(Some(self.array_append.call(vec![left, right])));
} else if left_list_ndims < right_list_ndims {
return Ok(Some(self.array_prepend.call(vec![left, right])));
}
}

Ok(None)
}
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;
Expand Down
33 changes: 33 additions & 0 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo};
use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption};
use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias};

use crate::expr::ArrayFunctionPlanner;
use datafusion_common::config::ConfigOptions;
use datafusion_common::TableReference;
use datafusion_common::{
Expand Down Expand Up @@ -236,11 +237,28 @@ impl PlannerContext {
}
}

/// This trait allows users to customize the behavior of the SQL planner
pub trait UserDefinedPlanner {
/// Plan the binary operation between two expressions, return None if not possible
/// TODO make an API that avoids the need to clone the expressions
fn plan_binary_op(
&self,
_op: sqlparser::ast::BinaryOperator,
_left: Expr,
_right: Expr,
_schema: &DFSchema,
) -> Result<Option<Expr>> {
Ok(None)
}
}

/// SQL query planner
pub struct SqlToRel<'a, S: ContextProvider> {
pub(crate) context_provider: &'a S,
pub(crate) options: ParserOptions,
pub(crate) normalizer: IdentNormalizer,
/// user defined planner extensions
pub(crate) planners: Vec<Arc<dyn UserDefinedPlanner>>,
}

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Expand All @@ -249,14 +267,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Self::new_with_options(context_provider, ParserOptions::default())
}

/// add an user defined planner
pub fn with_user_defined_planner(
mut self,
planner: Arc<dyn UserDefinedPlanner>,
) -> Self {
self.planners.push(planner);
self
}

/// Create a new query planner
pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self {
let normalize = options.enable_ident_normalization;
let array_planner =
Arc::new(ArrayFunctionPlanner::try_new(context_provider).unwrap()) as _;

SqlToRel {
context_provider,
options,
normalizer: IdentNormalizer::new(normalize),
planners: vec![],
}
// todo put this somewhere else
.with_user_defined_planner(array_planner)
}

pub fn build_schema(&self, columns: Vec<SQLColumnDef>) -> Result<Schema> {
Expand Down

0 comments on commit a7c9e7f

Please sign in to comment.