Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(binder): support sql udf with unnamed parameters #835

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,19 @@ impl Binder {
/// Bind an expression.
pub fn bind_expr(&mut self, expr: Expr) -> Result {
let id = match expr {
Expr::Value(v) => Ok(self.egraph.add(Node::Constant(v.into()))),
Expr::Value(v) => {
// This is okay since only sql udf relies on
// parameter-like (i.e., `$1`) values at present
// TODO: consider formally `bind_parameter` in the future
// e.g., lambda function support, etc.
if let Value::Placeholder(key) = v {
self.udf_context
.get_expr(&key)
.map_or_else(|| Err(BindError::InvalidSQL), |&e| Ok(e))
} else {
Ok(self.egraph.add(Node::Constant(v.into())))
}
}
Expr::Identifier(ident) => self.bind_ident([ident]),
Expr::CompoundIdentifier(idents) => self.bind_ident(idents),
Expr::BinaryOp { left, op, right } => self.bind_binary_op(*left, op, *right),
Expand Down
2 changes: 1 addition & 1 deletion src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ impl UdfContext {
return Err(BindError::InvalidExpression("invalid syntax".to_string()));
};
if catalog.arg_names[i].is_empty() {
todo!("anonymous parameters not yet supported");
ret.insert(format!("${}", i + 1), e.clone());
} else {
// The index mapping here is accurate
// So that we could directly use the index
Expand Down
233 changes: 214 additions & 19 deletions tests/sql/sql_udf.slt
Original file line number Diff line number Diff line change
@@ -1,42 +1,237 @@
# Create a named sql udf with single quote definition
statement ok
create function add_named(a INT, b INT) returns int language sql as 'select a + b';
#############################################################
# Basic tests for sql udf with [unnamed / named] parameters #
#############################################################

# Create another named sql udf with double dollar definition
# Create a sql udf function with unnamed parameters with double dollar as clause
statement ok
create function sub_named(a INT, b INT) returns int language sql as $$select a - b$$;
create function add(INT, INT) returns int language sql as $$select $1 + $2$$;

# Named sql udf with corner case
query I
select add(1, -1);
----
0

# Create a sql udf function with unnamed parameters with single quote as clause
statement ok
create function corner_case(a INT) returns varchar language sql as $$select '$1 + a + $3'$$;
create function sub(INT, INT) returns int language sql as 'select $1 - $2';

# Call sql udf inside named sql udf
query I
select sub(1, 1);
----
0

# Create a sql udf function with unamed parameters that calls other pre-defined sql udfs
statement ok
create function add_named_wrapper(a INT, b INT) returns int language sql as 'select add_named(a, b)';
create function add_sub_binding() returns int language sql as 'select add(1, 1) + sub(2, 2)';

# TODO: add checks
# Named sql udf with invalid parameter in body definition
# Will be rejected at creation time
# statement error failed to find named parameter aa
# create function unknown_parameter(a INT) returns int language sql as 'select a + aa + a';
query I
select add_sub_binding();
----
2

# Use them all together
query III
select add(1, -1), sub(1, 1), add_sub_binding();
----
0 0 2

# Create a sql udf with named parameters with single quote as clause
statement ok
create function add_named(a INT, b INT) returns int language sql as 'select a + b';

# Call the defined named sql udfs
query I
select add_named(1, -1);
----
0

# Create another sql udf with named parameters with double dollar as clause
statement ok
create function sub_named(a INT, b INT) returns int language sql as $$select a - b$$;

query I
select sub_named(1, 1);
----
0

query T
select corner_case(1);
# Mixed with named / unnamed parameters
statement ok
create function add_sub_mix(INT, a INT, INT) returns int language sql as 'select $1 - a + $3';

query I
select add_sub_mix(1, 2, 3);
----
$1 + a + $3
2

# Call sql udf with unnamed parameters inside sql udf with named parameters
statement ok
create function add_named_wrapper(a INT, b INT) returns int language sql as 'select add(a, b)';

query I
select add_named_wrapper(1, -1);
----
0
0

# Create a sql udf with unnamed parameters with return expression
statement ok
create function add_return(INT, INT) returns int language sql return $1 + $2;

query I
select add_return(1, 1);
----
2

statement ok
create function add_return_binding() returns int language sql return add_return(1, 1) + add_return(1, 1);

query I
select add_return_binding();
----
4

statement ok
create function print(INT) returns int language sql as 'select $1';

query T
select print(114514);
----
114514

# Multiple type interleaving sql udf
statement ok
create function add_sub(INT, FLOAT, INT) returns float language sql as $$select -$1 + $2 - $3$$;

query I
select add_sub(1, 5.1415926, 1);
----
3.1415926

query III
select add(1, -1), sub(1, 1), add_sub(1, 5.1415926, 1);
----
0 0 3.1415926

# TODO: `Real` is not supported yet
# Complex types interleaving
# statement ok
# create function add_sub_types(INT, BIGINT, FLOAT, DECIMAL, REAL) returns double language sql as 'select $1 + $2 - $3 + $4 + $5';

# query I
# select add_sub_types(1, 1919810114514, 3.1415926, 1.123123, 101010.191919);
# ----
# 1919810215523.1734494

statement ok
create function add_sub_return(INT, FLOAT, INT) returns float language sql return -$1 + $2 - $3;

query I
select add_sub_return(1, 5.1415926, 1);
----
3.1415926

# Create a wrapper function for `add` & `sub`
statement ok
create function add_sub_wrapper(INT, INT) returns int language sql as 'select add($1, $2) + sub($1, $2) + 114512';

query I
select add_sub_wrapper(1, 1);
----
114514

##########################################################
# Basic sql udfs integrated with the use of mock tables #
# P.S. This is also a simulation of real world use cases #
##########################################################

statement ok
create table t1 (c1 INT, c2 INT);

statement ok
create table t2 (c1 INT, c2 FLOAT, c3 INT);

# Special table for named sql udf
statement ok
create table t3 (a INT, b INT);

statement ok
insert into t1 values (1, 1), (2, 2), (3, 3), (4, 4), (5, 5);

statement ok
insert into t2 values (1, 3.14, 2), (2, 4.44, 5), (20, 10.30, 02);

statement ok
insert into t3 values (1, 1), (2, 2), (3, 3), (4, 4), (5, 5);

query I
select c1, c2, add_return(c1, c2) from t1 order by c1 asc;
----
1 1 2
2 2 4
3 3 6
4 4 8
5 5 10

query III
select sub(c1, c2), c1, c2, add(c1, c2) from t1 order by c1 asc;
----
0 1 1 2
0 2 2 4
0 3 3 6
0 4 4 8
0 5 5 10

query IIIIII
select c1, c2, c3, add(c1, c3), sub(c1, c3), add_sub(c1, c2, c3) from t2 order by c1 asc;
----
1 3.14 2 3 -1 0.14000000000000012
2 4.44 5 7 -3 -2.5599999999999996
20 10.3 2 22 18 -11.7

query IIIIII
select c1, c2, c3, add(c1, c3), sub(c1, c3), add_sub_return(c1, c2, c3) from t2 order by c1 asc;
----
1 3.14 2 3 -1 0.14000000000000012
2 4.44 5 7 -3 -2.5599999999999996
20 10.3 2 22 18 -11.7

query I
select add_named(a, b) from t3 order by a asc;
----
2
4
6
8
10

################################
# Corner & Special cases tests #
################################

# Mixed parameter with calling inner sql udfs
statement ok
create function add_sub_mix_wrapper(INT, a INT, INT) returns int language sql as 'select add($1, a) + a + sub(a, $3)';

query I
select add_sub_mix_wrapper(1, 2, 3);
----
4

# Named sql udf with corner case
statement ok
create function corner_case(INT, a INT, INT) returns varchar language sql as $$select '$1 + a + $3'$$;

query T
select corner_case(1, 2, 3);
----
$1 + a + $3

# Adjust the input value of the calling function (i.e., `print` here) with the actual input parameter
statement ok
create function print_add_one(INT) returns int language sql as 'select print($1 + 1)';

statement ok
create function print_add_two(INT) returns int language sql as 'select print($1 + $1)';

query III
select print_add_one(1), print_add_one(114513), print_add_two(2);
----
2 114514 4
Loading