From addf685a9b339593e746e32e121b9fa303cdaca0 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Thu, 4 Jan 2024 09:31:28 +0800 Subject: [PATCH] fix test Signed-off-by: jayzhan211 --- datafusion/common/src/scalar.rs | 218 +++++++++++------- .../tests/cases/roundtrip_logical_plan.rs | 2 +- 2 files changed, 133 insertions(+), 87 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index ba4b77610860..6c7e42167245 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -1941,7 +1941,7 @@ impl ScalarValue { .map(|a| a as &dyn Array) .collect::>(); arrow::compute::concat(arrays.as_slice()) - .map_err(DataFusionError::ArrowError)? + .map_err(|e| arrow_datafusion_err!(e))? } ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) @@ -3011,6 +3011,7 @@ impl fmt::Display for ScalarValue { let columns = struct_arr.columns(); let fields = struct_arr.fields(); + let nulls = struct_arr.nulls(); write!( f, @@ -3018,10 +3019,25 @@ impl fmt::Display for ScalarValue { columns .iter() .zip(fields.iter()) - .map(|(column, field)| { - let sv = ScalarValue::try_from_array(column, 0).unwrap(); - let name = field.name(); - format!("{name}:{sv}") + .enumerate() + .map(|(index, (column, field))| { + if nulls.is_some_and(|b| b.is_null(index)) { + format!("{}:NULL", field.name()) + } else { + if let DataType::Struct(_) = field.data_type() { + let sv = ScalarValue::Struct(Arc::new( + column.as_struct().to_owned(), + )); + + let name = field.name(); + format!("{name}:{sv}") + } else { + let sv = + ScalarValue::try_from_array(column, 0).unwrap(); + let name = field.name(); + format!("{name}:{sv}") + } + } }) .collect::>() .join(",") @@ -3181,13 +3197,14 @@ mod tests { use std::cmp::Ordering; use std::sync::Arc; + use arrow::util::pretty::pretty_format_columns; use arrow_buffer::Buffer; use chrono::NaiveDate; use rand::Rng; - use arrow::buffer::OffsetBuffer; use arrow::compute::kernels; use arrow::datatypes::ArrowPrimitiveType; + use arrow::{buffer::OffsetBuffer, compute::is_null}; use arrow_array::ArrowNumericType; use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; @@ -5407,86 +5424,115 @@ mod tests { } } - // #[test] - // fn test_struct_nulls() { - // let fields_b = Fields::from(vec![ - // Field::new("ba", DataType::UInt64, true), - // Field::new("bb", DataType::UInt64, true), - // ]); - // let fields = Fields::from(vec![ - // Field::new("a", DataType::UInt64, true), - // Field::new("b", DataType::Struct(fields_b.clone()), true), - // ]); - // let scalars = vec![ - // ScalarValue::Struct(None, fields.clone()), - // ScalarValue::Struct( - // Some(vec![ - // ScalarValue::UInt64(None), - // ScalarValue::Struct(None, fields_b.clone()), - // ]), - // fields.clone(), - // ), - // ScalarValue::Struct( - // Some(vec![ - // ScalarValue::UInt64(None), - // ScalarValue::Struct( - // Some(vec![ScalarValue::UInt64(None), ScalarValue::UInt64(None)]), - // fields_b.clone(), - // ), - // ]), - // fields.clone(), - // ), - // ScalarValue::Struct( - // Some(vec![ - // ScalarValue::UInt64(Some(1)), - // ScalarValue::Struct( - // Some(vec![ - // ScalarValue::UInt64(Some(2)), - // ScalarValue::UInt64(Some(3)), - // ]), - // fields_b, - // ), - // ]), - // fields, - // ), - // ]; - - // let check_array = |array| { - // let is_null = is_null(&array).unwrap(); - // assert_eq!(is_null, BooleanArray::from(vec![true, false, false, false])); - - // let formatted = pretty_format_columns("col", &[array]).unwrap().to_string(); - // let formatted = formatted.split('\n').collect::>(); - // let expected = vec![ - // "+---------------------------+", - // "| col |", - // "+---------------------------+", - // "| |", - // "| {a: , b: } |", - // "| {a: , b: {ba: , bb: }} |", - // "| {a: 1, b: {ba: 2, bb: 3}} |", - // "+---------------------------+", - // ]; - // assert_eq!( - // formatted, expected, - // "Actual:\n{formatted:#?}\n\nExpected:\n{expected:#?}" - // ); - // }; - - // // test `ScalarValue::iter_to_array` - // let array = ScalarValue::iter_to_array(scalars.clone()).unwrap(); - // check_array(array); - - // // test `ScalarValue::to_array` / `ScalarValue::to_array_of_size` - // let arrays = scalars - // .iter() - // .map(ScalarValue::to_array) - // .collect::>>() - // .expect("Failed to convert to array"); - // let arrays = arrays.iter().map(|a| a.as_ref()).collect::>(); - // let array = arrow::compute::concat(&arrays).unwrap(); - // check_array(array); - // } + #[test] + fn test_struct_nulls() { + let fields_b = Fields::from(vec![ + Field::new("ba", DataType::UInt64, true), + Field::new("bb", DataType::UInt64, true), + ]); + let fields = Fields::from(vec![ + Field::new("a", DataType::UInt64, true), + Field::new("b", DataType::Struct(fields_b.clone()), true), + ]); + + let struct_value = vec![ + ( + fields[0].clone(), + Arc::new(UInt64Array::from(vec![Some(1)])) as ArrayRef, + ), + ( + fields[1].clone(), + Arc::new(StructArray::from(vec![ + ( + fields_b[0].clone(), + Arc::new(UInt64Array::from(vec![Some(2)])) as ArrayRef, + ), + ( + fields_b[1].clone(), + Arc::new(UInt64Array::from(vec![Some(3)])) as ArrayRef, + ), + ])) as ArrayRef, + ), + ]; + + let struct_value_with_nulls = vec![ + ( + fields[0].clone(), + Arc::new(UInt64Array::from(vec![Some(1)])) as ArrayRef, + ), + ( + fields[1].clone(), + Arc::new(StructArray::from(( + vec![ + ( + fields_b[0].clone(), + Arc::new(UInt64Array::from(vec![Some(2)])) as ArrayRef, + ), + ( + fields_b[1].clone(), + Arc::new(UInt64Array::from(vec![Some(3)])) as ArrayRef, + ), + ], + Buffer::from(&[0]), + ))) as ArrayRef, + ), + ]; + + let scalars = vec![ + // all null + ScalarValue::Struct(Arc::new(StructArray::from(( + struct_value.clone(), + Buffer::from(&[0]), + )))), + // field 1 valid, field 2 null + ScalarValue::Struct(Arc::new(StructArray::from(( + struct_value_with_nulls.clone(), + Buffer::from(&[1]), + )))), + // all valid + ScalarValue::Struct(Arc::new(StructArray::from(( + struct_value.clone(), + Buffer::from(&[1]), + )))), + ]; + + println!("scalars: {scalars:#?}"); + + let check_array = |array| { + let is_null = is_null(&array).unwrap(); + assert_eq!(is_null, BooleanArray::from(vec![true, false, false])); + + let formatted = pretty_format_columns("col", &[array]).unwrap().to_string(); + let formatted = formatted.split('\n').collect::>(); + let expected = vec![ + "+---------------------------+", + "| col |", + "+---------------------------+", + "| |", + "| {a: 1, b: } |", + "| {a: 1, b: {ba: 2, bb: 3}} |", + "+---------------------------+", + ]; + assert_eq!( + formatted, expected, + "Actual:\n{formatted:#?}\n\nExpected:\n{expected:#?}" + ); + }; + + // test `ScalarValue::iter_to_array` + let array = ScalarValue::iter_to_array(scalars.clone()).unwrap(); + check_array(array); + + // test `ScalarValue::to_array` / `ScalarValue::to_array_of_size` + let arrays = scalars + .iter() + .map(ScalarValue::to_array) + .collect::>>() + .expect("Failed to convert to array"); + let arrays = arrays.iter().map(|a| a.as_ref()).collect::>(); + let array = arrow::compute::concat(&arrays).unwrap(); + check_array(array); + } #[test] fn test_build_timestamp_millisecond_list() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 71983d5ff556..d32b973b36b2 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -21,8 +21,8 @@ use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; use arrow::array::{ArrayRef, FixedSizeListArray}; -use arrow::csv::WriterBuilder; use arrow::array::{BooleanArray, Int32Array}; +use arrow::csv::WriterBuilder; use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,