Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 committed Jan 4, 2024
1 parent e18f402 commit addf685
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 87 deletions.
218 changes: 132 additions & 86 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1941,7 +1941,7 @@ impl ScalarValue {
.map(|a| a as &dyn Array)
.collect::<Vec<_>>();
arrow::compute::concat(arrays.as_slice())
.map_err(DataFusionError::ArrowError)?
.map_err(|e| arrow_datafusion_err!(e))?
}
ScalarValue::Date32(e) => {
build_array_from_option!(Date32, Date32Array, e, size)
Expand Down Expand Up @@ -3011,17 +3011,33 @@ impl fmt::Display for ScalarValue {

let columns = struct_arr.columns();
let fields = struct_arr.fields();
let nulls = struct_arr.nulls();

write!(
f,
"{{{}}}",
columns
.iter()
.zip(fields.iter())
.map(|(column, field)| {
let sv = ScalarValue::try_from_array(column, 0).unwrap();
let name = field.name();
format!("{name}:{sv}")
.enumerate()
.map(|(index, (column, field))| {
if nulls.is_some_and(|b| b.is_null(index)) {
format!("{}:NULL", field.name())
} else {
if let DataType::Struct(_) = field.data_type() {
let sv = ScalarValue::Struct(Arc::new(
column.as_struct().to_owned(),
));

let name = field.name();
format!("{name}:{sv}")
} else {
let sv =
ScalarValue::try_from_array(column, 0).unwrap();
let name = field.name();
format!("{name}:{sv}")
}
}
})
.collect::<Vec<_>>()
.join(",")
Expand Down Expand Up @@ -3181,13 +3197,14 @@ mod tests {
use std::cmp::Ordering;
use std::sync::Arc;

use arrow::util::pretty::pretty_format_columns;
use arrow_buffer::Buffer;
use chrono::NaiveDate;
use rand::Rng;

use arrow::buffer::OffsetBuffer;
use arrow::compute::kernels;
use arrow::datatypes::ArrowPrimitiveType;
use arrow::{buffer::OffsetBuffer, compute::is_null};
use arrow_array::ArrowNumericType;

use crate::cast::{as_string_array, as_uint32_array, as_uint64_array};
Expand Down Expand Up @@ -5407,86 +5424,115 @@ mod tests {
}
}

// #[test]
// fn test_struct_nulls() {
// let fields_b = Fields::from(vec![
// Field::new("ba", DataType::UInt64, true),
// Field::new("bb", DataType::UInt64, true),
// ]);
// let fields = Fields::from(vec![
// Field::new("a", DataType::UInt64, true),
// Field::new("b", DataType::Struct(fields_b.clone()), true),
// ]);
// let scalars = vec![
// ScalarValue::Struct(None, fields.clone()),
// ScalarValue::Struct(
// Some(vec![
// ScalarValue::UInt64(None),
// ScalarValue::Struct(None, fields_b.clone()),
// ]),
// fields.clone(),
// ),
// ScalarValue::Struct(
// Some(vec![
// ScalarValue::UInt64(None),
// ScalarValue::Struct(
// Some(vec![ScalarValue::UInt64(None), ScalarValue::UInt64(None)]),
// fields_b.clone(),
// ),
// ]),
// fields.clone(),
// ),
// ScalarValue::Struct(
// Some(vec![
// ScalarValue::UInt64(Some(1)),
// ScalarValue::Struct(
// Some(vec![
// ScalarValue::UInt64(Some(2)),
// ScalarValue::UInt64(Some(3)),
// ]),
// fields_b,
// ),
// ]),
// fields,
// ),
// ];

// let check_array = |array| {
// let is_null = is_null(&array).unwrap();
// assert_eq!(is_null, BooleanArray::from(vec![true, false, false, false]));

// let formatted = pretty_format_columns("col", &[array]).unwrap().to_string();
// let formatted = formatted.split('\n').collect::<Vec<_>>();
// let expected = vec![
// "+---------------------------+",
// "| col |",
// "+---------------------------+",
// "| |",
// "| {a: , b: } |",
// "| {a: , b: {ba: , bb: }} |",
// "| {a: 1, b: {ba: 2, bb: 3}} |",
// "+---------------------------+",
// ];
// assert_eq!(
// formatted, expected,
// "Actual:\n{formatted:#?}\n\nExpected:\n{expected:#?}"
// );
// };

// // test `ScalarValue::iter_to_array`
// let array = ScalarValue::iter_to_array(scalars.clone()).unwrap();
// check_array(array);

// // test `ScalarValue::to_array` / `ScalarValue::to_array_of_size`
// let arrays = scalars
// .iter()
// .map(ScalarValue::to_array)
// .collect::<Result<Vec<_>>>()
// .expect("Failed to convert to array");
// let arrays = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
// let array = arrow::compute::concat(&arrays).unwrap();
// check_array(array);
// }
#[test]
fn test_struct_nulls() {
let fields_b = Fields::from(vec![
Field::new("ba", DataType::UInt64, true),
Field::new("bb", DataType::UInt64, true),
]);
let fields = Fields::from(vec![
Field::new("a", DataType::UInt64, true),
Field::new("b", DataType::Struct(fields_b.clone()), true),
]);

let struct_value = vec![
(
fields[0].clone(),
Arc::new(UInt64Array::from(vec![Some(1)])) as ArrayRef,
),
(
fields[1].clone(),
Arc::new(StructArray::from(vec![
(
fields_b[0].clone(),
Arc::new(UInt64Array::from(vec![Some(2)])) as ArrayRef,
),
(
fields_b[1].clone(),
Arc::new(UInt64Array::from(vec![Some(3)])) as ArrayRef,
),
])) as ArrayRef,
),
];

let struct_value_with_nulls = vec![
(
fields[0].clone(),
Arc::new(UInt64Array::from(vec![Some(1)])) as ArrayRef,
),
(
fields[1].clone(),
Arc::new(StructArray::from((
vec![
(
fields_b[0].clone(),
Arc::new(UInt64Array::from(vec![Some(2)])) as ArrayRef,
),
(
fields_b[1].clone(),
Arc::new(UInt64Array::from(vec![Some(3)])) as ArrayRef,
),
],
Buffer::from(&[0]),
))) as ArrayRef,
),
];

let scalars = vec![
// all null
ScalarValue::Struct(Arc::new(StructArray::from((
struct_value.clone(),
Buffer::from(&[0]),
)))),
// field 1 valid, field 2 null
ScalarValue::Struct(Arc::new(StructArray::from((
struct_value_with_nulls.clone(),
Buffer::from(&[1]),
)))),
// all valid
ScalarValue::Struct(Arc::new(StructArray::from((
struct_value.clone(),
Buffer::from(&[1]),
)))),
];

println!("scalars: {scalars:#?}");

let check_array = |array| {
let is_null = is_null(&array).unwrap();
assert_eq!(is_null, BooleanArray::from(vec![true, false, false]));

let formatted = pretty_format_columns("col", &[array]).unwrap().to_string();
let formatted = formatted.split('\n').collect::<Vec<_>>();
let expected = vec![
"+---------------------------+",
"| col |",
"+---------------------------+",
"| |",
"| {a: 1, b: } |",
"| {a: 1, b: {ba: 2, bb: 3}} |",
"+---------------------------+",
];
assert_eq!(
formatted, expected,
"Actual:\n{formatted:#?}\n\nExpected:\n{expected:#?}"
);
};

// test `ScalarValue::iter_to_array`
let array = ScalarValue::iter_to_array(scalars.clone()).unwrap();
check_array(array);

// test `ScalarValue::to_array` / `ScalarValue::to_array_of_size`
let arrays = scalars
.iter()
.map(ScalarValue::to_array)
.collect::<Result<Vec<_>>>()
.expect("Failed to convert to array");
let arrays = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
let array = arrow::compute::concat(&arrays).unwrap();
check_array(array);
}

#[test]
fn test_build_timestamp_millisecond_list() {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;

use arrow::array::{ArrayRef, FixedSizeListArray};
use arrow::csv::WriterBuilder;
use arrow::array::{BooleanArray, Int32Array};
use arrow::csv::WriterBuilder;
use arrow::datatypes::{
DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType,
IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
Expand Down

0 comments on commit addf685

Please sign in to comment.