Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(function): add CreateFunctionExecutor & bind_create_function #828

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 68 additions & 10 deletions src/binder/create_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ use pretty_xmlish::Pretty;
use serde::{Deserialize, Serialize};

use super::*;
use crate::types::DataType as RlDataType;

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
pub struct CreateFunction {
name: String,
arg_types: Vec<DataType>,
return_types: DataType,
language: String,
body: String,
pub schema_name: String,
pub name: String,
pub arg_types: Vec<RlDataType>,
pub return_type: RlDataType,
pub language: String,
pub body: String,
}

impl fmt::Display for CreateFunction {
Expand Down Expand Up @@ -46,11 +48,67 @@ impl CreateFunction {
impl Binder {
pub(super) fn bind_create_function(
&mut self,
_name: ObjectName,
_args: Option<Vec<OperateFunctionArg>>,
_return_type: Option<DataType>,
_params: CreateFunctionBody,
name: ObjectName,
args: Option<Vec<OperateFunctionArg>>,
return_type: Option<DataType>,
params: CreateFunctionBody,
) -> Result {
todo!()
let Ok((schema_name, function_name)) = split_name(&name) else {
return Err(BindError::BindFunctionError(
"failed to parse the input function name".to_string(),
));
};

let schema_name = schema_name.to_string();
let name = function_name.to_string();

let Some(return_type) = return_type else {
return Err(BindError::BindFunctionError(
"`return type` must be specified".to_string(),
));
};
let return_type = RlDataType::new(DataTypeKind::from(&return_type), false);

// TODO: language check (e.g., currently only support sql)
let Some(language) = params.language.clone() else {
return Err(BindError::BindFunctionError(
"`language` must be specified".to_string(),
));
};
let language = language.to_string();

// SQL udf function supports both single quote (i.e., as 'select $1 + $2')
// and double dollar (i.e., as $$select $1 + $2$$) for as clause
let body = match &params.as_ {
Some(FunctionDefinition::SingleQuotedDef(s)) => s.clone(),
Some(FunctionDefinition::DoubleDollarDef(s)) => s.clone(),
None => {
if params.return_.is_none() {
return Err(BindError::BindFunctionError(
"AS or RETURN must be specified".to_string(),
));
}
// Otherwise this is a return expression
// Note: this is a current work around, and we are assuming return sql udf
// will NOT involve complex syntax, so just reuse the logic for select definition
format!("select {}", &params.return_.unwrap().to_string())
}
};

let mut arg_types = vec![];
for arg in args.unwrap_or_default() {
arg_types.push(RlDataType::new(DataTypeKind::from(&arg.data_type), false));
}

let f = self.egraph.add(Node::CreateFunction(CreateFunction {
schema_name,
name,
arg_types,
return_type,
language,
body,
}));

Ok(f)
}
}
10 changes: 10 additions & 0 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,16 @@ impl Binder {
FunctionArgExpr::QualifiedWildcard(_) => todo!("support qualified wildcard"),
}
}

// TODO: sql udf inlining
let _catalog = self.catalog();
let Ok((_schema_name, _function_name)) = split_name(&func.name) else {
return Err(BindError::BindFunctionError(format!(
"failed to parse the function name {}",
func.name
)));
};

let node = match func.name.to_string().to_lowercase().as_str() {
"count" if args.is_empty() => Node::RowCount,
"count" => Node::Count(args[0]),
Expand Down
6 changes: 5 additions & 1 deletion src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use egg::{Id, Language};
use itertools::Itertools;

use crate::array;
use crate::catalog::{RootCatalog, TableRefId, DEFAULT_SCHEMA_NAME};
use crate::catalog::{RootCatalog, RootCatalogRef, TableRefId, DEFAULT_SCHEMA_NAME};
use crate::parser::*;
use crate::planner::{Expr as Node, RecExpr, TypeError, TypeSchemaAnalysis};
use crate::types::{DataTypeKind, DataValue};
Expand Down Expand Up @@ -234,6 +234,10 @@ impl Binder {
&self.egraph[id].nodes[0]
}

fn catalog(&self) -> RootCatalogRef {
self.catalog.clone()
}

fn bind_explain(&mut self, query: Statement) -> Result {
let id = self.bind_stmt(query)?;
let id = self.egraph.add(Node::Explain(id));
Expand Down
10 changes: 5 additions & 5 deletions src/catalog/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use crate::types::DataType;

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct FunctionCatalog {
name: String,
arg_types: Vec<DataType>,
return_type: DataType,
language: String,
body: String,
pub name: String,
pub arg_types: Vec<DataType>,
pub return_type: DataType,
pub language: String,
pub body: String,
}

impl FunctionCatalog {
Expand Down
25 changes: 25 additions & 0 deletions src/catalog/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

use super::function::FunctionCatalog;
use super::*;

/// The root of all catalogs.
Expand Down Expand Up @@ -99,6 +100,30 @@ impl RootCatalog {
table_id: table.id(),
})
}

pub fn get_function_by_name(
&self,
schema_name: &str,
function_name: &str,
) -> Option<Arc<FunctionCatalog>> {
let schema = self.get_schema_by_name(schema_name)?;
schema.get_function_by_name(function_name)
}

pub fn create_function(
&self,
schema_name: String,
name: String,
arg_types: Vec<DataType>,
return_type: DataType,
language: String,
body: String,
) {
let schema_idx = self.get_schema_id_by_name(&schema_name).unwrap();
let mut inner = self.inner.lock().unwrap();
let schema = inner.schemas.get_mut(&schema_idx).unwrap();
schema.create_function(name, arg_types, return_type, language, body);
}
}

impl Inner {
Expand Down
20 changes: 20 additions & 0 deletions src/catalog/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,26 @@ impl SchemaCatalog {
pub fn get_function_by_name(&self, name: &str) -> Option<Arc<FunctionCatalog>> {
self.functions.get(name).cloned()
}

pub fn create_function(
&mut self,
name: String,
arg_types: Vec<DataType>,
return_type: DataType,
language: String,
body: String,
) {
self.functions.insert(
name.clone(),
Arc::new(FunctionCatalog {
name: name.clone(),
arg_types,
return_type,
language,
body,
}),
);
}
}

#[cfg(test)]
Expand Down
34 changes: 34 additions & 0 deletions src/executor/create_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2024 RisingLight Project Authors. Licensed under Apache-2.0.

use super::*;
use crate::binder::CreateFunction;
use crate::catalog::RootCatalogRef;

/// The executor of `create function` statement.
pub struct CreateFunctionExecutor {
pub f: CreateFunction,
pub catalog: RootCatalogRef,
}

impl CreateFunctionExecutor {
#[try_stream(boxed, ok = DataChunk, error = ExecutorError)]
pub async fn execute(self) {
let CreateFunction {
schema_name,
name,
arg_types,
return_type,
language,
body,
} = self.f;

self.catalog.create_function(
schema_name.clone(),
name.clone(),
arg_types,
return_type,
language,
body,
);
}
}
8 changes: 8 additions & 0 deletions src/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ use self::top_n::TopNExecutor;
use self::values::*;
use self::window::*;
use crate::array::DataChunk;
use crate::executor::create_function::CreateFunctionExecutor;
use crate::planner::{Expr, ExprAnalysis, Optimizer, RecExpr, TypeSchemaAnalysis};
use crate::storage::{Storage, TracedStorageError};
use crate::types::{ColumnIndex, ConvertError, DataType};

mod copy_from_file;
mod copy_to_file;
mod create;
mod create_function;
mod delete;
mod drop;
mod evaluator;
Expand Down Expand Up @@ -302,6 +304,12 @@ impl<S: Storage> Builder<S> {
}
.execute(),

CreateFunction(f) => CreateFunctionExecutor {
f,
catalog: self.optimizer.catalog().clone(),
}
.execute(),

Drop(plan) => DropExecutor {
plan,
storage: self.storage.clone(),
Expand Down
Loading