Skip to content

Commit

Permalink
Merge pull request #455 from splitgraph/create-or-replace-function
Browse files Browse the repository at this point in the history
Implement CREATE OR REPLACE FUNCTION statement path
  • Loading branch information
gruuya authored Jul 19, 2023
2 parents 1abf4b1 + 07e4142 commit 35e1468
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 177 deletions.
5 changes: 3 additions & 2 deletions src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ pub trait FunctionCatalog: Sync + Send {
&self,
database_id: DatabaseId,
function_name: &str,
or_replace: bool,
details: &CreateFunctionDetails,
) -> Result<FunctionId>;

Expand Down Expand Up @@ -635,17 +636,17 @@ impl FunctionCatalog for DefaultCatalog {
&self,
database_id: DatabaseId,
function_name: &str,
or_replace: bool,
details: &CreateFunctionDetails,
) -> Result<FunctionId> {
self.repository
.create_function(database_id, function_name, details)
.create_function(database_id, function_name, or_replace, details)
.await
.map_err(|e| match e {
RepositoryError::FKConstraintViolation(_) => {
Error::DatabaseDoesNotExist { id: database_id }
}
RepositoryError::UniqueConstraintViolation(_) => {
// TODO overwrite function defns instead?
Error::FunctionAlreadyExists {
name: function_name.to_string(),
}
Expand Down
67 changes: 54 additions & 13 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,7 @@ impl SeafowlContext for DefaultSeafowlContext {
Statement::CreateTable { .. } => state.statement_to_plan(statement).await,

Statement::CreateFunction {
or_replace,
temporary: false,
name,
params: CreateFunctionBody { as_: Some( FunctionDefinition::SingleQuotedDef(details) ), .. },
Expand All @@ -1186,6 +1187,7 @@ impl SeafowlContext for DefaultSeafowlContext {

Ok(LogicalPlan::Extension(Extension {
node: Arc::new(SeafowlExtensionNode::CreateFunction(CreateFunction {
or_replace,
name: name.to_string(),
details: function_details,
output_schema: Arc::new(DFSchema::empty())
Expand Down Expand Up @@ -1712,14 +1714,20 @@ impl SeafowlContext for DefaultSeafowlContext {
}
SeafowlExtensionNode::CreateFunction(CreateFunction {
name,
or_replace,
details,
output_schema: _,
}) => {
self.register_function(name, details)?;

// Persist the function in the metadata storage
self.function_catalog
.create_function(self.database_id, name, details)
.create_function(
self.database_id,
name,
*or_replace,
details,
)
.await?;

Ok(make_dummy_exec())
Expand Down Expand Up @@ -2489,22 +2497,55 @@ mod tests {
);
}

#[rstest]
#[case::regular_type_names("float", "float")]
#[case::legacy_type_names("f32", "f32")]
#[case::uppercase_type_names("FLOAT", "REAL")]
#[tokio::test]
async fn test_register_udf() -> Result<()> {
async fn test_register_udf(
#[case] input_type: &str,
#[case] return_type: &str,
) -> Result<()> {
let sf_context = in_memory_context().await;

// Source: https://gist.github.com/going-digital/02e46c44d89237c07bc99cd440ebfa43
sf_context.collect(sf_context.plan_query(
r#"CREATE FUNCTION sintau AS '
{
"entrypoint": "sintau",
"language": "wasm",
"input_types": ["float"],
"return_type": "float",
"data": "AGFzbQEAAAABDQJgAX0BfWADfX9/AX0DBQQAAAABBQQBAUREBxgDBnNpbnRhdQAABGV4cDIAAQRsb2cyAAIKjgEEKQECfUMAAAA/IgIgACAAjpMiACACk4siAZMgAZZBAEEYEAMgAiAAk5gLGQAgACAAjiIAk0EYQSwQA7wgAKhBF3RqvgslAQF/IAC8IgFBF3ZB/wBrsiABQQl0s0MAAIBPlUEsQcQAEAOSCyIBAX0DQCADIACUIAEqAgCSIQMgAUEEaiIBIAJrDQALIAMLC0oBAEEAC0Q/x2FC2eATQUuqKsJzsqY9QAHJQH6V0DZv+V88kPJTPSJndz6sZjE/HQCAP/clMD0D/T++F6bRPkzcNL/Tgrg//IiKNwBqBG5hbWUBHwQABnNpbnRhdQEEZXhwMgIEbG9nMgMIZXZhbHBvbHkCNwQAAwABeAECeDECBGhhbGYBAQABeAICAAF4AQJ4aQMEAAF4AQVzdGFydAIDZW5kAwZyZXN1bHQDCQEDAQAEbG9vcA=="
}';"#,
)
.await?).await?;
let create_function_stmt = r#"CREATE FUNCTION sintau AS '
{
"entrypoint": "sintau",
"language": "wasm",
"input_types": ["int"],
"return_type": "int",
"data": "AGFzbQEAAAABDQJgAX0BfWADfX9/AX0DBQQAAAABBQQBAUREBxgDBnNpbnRhdQAABGV4cDIAAQRsb2cyAAIKjgEEKQECfUMAAAA/IgIgACAAjpMiACACk4siAZMgAZZBAEEYEAMgAiAAk5gLGQAgACAAjiIAk0EYQSwQA7wgAKhBF3RqvgslAQF/IAC8IgFBF3ZB/wBrsiABQQl0s0MAAIBPlUEsQcQAEAOSCyIBAX0DQCADIACUIAEqAgCSIQMgAUEEaiIBIAJrDQALIAMLC0oBAEEAC0Q/x2FC2eATQUuqKsJzsqY9QAHJQH6V0DZv+V88kPJTPSJndz6sZjE/HQCAP/clMD0D/T++F6bRPkzcNL/Tgrg//IiKNwBqBG5hbWUBHwQABnNpbnRhdQEEZXhwMgIEbG9nMgMIZXZhbHBvbHkCNwQAAwABeAECeDECBGhhbGYBAQABeAICAAF4AQJ4aQMEAAF4AQVzdGFydAIDZW5kAwZyZXN1bHQDCQEDAQAEbG9vcA=="
}';"#;

sf_context.plan_query(create_function_stmt).await?;

// Run the same query again to make sure we raise an error if the function already exists
let err = sf_context
.plan_query(create_function_stmt)
.await
.unwrap_err();

assert_eq!(
err.to_string(),
"Error during planning: Function \"sintau\" already exists"
);

// Now replace the function using proper input/return types
let replace_function_stmt = format!(
r#"CREATE OR REPLACE FUNCTION sintau AS '
{{
"entrypoint": "sintau",
"language": "wasm",
"input_types": ["{input_type}"],
"return_type": "{return_type}",
"data": "AGFzbQEAAAABDQJgAX0BfWADfX9/AX0DBQQAAAABBQQBAUREBxgDBnNpbnRhdQAABGV4cDIAAQRsb2cyAAIKjgEEKQECfUMAAAA/IgIgACAAjpMiACACk4siAZMgAZZBAEEYEAMgAiAAk5gLGQAgACAAjiIAk0EYQSwQA7wgAKhBF3RqvgslAQF/IAC8IgFBF3ZB/wBrsiABQQl0s0MAAIBPlUEsQcQAEAOSCyIBAX0DQCADIACUIAEqAgCSIQMgAUEEaiIBIAJrDQALIAMLC0oBAEEAC0Q/x2FC2eATQUuqKsJzsqY9QAHJQH6V0DZv+V88kPJTPSJndz6sZjE/HQCAP/clMD0D/T++F6bRPkzcNL/Tgrg//IiKNwBqBG5hbWUBHwQABnNpbnRhdQEEZXhwMgIEbG9nMgMIZXZhbHBvbHkCNwQAAwABeAECeDECBGhhbGYBAQABeAICAAF4AQJ4aQMEAAF4AQVzdGFydAIDZW5kAwZyZXN1bHQDCQEDAQAEbG9vcA=="
}}';"#
);

sf_context
.plan_query(replace_function_stmt.as_str())
.await?;

let results = sf_context
.collect(
Expand Down
4 changes: 3 additions & 1 deletion src/datafusion/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ impl<'a> DFParser<'a> {

/// Parse a SQL CREATE statement
pub fn parse_create(&mut self) -> Result<Statement, ParserError> {
let or_replace = self.parser.parse_keywords(&[Keyword::OR, Keyword::REPLACE]);

if self.parser.parse_keyword(Keyword::EXTERNAL) {
self.parse_create_external_table(false)
} else if self.parser.parse_keyword(Keyword::UNBOUNDED) {
Expand All @@ -264,7 +266,7 @@ impl<'a> DFParser<'a> {
// XXX SEAFOWL: this is the change to get CREATE FUNCTION parsing working
else if self.parser.parse_keyword(Keyword::FUNCTION) {
// assume we don't have CREATE TEMPORARY FUNCTION (since we don't care about TEMPORARY)
self.parse_create_function(false, false)
self.parse_create_function(or_replace, false)
// XXX SEAFOWL: change ends here
} else {
Ok(Statement::Statement(Box::from(self.parser.parse_create()?)))
Expand Down
1 change: 1 addition & 0 deletions src/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub struct CreateTable {
pub struct CreateFunction {
/// The function name
pub name: String,
pub or_replace: bool,
pub details: CreateFunctionDetails,
/// Dummy result schema for the plan (empty)
pub output_schema: DFSchemaRef,
Expand Down
20 changes: 17 additions & 3 deletions src/repository/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,15 +376,29 @@ impl Repository for $repo {
&self,
database_id: DatabaseId,
function_name: &str,
or_replace: bool,
details: &CreateFunctionDetails,
) -> Result<FunctionId, Error> {
let input_types = serde_json::to_string(&details.input_types).expect("Couldn't serialize input types!");

let new_function_id: i64 = sqlx::query(
let query = format!(
r#"
INSERT INTO "function" (database_id, name, entrypoint, language, input_types, return_type, data, volatility)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING (id);
"#)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8){} RETURNING (id);
"#,
if or_replace {
" ON CONFLICT (database_id, name) DO UPDATE SET entrypoint = EXCLUDED.entrypoint, \
language = EXCLUDED.language, \
input_types = EXCLUDED.input_types, \
return_type = EXCLUDED.return_type, \
data = EXCLUDED.data, \
volatility = EXCLUDED.volatility"
} else {
""
}
);

let new_function_id: i64 = sqlx::query(query.as_str())
.bind(database_id)
.bind(function_name)
.bind(details.entrypoint.clone())
Expand Down
44 changes: 44 additions & 0 deletions src/repository/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ pub trait Repository: Send + Sync + Debug {
&self,
database_id: DatabaseId,
function_name: &str,
or_replace: bool,
details: &CreateFunctionDetails,
) -> Result<FunctionId, Error>;

Expand Down Expand Up @@ -400,6 +401,7 @@ pub mod tests {
.create_function(
database_id,
"testfun",
false,
&CreateFunctionDetails {
entrypoint: "entrypoint".to_string(),
language: CreateFunctionLanguage::Wasm,
Expand Down Expand Up @@ -432,6 +434,48 @@ pub mod tests {
volatility: "Volatile".to_string(),
}];
assert_eq!(all_functions, expected_functions);

// Now try to replace the function, effectively only upserting the new function details
let new_function_id = repository
.create_function(
database_id,
"testfun",
true,
&CreateFunctionDetails {
entrypoint: "entrypoint".to_string(),
language: CreateFunctionLanguage::WasmMessagePack,
input_types: vec![
CreateFunctionDataType::VARCHAR,
CreateFunctionDataType::DOUBLE,
CreateFunctionDataType::DATE,
],
return_type: CreateFunctionDataType::BOOLEAN,
data: "replaced_data".to_string(),
volatility: CreateFunctionVolatility::Immutable,
},
)
.await
.unwrap();

assert_eq!(new_function_id, function_id);

// Load function and assert changes have been made to the original entry
let all_functions = repository
.get_all_functions_in_database(database_id)
.await
.unwrap();

let expected_functions = vec![AllDatabaseFunctionsResult {
name: "testfun".to_string(),
id: function_id,
entrypoint: "entrypoint".to_string(),
language: "WasmMessagePack".to_string(),
input_types: r#"["varchar","double","date"]"#.to_string(),
return_type: "BOOLEAN".to_string(),
data: "replaced_data".to_string(),
volatility: "Immutable".to_string(),
}];
assert_eq!(all_functions, expected_functions);
}

async fn test_rename_table(
Expand Down
Loading

0 comments on commit 35e1468

Please sign in to comment.