diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 4f262b54fb20..def9fcb4c61b 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -26,7 +26,7 @@ use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, cast::as_float64_array, cast::as_int32_array, not_impl_err, plan_err, ExprSchema, Result, ScalarValue, }; -use datafusion_common::{exec_err, internal_err, DataFusionError}; +use datafusion_common::{assert_contains, exec_err, internal_err, DataFusionError}; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ @@ -205,6 +205,44 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF { } } +#[tokio::test] +async fn test_row_mismatch_error_in_scalar_udf() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![1, 2]))], + )?; + + let ctx = SessionContext::new(); + + ctx.register_batch("t", batch)?; + + // udf that always return 1 row + let buggy_udf = Arc::new(|_: &[ColumnarValue]| { + Ok(ColumnarValue::Array(Arc::new(Int32Array::from(vec![0])))) + }); + + ctx.register_udf(create_udf( + "buggy_func", + vec![DataType::Int32], + Arc::new(DataType::Int32), + Volatility::Immutable, + buggy_udf, + )); + assert_contains!( + ctx.sql("select buggy_func(a) from t") + .await? + .show() + .await + .err() + .unwrap() + .to_string(), + "UDF returned a different number of rows than expected" + ); + Ok(()) +} + #[tokio::test] async fn scalar_udf_zero_params() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); diff --git a/datafusion/functions-array/src/array_has.rs b/datafusion/functions-array/src/array_has.rs index ee064335c1cc..e5e8add95fbe 100644 --- a/datafusion/functions-array/src/array_has.rs +++ b/datafusion/functions-array/src/array_has.rs @@ -288,36 +288,40 @@ fn general_array_has_dispatch( } else { array }; - for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = if comparison_type != ComparisonType::Single { - converter.convert_columns(&[sub_arr])? - } else { - converter.convert_columns(&[element.clone()])? - }; - - let mut res = match comparison_type { - ComparisonType::All => sub_arr_values - .iter() - .dedup() - .all(|elem| arr_values.iter().dedup().any(|x| x == elem)), - ComparisonType::Any => sub_arr_values - .iter() - .dedup() - .any(|elem| arr_values.iter().dedup().any(|x| x == elem)), - ComparisonType::Single => arr_values - .iter() - .dedup() - .any(|x| x == sub_arr_values.row(row_idx)), - }; - - if comparison_type == ComparisonType::Any { - res |= res; + match (arr, sub_arr) { + (Some(arr), Some(sub_arr)) => { + let arr_values = converter.convert_columns(&[arr])?; + let sub_arr_values = if comparison_type != ComparisonType::Single { + converter.convert_columns(&[sub_arr])? + } else { + converter.convert_columns(&[element.clone()])? + }; + + let mut res = match comparison_type { + ComparisonType::All => sub_arr_values + .iter() + .dedup() + .all(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Any => sub_arr_values + .iter() + .dedup() + .any(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Single => arr_values + .iter() + .dedup() + .any(|x| x == sub_arr_values.row(row_idx)), + }; + + if comparison_type == ComparisonType::Any { + res |= res; + } + boolean_builder.append_value(res); + } + // respect null input + (_, _) => { + boolean_builder.append_null(); } - - boolean_builder.append_value(res); } } Ok(Arc::new(boolean_builder.finish())) diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index b00b8ea553f2..a092aac159bb 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Scalar, StringArray}; +use arrow::array::{ + make_array, Array, Capacities, MutableArrayData, Scalar, StringArray, +}; use arrow::datatypes::DataType; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{exec_err, ExprSchema, Result, ScalarValue}; @@ -107,29 +109,55 @@ impl ScalarUDFImpl for GetFieldFunc { ); } }; + match (array.data_type(), name) { - (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { - let map_array = as_map_array(array.as_ref())?; - let key_scalar = Scalar::new(StringArray::from(vec![k.clone()])); - let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; - let entries = arrow::compute::filter(map_array.entries(), &keys)?; - let entries_struct_array = as_struct_array(entries.as_ref())?; - Ok(ColumnarValue::Array(entries_struct_array.column(1).clone())) - } - (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { - let as_struct_array = as_struct_array(&array)?; - match as_struct_array.column_by_name(k) { - None => exec_err!( - "get indexed field {k} not found in struct"), - Some(col) => Ok(ColumnarValue::Array(col.clone())) + (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { + let map_array = as_map_array(array.as_ref())?; + let key_scalar: Scalar>> = Scalar::new(StringArray::from(vec![k.clone()])); + let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; + + // note that this array has more entries than the expected output/input size + // because maparray is flatten + let original_data = map_array.entries().column(1).to_data(); + let capacity = Capacities::Array(original_data.len()); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, + capacity); + + for entry in 0..map_array.len(){ + let start = map_array.value_offsets()[entry] as usize; + let end = map_array.value_offsets()[entry + 1] as usize; + + let maybe_matched = + keys.slice(start, end-start). + iter().enumerate(). + find(|(_, t)| t.unwrap()); + if maybe_matched.is_none(){ + mutable.extend_nulls(1); + continue } + let (match_offset,_) = maybe_matched.unwrap(); + mutable.extend(0, start + match_offset, start + match_offset + 1); + } + let data = mutable.freeze(); + let data = make_array(data); + Ok(ColumnarValue::Array(data)) + } + (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { + let as_struct_array = as_struct_array(&array)?; + match as_struct_array.column_by_name(k) { + None => exec_err!("get indexed field {k} not found in struct"), + Some(col) => Ok(ColumnarValue::Array(col.clone())), } - (DataType::Struct(_), name) => exec_err!( - "get indexed field is only possible on struct with utf8 indexes. \ - Tried with {name:?} index"), - (dt, name) => exec_err!( - "get indexed field is only possible on lists with int64 indexes or struct \ - with utf8 indexes. Tried {dt:?} with {name:?} index"), } + (DataType::Struct(_), name) => exec_err!( + "get indexed field is only possible on struct with utf8 indexes. \ + Tried with {name:?} index" + ), + (dt, name) => exec_err!( + "get indexed field is only possible on lists with int64 indexes or struct \ + with utf8 indexes. Tried {dt:?} with {name:?} index" + ), + } } } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 3b360fc20c39..b9c6ff3cfefc 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -146,11 +146,18 @@ impl PhysicalExpr for ScalarFunctionExpr { // evaluate the function match self.fun { ScalarFunctionDefinition::UDF(ref fun) => { - if self.args.is_empty() { - fun.invoke_no_args(batch.num_rows()) - } else { - fun.invoke(&inputs) + let output = match self.args.is_empty() { + true => fun.invoke_no_args(batch.num_rows()), + false => fun.invoke(&inputs), + }?; + + if let ColumnarValue::Array(array) = &output { + if array.len() != batch.num_rows() { + return internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}", + batch.num_rows(), array.len()); + } } + Ok(output) } ScalarFunctionDefinition::Name(_) => { internal_err!( diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index c3c5603dafc6..b33419ecd47c 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5169,8 +5169,9 @@ false false false true true false true false true false false true false true false false -false false false false -false false false false +NULL NULL false false +false false NULL false +false false false NULL query BBBB select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5, 6)), @@ -5183,8 +5184,9 @@ false false false true true false true false true false false true false true false false -false false false false -false false false false +NULL NULL false false +false false NULL false +false false false NULL query BBBB select array_has(column1, make_array(5, 6)), @@ -5197,8 +5199,9 @@ false false false true true false true false true false false true false true false false -false false false false -false false false false +NULL NULL false false +false false NULL false +false false false NULL query BBBBBBBBBBBBB select array_has_all(make_array(1,2,3), make_array(1,3)), diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 415fabf224d7..8ff7d119c454 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -44,6 +44,7 @@ DELETE 24 query T SELECT strings['not_found'] FROM data LIMIT 1; ---- +NULL statement ok drop table data;