Skip to content

Commit

Permalink
add allow null type coercion parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Weijun-H committed Feb 2, 2024
1 parent 1ce1ffd commit 2556f04
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 36 deletions.
44 changes: 29 additions & 15 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -924,51 +924,63 @@ impl BuiltinScalarFunction {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayAppend => {
Signature::array_and_element(self.volatility())
Signature::array_and_element(false, self.volatility())
}
BuiltinScalarFunction::MakeArray => {
// 0 or more arguments of arbitrary type
Signature::one_of(vec![VariadicEqual, Any(0)], self.volatility())
}
BuiltinScalarFunction::ArrayPopFront => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayPopBack => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayPopFront => {
Signature::array(false, self.volatility())
}
BuiltinScalarFunction::ArrayPopBack => {
Signature::array(false, self.volatility())
}
BuiltinScalarFunction::ArrayConcat => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayDims => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayEmpty => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayDims => {
Signature::array(false, self.volatility())
}
BuiltinScalarFunction::ArrayEmpty => {
Signature::array(false, self.volatility())
}
BuiltinScalarFunction::ArrayElement => {
Signature::array_and_index(self.volatility())
Signature::array_and_index(false, self.volatility())
}
BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Flatten => Signature::array(self.volatility()),
BuiltinScalarFunction::Flatten => Signature::array(false, self.volatility()),
BuiltinScalarFunction::ArrayHasAll | BuiltinScalarFunction::ArrayHasAny => {
Signature::any(2, self.volatility())
}
BuiltinScalarFunction::ArrayHas => {
Signature::array_and_element(self.volatility())
Signature::array_and_element(false, self.volatility())
}
BuiltinScalarFunction::ArrayLength => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayNdims => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayDistinct => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayNdims => {
Signature::array(false, self.volatility())
}
BuiltinScalarFunction::ArrayDistinct => {
Signature::array(true, self.volatility())
}
BuiltinScalarFunction::ArrayPosition => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayPositions => {
Signature::array_and_element(self.volatility())
Signature::array_and_element(false, self.volatility())
}
BuiltinScalarFunction::ArrayPrepend => {
Signature::element_and_array(self.volatility())
Signature::element_and_array(false, self.volatility())
}
BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayRemove => {
Signature::array_and_element(self.volatility())
Signature::array_and_element(false, self.volatility())
}
BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()),
BuiltinScalarFunction::ArrayRemoveAll => {
Signature::array_and_element(self.volatility())
Signature::array_and_element(false, self.volatility())
}
BuiltinScalarFunction::ArrayReplace => Signature::any(3, self.volatility()),
BuiltinScalarFunction::ArrayReplaceN => Signature::any(4, self.volatility()),
Expand All @@ -985,7 +997,9 @@ impl BuiltinScalarFunction {
}
BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Cardinality => Signature::array(self.volatility()),
BuiltinScalarFunction::Cardinality => {
Signature::array(false, self.volatility())
}
BuiltinScalarFunction::ArrayResize => {
Signature::variadic_any(self.volatility())
}
Expand Down
53 changes: 39 additions & 14 deletions datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ pub enum TypeSignature {
/// is `OneOf(vec![Any(0), VariadicAny])`.
OneOf(Vec<TypeSignature>),
/// Specifies Signatures for array functions
ArraySignature(ArrayFunctionSignature),
/// Boolean value specifies whether null type coercion is allowed
ArraySignature(ArrayFunctionSignature, bool),
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand All @@ -144,13 +145,19 @@ pub enum ArrayFunctionSignature {
}

impl ArrayFunctionSignature {
/// Arguments to ArrayFunctionSignature
/// `current_types` - The data types of the arguments
/// `coercion` - Whether null type coercion is allowed
/// Returns the valid types for the function signature
pub fn get_type_signature(
&self,
current_types: &[DataType],
allow_null_coercion: bool,
) -> Result<Vec<Vec<DataType>>> {
fn array_append_or_prepend_valid_types(
current_types: &[DataType],
is_append: bool,
allow_null_coercion: bool,
) -> Result<Vec<Vec<DataType>>> {
if current_types.len() != 2 {
return Ok(vec![vec![]]);
Expand All @@ -163,7 +170,7 @@ impl ArrayFunctionSignature {
};

// We follow Postgres on `array_append(Null, T)`, which is not valid.
if array_type.eq(&DataType::Null) {
if array_type.eq(&DataType::Null) && !allow_null_coercion {
return Ok(vec![vec![]]);
}

Expand Down Expand Up @@ -215,8 +222,13 @@ impl ArrayFunctionSignature {
_ => Ok(vec![vec![]]),
}
}
fn array(current_types: &[DataType]) -> Result<Vec<Vec<DataType>>> {
if current_types.len() != 1 {
fn array(
current_types: &[DataType],
allow_null_coercion: bool,
) -> Result<Vec<Vec<DataType>>> {
if current_types.len() != 1
|| (current_types[0].is_null() && !allow_null_coercion)
{
return Ok(vec![vec![]]);
}

Expand All @@ -229,7 +241,6 @@ impl ArrayFunctionSignature {
let array_type = coerced_fixed_size_list_to_list(array_type);
Ok(vec![vec![array_type]])
}
DataType::Null => Ok(vec![vec![array_type.to_owned()]]),
_ => Ok(vec![vec![DataType::List(Arc::new(Field::new(
"item",
array_type.to_owned(),
Expand All @@ -239,13 +250,21 @@ impl ArrayFunctionSignature {
}
match self {
ArrayFunctionSignature::ArrayAndElement => {
array_append_or_prepend_valid_types(current_types, true)
array_append_or_prepend_valid_types(
current_types,
true,
allow_null_coercion,
)
}
ArrayFunctionSignature::ElementAndArray => {
array_append_or_prepend_valid_types(current_types, false)
array_append_or_prepend_valid_types(
current_types,
false,
allow_null_coercion,
)
}
ArrayFunctionSignature::ArrayAndIndex => array_and_index(current_types),
ArrayFunctionSignature::Array => array(current_types),
ArrayFunctionSignature::Array => array(current_types, allow_null_coercion),
}
}
}
Expand Down Expand Up @@ -297,7 +316,7 @@ impl TypeSignature {
TypeSignature::OneOf(sigs) => {
sigs.iter().flat_map(|s| s.to_string_repr()).collect()
}
TypeSignature::ArraySignature(array_signature) => {
TypeSignature::ArraySignature(array_signature, _) => {
vec![array_signature.to_string()]
}
}
Expand Down Expand Up @@ -402,36 +421,42 @@ impl Signature {
}
}
/// Specialized Signature for ArrayAppend and similar functions
pub fn array_and_element(volatility: Volatility) -> Self {
pub fn array_and_element(allow_null_coercion: bool, volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::ArrayAndElement,
allow_null_coercion,
),
volatility,
}
}
/// Specialized Signature for ArrayPrepend and similar functions
pub fn element_and_array(volatility: Volatility) -> Self {
pub fn element_and_array(allow_null_coercion: bool, volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::ElementAndArray,
allow_null_coercion,
),
volatility,
}
}
/// Specialized Signature for ArrayElement and similar functions
pub fn array_and_index(volatility: Volatility) -> Self {
pub fn array_and_index(allow_null_coercion: bool, volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::ArrayAndIndex,
allow_null_coercion,
),
volatility,
}
}
/// Specialized Signature for ArrayEmpty and similar functions
pub fn array(volatility: Volatility) -> Self {
pub fn array(allow_null_coercion: bool, volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::Array,
allow_null_coercion,
),
volatility,
}
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
.or_else(|| dictionary_coercion(lhs_type, rhs_type, true))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
.or_else(|| allow_null_coercion(lhs_type, rhs_type))
.or_else(|| string_numeric_coercion(lhs_type, rhs_type))
.or_else(|| string_temporal_coercion(lhs_type, rhs_type))
.or_else(|| binary_coercion(lhs_type, rhs_type))
Expand Down Expand Up @@ -756,7 +756,7 @@ pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataTyp
string_coercion(lhs_type, rhs_type)
.or_else(|| binary_to_string_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type, false))
.or_else(|| null_coercion(lhs_type, rhs_type))
.or_else(|| allow_null_coercion(lhs_type, rhs_type))
}

/// coercion rules for regular expression comparison operations.
Expand Down Expand Up @@ -844,7 +844,7 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataTyp

/// coercion rules from NULL type. Since NULL can be casted to any other type in arrow,
/// either lhs or rhs is NULL, if NULL can be casted to type of the other side, the coercion is valid.
fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
fn allow_null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
match (lhs_type, rhs_type) {
(DataType::Null, other_type) | (other_type, DataType::Null) => {
if can_cast_types(&DataType::Null, other_type) {
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ fn get_valid_types(
}

TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
TypeSignature::ArraySignature(ref function_signature) => {
function_signature.get_type_signature(current_types)?
TypeSignature::ArraySignature(ref function_signature, allow_null_coercion) => {
function_signature.get_type_signature(current_types, *allow_null_coercion)?
}

TypeSignature::Any(number) => {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2659,7 +2659,7 @@ pub fn array_distinct(args: &[ArrayRef]) -> Result<ArrayRef> {
}

// handle for list & largelist
match args[0].data_type() {
match dbg!(args[0].data_type()) {
DataType::List(field) => {
let array = as_list_array(&args[0])?;
general_array_distinct(array, field)
Expand Down
5 changes: 4 additions & 1 deletion datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4253,7 +4253,10 @@ NULL [3] [4]
# array_ndims scalar function #1

query error
selrct array_ndims(1), array_ndims(null)
select array_ndims(1)

query error
select array_ndims(null)

query I
select
Expand Down

0 comments on commit 2556f04

Please sign in to comment.