Skip to content

Commit

Permalink
Fix array function signatures and behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
Weijun-H committed Feb 19, 2024
1 parent e9346fe commit f0f9f7a
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 16 deletions.
8 changes: 4 additions & 4 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ impl BuiltinScalarFunction {
Signature::any(2, self.volatility())
}
BuiltinScalarFunction::ArrayHas => {
Signature::array_and_element(false, self.volatility())
Signature::array_and_element(true, self.volatility())
}
BuiltinScalarFunction::ArrayLength => {
Signature::variadic_any(self.volatility())
Expand All @@ -977,18 +977,18 @@ impl BuiltinScalarFunction {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayPositions => {
Signature::array_and_element(false, self.volatility())
Signature::array_and_element(true, self.volatility())
}
BuiltinScalarFunction::ArrayPrepend => {
Signature::element_and_array(false, self.volatility())
}
BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayRemove => {
Signature::array_and_element(false, self.volatility())
Signature::array_and_element(true, self.volatility())
}
BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()),
BuiltinScalarFunction::ArrayRemoveAll => {
Signature::array_and_element(false, self.volatility())
Signature::array_and_element(true, self.volatility())
}
BuiltinScalarFunction::ArrayReplace => Signature::any(3, self.volatility()),
BuiltinScalarFunction::ArrayReplaceN => Signature::any(4, self.volatility()),
Expand Down
25 changes: 15 additions & 10 deletions datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,12 @@ impl ArrayFunctionSignature {
};

// We follow Postgres on `array_append(Null, T)`, which is not valid.
if array_type.eq(&DataType::Null) && !allow_null_coercion {
return Ok(vec![vec![]]);
if array_type.eq(&DataType::Null) {
if allow_null_coercion {
return Ok(vec![vec![array_type.clone(), elem_type.clone()]]);
} else {
return Ok(vec![vec![]]);
}
}

// We need to find the coerced base type, mainly for cases like:
Expand All @@ -189,20 +193,21 @@ impl ArrayFunctionSignature {
)
})?;

let array_type = datafusion_common::utils::coerced_type_with_base_type_only(
array_type,
&new_base_type,
);
let new_array_type =
datafusion_common::utils::coerced_type_with_base_type_only(
array_type,
&new_base_type,
);

match array_type {
match new_array_type {
DataType::List(ref field)
| DataType::LargeList(ref field)
| DataType::FixedSizeList(ref field, _) => {
let elem_type = field.data_type();
let new_elem_type = field.data_type();
if is_append {
Ok(vec![vec![array_type.clone(), elem_type.clone()]])
Ok(vec![vec![new_array_type.clone(), new_elem_type.clone()]])
} else {
Ok(vec![vec![elem_type.to_owned(), array_type.clone()]])
Ok(vec![vec![new_elem_type.to_owned(), new_array_type.clone()]])
}
}
_ => Ok(vec![vec![]]),
Expand Down
13 changes: 12 additions & 1 deletion datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ fn compare_element_to_list(
row_index: usize,
eq: bool,
) -> Result<BooleanArray> {
if list_array_row.data_type() != element_array.data_type() {
if list_array_row.data_type() != element_array.data_type()
&& !element_array.data_type().is_null()
{
return exec_err!(
"compare_element_to_list received incompatible types: '{:?}' and '{:?}'.",
list_array_row.data_type(),
Expand Down Expand Up @@ -1481,6 +1483,10 @@ pub fn array_positions(args: &[ArrayRef]) -> Result<ArrayRef> {
check_datatypes("array_positions", &[arr.values(), element])?;
general_positions::<i64>(arr, element)
}
DataType::Null => Ok(new_null_array(
&DataType::List(Arc::new(Field::new("item", DataType::UInt64, true))),
1,
)),
array_type => {
exec_err!("array_positions does not support type '{array_type:?}'.")
}
Expand Down Expand Up @@ -1613,6 +1619,10 @@ fn array_remove_internal(
element_array: &ArrayRef,
arr_n: Vec<i64>,
) -> Result<ArrayRef> {
if array.data_type().is_null() {
return Ok(array.clone());
}

match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
Expand Down Expand Up @@ -2288,6 +2298,7 @@ pub fn array_has(args: &[ArrayRef]) -> Result<ArrayRef> {
DataType::LargeList(_) => {
general_array_has_dispatch::<i64>(&args[0], &args[1], ComparisonType::Single)
}
DataType::Null => Ok(new_null_array(&DataType::Boolean, 1)),
_ => exec_err!("array_has does not support type '{array_type:?}'."),
}
}
Expand Down
19 changes: 18 additions & 1 deletion datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2745,12 +2745,17 @@ NULL 1 NULL

## array_positions (aliases: `list_positions`)

# array_position with NULL (follow PostgreSQL)
query ?
select array_positions([1, 2, 3, 4, 5], null);
----
[]

# array_positions with NULL (follow PostgreSQL)
query ?
select array_positions(null, 1);
----
NULL

# array_positions scalar function #1
query ???
select array_positions(['h', 'e', 'l', 'l', 'o'], 'l'), array_positions([1, 2, 3, 4, 5], 5), array_positions([1, 1, 1], 1);
Expand Down Expand Up @@ -3874,6 +3879,13 @@ select
----
[1, , 3] [, 2.2, 3.3] [, bc]

# follow PostgreSQL behavior
query ?
select
array_remove(NULL, 1)
----
NULL

query ??
select
array_remove(make_array(1, null, 2), null),
Expand Down Expand Up @@ -4034,6 +4046,11 @@ select array_remove_n(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12],
## array_remove_all (aliases: `list_removes`)

# array_remove_all with NULL elements
query ?
select array_remove_all(NULL, 1);
----
NULL

query ?
select array_remove_all(make_array(1, 2, 2, 1, 1), NULL);
----
Expand Down

0 comments on commit f0f9f7a

Please sign in to comment.