diff --git a/src/binder/create_function.rs b/src/binder/create_function.rs index 50272ff5..c0195348 100644 --- a/src/binder/create_function.rs +++ b/src/binder/create_function.rs @@ -8,14 +8,16 @@ use pretty_xmlish::Pretty; use serde::{Deserialize, Serialize}; use super::*; +use crate::types::DataType as RlDataType; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)] pub struct CreateFunction { - name: String, - arg_types: Vec, - return_types: DataType, - language: String, - body: String, + pub schema_name: String, + pub name: String, + pub arg_types: Vec, + pub return_type: RlDataType, + pub language: String, + pub body: String, } impl fmt::Display for CreateFunction { @@ -46,11 +48,67 @@ impl CreateFunction { impl Binder { pub(super) fn bind_create_function( &mut self, - _name: ObjectName, - _args: Option>, - _return_type: Option, - _params: CreateFunctionBody, + name: ObjectName, + args: Option>, + return_type: Option, + params: CreateFunctionBody, ) -> Result { - todo!() + let Ok((schema_name, function_name)) = split_name(&name) else { + return Err(BindError::BindFunctionError( + "failed to parse the input function name".to_string(), + )); + }; + + let schema_name = schema_name.to_string(); + let name = function_name.to_string(); + + let Some(return_type) = return_type else { + return Err(BindError::BindFunctionError( + "`return type` must be specified".to_string(), + )); + }; + let return_type = RlDataType::new(DataTypeKind::from(&return_type), false); + + // TODO: language check (e.g., currently only support sql) + let Some(language) = params.language.clone() else { + return Err(BindError::BindFunctionError( + "`language` must be specified".to_string(), + )); + }; + let language = language.to_string(); + + // SQL udf function supports both single quote (i.e., as 'select $1 + $2') + // and double dollar (i.e., as $$select $1 + $2$$) for as clause + let body = match ¶ms.as_ { + Some(FunctionDefinition::SingleQuotedDef(s)) => s.clone(), + Some(FunctionDefinition::DoubleDollarDef(s)) => s.clone(), + None => { + if params.return_.is_none() { + return Err(BindError::BindFunctionError( + "AS or RETURN must be specified".to_string(), + )); + } + // Otherwise this is a return expression + // Note: this is a current work around, and we are assuming return sql udf + // will NOT involve complex syntax, so just reuse the logic for select definition + format!("select {}", ¶ms.return_.unwrap().to_string()) + } + }; + + let mut arg_types = vec![]; + for arg in args.unwrap_or_default() { + arg_types.push(RlDataType::new(DataTypeKind::from(&arg.data_type), false)); + } + + let f = self.egraph.add(Node::CreateFunction(CreateFunction { + schema_name, + name, + arg_types, + return_type, + language, + body, + })); + + Ok(f) } } diff --git a/src/binder/expr.rs b/src/binder/expr.rs index ad0b3716..17a0ae46 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -297,6 +297,16 @@ impl Binder { FunctionArgExpr::QualifiedWildcard(_) => todo!("support qualified wildcard"), } } + + // TODO: sql udf inlining + let _catalog = self.catalog(); + let Ok((_schema_name, _function_name)) = split_name(&func.name) else { + return Err(BindError::BindFunctionError(format!( + "failed to parse the function name {}", + func.name + ))); + }; + let node = match func.name.to_string().to_lowercase().as_str() { "count" if args.is_empty() => Node::RowCount, "count" => Node::Count(args[0]), diff --git a/src/binder/mod.rs b/src/binder/mod.rs index d7abaabd..20f493d8 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -9,7 +9,7 @@ use egg::{Id, Language}; use itertools::Itertools; use crate::array; -use crate::catalog::{RootCatalog, TableRefId, DEFAULT_SCHEMA_NAME}; +use crate::catalog::{RootCatalog, RootCatalogRef, TableRefId, DEFAULT_SCHEMA_NAME}; use crate::parser::*; use crate::planner::{Expr as Node, RecExpr, TypeError, TypeSchemaAnalysis}; use crate::types::{DataTypeKind, DataValue}; @@ -234,6 +234,10 @@ impl Binder { &self.egraph[id].nodes[0] } + fn catalog(&self) -> RootCatalogRef { + self.catalog.clone() + } + fn bind_explain(&mut self, query: Statement) -> Result { let id = self.bind_stmt(query)?; let id = self.egraph.add(Node::Explain(id)); diff --git a/src/catalog/function.rs b/src/catalog/function.rs index cb57256e..0ab6fe80 100644 --- a/src/catalog/function.rs +++ b/src/catalog/function.rs @@ -4,11 +4,11 @@ use crate::types::DataType; #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct FunctionCatalog { - name: String, - arg_types: Vec, - return_type: DataType, - language: String, - body: String, + pub name: String, + pub arg_types: Vec, + pub return_type: DataType, + pub language: String, + pub body: String, } impl FunctionCatalog { diff --git a/src/catalog/root.rs b/src/catalog/root.rs index 7f6c23aa..7af06e48 100644 --- a/src/catalog/root.rs +++ b/src/catalog/root.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; +use super::function::FunctionCatalog; use super::*; /// The root of all catalogs. @@ -99,6 +100,30 @@ impl RootCatalog { table_id: table.id(), }) } + + pub fn get_function_by_name( + &self, + schema_name: &str, + function_name: &str, + ) -> Option> { + let schema = self.get_schema_by_name(schema_name)?; + schema.get_function_by_name(function_name) + } + + pub fn create_function( + &self, + schema_name: String, + name: String, + arg_types: Vec, + return_type: DataType, + language: String, + body: String, + ) { + let schema_idx = self.get_schema_id_by_name(&schema_name).unwrap(); + let mut inner = self.inner.lock().unwrap(); + let schema = inner.schemas.get_mut(&schema_idx).unwrap(); + schema.create_function(name, arg_types, return_type, language, body); + } } impl Inner { diff --git a/src/catalog/schema.rs b/src/catalog/schema.rs index 615aabb4..af88ccf2 100644 --- a/src/catalog/schema.rs +++ b/src/catalog/schema.rs @@ -89,6 +89,26 @@ impl SchemaCatalog { pub fn get_function_by_name(&self, name: &str) -> Option> { self.functions.get(name).cloned() } + + pub fn create_function( + &mut self, + name: String, + arg_types: Vec, + return_type: DataType, + language: String, + body: String, + ) { + self.functions.insert( + name.clone(), + Arc::new(FunctionCatalog { + name: name.clone(), + arg_types, + return_type, + language, + body, + }), + ); + } } #[cfg(test)] diff --git a/src/executor/create_function.rs b/src/executor/create_function.rs new file mode 100644 index 00000000..8e59076c --- /dev/null +++ b/src/executor/create_function.rs @@ -0,0 +1,34 @@ +// Copyright 2024 RisingLight Project Authors. Licensed under Apache-2.0. + +use super::*; +use crate::binder::CreateFunction; +use crate::catalog::RootCatalogRef; + +/// The executor of `create function` statement. +pub struct CreateFunctionExecutor { + pub f: CreateFunction, + pub catalog: RootCatalogRef, +} + +impl CreateFunctionExecutor { + #[try_stream(boxed, ok = DataChunk, error = ExecutorError)] + pub async fn execute(self) { + let CreateFunction { + schema_name, + name, + arg_types, + return_type, + language, + body, + } = self.f; + + self.catalog.create_function( + schema_name.clone(), + name.clone(), + arg_types, + return_type, + language, + body, + ); + } +} diff --git a/src/executor/mod.rs b/src/executor/mod.rs index e945e7f2..472c0930 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -46,6 +46,7 @@ use self::top_n::TopNExecutor; use self::values::*; use self::window::*; use crate::array::DataChunk; +use crate::executor::create_function::CreateFunctionExecutor; use crate::planner::{Expr, ExprAnalysis, Optimizer, RecExpr, TypeSchemaAnalysis}; use crate::storage::{Storage, TracedStorageError}; use crate::types::{ColumnIndex, ConvertError, DataType}; @@ -53,6 +54,7 @@ use crate::types::{ColumnIndex, ConvertError, DataType}; mod copy_from_file; mod copy_to_file; mod create; +mod create_function; mod delete; mod drop; mod evaluator; @@ -302,6 +304,12 @@ impl Builder { } .execute(), + CreateFunction(f) => CreateFunctionExecutor { + f, + catalog: self.optimizer.catalog().clone(), + } + .execute(), + Drop(plan) => DropExecutor { plan, storage: self.storage.clone(),