Skip to content

Commit

Permalink
Replace macro in array_array to remove duplicate codes
Browse files Browse the repository at this point in the history
Signed-off-by: veeupup <code@tanweime.com>
  • Loading branch information
Veeupup committed Nov 17, 2023
1 parent 9fd0f4e commit a3bbf16
Showing 1 changed file with 37 additions and 143 deletions.
180 changes: 37 additions & 143 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,70 +67,6 @@ macro_rules! downcast_vec {
}};
}

macro_rules! new_builder {
(BooleanBuilder, $len:expr) => {
BooleanBuilder::with_capacity($len)
};
(StringBuilder, $len:expr) => {
StringBuilder::new()
};
(LargeStringBuilder, $len:expr) => {
LargeStringBuilder::new()
};
($el:ident, $len:expr) => {{
<$el>::with_capacity($len)
}};
}

/// Combines multiple arrays into a single ListArray
///
/// $ARGS: slice of arrays, each with $ARRAY_TYPE
/// $ARRAY_TYPE: the type of the list elements
/// $BUILDER_TYPE: the type of ArrayBuilder for the list elements
///
/// Returns: a ListArray where the elements each have the same type as
/// $ARRAY_TYPE and each element have a length of $ARGS.len()
macro_rules! array {
($ARGS:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{
let builder = new_builder!($BUILDER_TYPE, $ARGS[0].len());
let mut builder =
ListBuilder::<$BUILDER_TYPE>::with_capacity(builder, $ARGS.len());

let num_rows = $ARGS[0].len();
assert!(
$ARGS.iter().all(|a| a.len() == num_rows),
"all arguments must have the same number of rows"
);

// for each entry in the array
for index in 0..num_rows {
// for each column
for arg in $ARGS {
match arg.as_any().downcast_ref::<$ARRAY_TYPE>() {
// Copy the source array value into the target ListArray
Some(arr) => {
if arr.is_valid(index) {
builder.values().append_value(arr.value(index));
} else {
builder.values().append_null();
}
}
None => match arg.as_any().downcast_ref::<NullArray>() {
Some(arr) => {
for _ in 0..arr.len() {
builder.values().append_null();
}
}
None => return internal_err!("failed to downcast"),
},
}
}
builder.append(true);
}
Arc::new(builder.finish())
}};
}

/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array.
///
/// # Arguments
Expand Down Expand Up @@ -389,88 +325,46 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result<ArrayRef> {
return plan_err!("Array requires at least one argument");
}

let res = match data_type {
DataType::List(..) => {
let row_count = args[0].len();
let column_count = args.len();
let mut list_arrays = vec![];
let mut list_array_lengths = vec![];
let mut list_valid = BooleanBufferBuilder::new(row_count);
// Construct ListArray per row
for index in 0..row_count {
let mut arrays = vec![];
let mut array_lengths = vec![];
let mut valid = BooleanBufferBuilder::new(column_count);
for arg in args {
if arg.as_any().downcast_ref::<NullArray>().is_some() {
array_lengths.push(0);
valid.append(false);
} else {
let list_arr = as_list_array(arg)?;
let arr = list_arr.value(index);
array_lengths.push(arr.len());
arrays.push(arr);
valid.append(true);
}
}
if arrays.is_empty() {
list_valid.append(false);
list_array_lengths.push(0);
} else {
let buffer = valid.finish();
// Assume all list arrays have the same data type
let data_type = arrays[0].data_type();
let field = Arc::new(Field::new("item", data_type.to_owned(), true));
let elements = arrays.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
let values = compute::concat(elements.as_slice())?;
let list_arr = ListArray::new(
field,
OffsetBuffer::from_lengths(array_lengths),
values,
Some(NullBuffer::new(buffer)),
);
list_valid.append(true);
list_array_lengths.push(list_arr.len());
list_arrays.push(list_arr);
}
let mut data = vec![];
let mut total_len = 0;
for arg in args {
let arg_data = if arg.as_any().is::<NullArray>() {
ArrayData::new_empty(&data_type)
} else {
arg.to_data()
};
total_len += arg_data.len();
data.push(arg_data);
}
let mut offsets = Vec::with_capacity(total_len);
offsets.push(0);

let capacity = Capacities::Array(total_len);
let data_ref = data.iter().map(|d| d).collect::<Vec<_>>();
let mut mutable = MutableArrayData::with_capacities(data_ref, true, capacity);

let num_rows = args[0].len();
for row_idx in 0..num_rows {
for (arr_idx, arg) in args.iter().enumerate() {
if !arg.as_any().is::<NullArray>()
&& !arg.is_null(row_idx)
&& arg.is_valid(row_idx)
{
mutable.extend(arr_idx, row_idx, row_idx + 1);
} else {
mutable.extend_nulls(1);
}
// Construct ListArray for all rows
let buffer = list_valid.finish();
// Assume all list arrays have the same data type
let data_type = list_arrays[0].data_type();
let field = Arc::new(Field::new("item", data_type.to_owned(), true));
let elements = list_arrays
.iter()
.map(|x| x as &dyn Array)
.collect::<Vec<_>>();
let values = compute::concat(elements.as_slice())?;
let list_arr = ListArray::new(
field,
OffsetBuffer::from_lengths(list_array_lengths),
values,
Some(NullBuffer::new(buffer)),
);
Arc::new(list_arr)
}
DataType::Utf8 => array!(args, StringArray, StringBuilder),
DataType::LargeUtf8 => array!(args, LargeStringArray, LargeStringBuilder),
DataType::Boolean => array!(args, BooleanArray, BooleanBuilder),
DataType::Float32 => array!(args, Float32Array, Float32Builder),
DataType::Float64 => array!(args, Float64Array, Float64Builder),
DataType::Int8 => array!(args, Int8Array, Int8Builder),
DataType::Int16 => array!(args, Int16Array, Int16Builder),
DataType::Int32 => array!(args, Int32Array, Int32Builder),
DataType::Int64 => array!(args, Int64Array, Int64Builder),
DataType::UInt8 => array!(args, UInt8Array, UInt8Builder),
DataType::UInt16 => array!(args, UInt16Array, UInt16Builder),
DataType::UInt32 => array!(args, UInt32Array, UInt32Builder),
DataType::UInt64 => array!(args, UInt64Array, UInt64Builder),
data_type => {
return not_impl_err!("Array is not implemented for type '{data_type:?}'.")
}
};
offsets.push(mutable.len() as i32);
}

Ok(res)
let data = mutable.freeze();
Ok(Arc::new(ListArray::try_new(
Arc::new(Field::new("item", data_type, true)),
OffsetBuffer::new(offsets.into()),
arrow_array::make_array(data),
None,
)?))
}

/// `make_array` SQL function
Expand Down

0 comments on commit a3bbf16

Please sign in to comment.