Skip to content

Commit

Permalink
Expose array sort (apache#764)
Browse files Browse the repository at this point in the history
  • Loading branch information
timsaucer authored Jul 20, 2024
1 parent aa8aa9c commit f00b8ee
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 0 deletions.
22 changes: 22 additions & 0 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,28 @@ def list_replace_all(array: Expr, from_val: Expr, to_val: Expr) -> Expr:
return array_replace_all(array, from_val, to_val)


def array_sort(array: Expr, descending: bool = False, null_first: bool = False) -> Expr:
"""Sort an array.
Args:
array: The input array to sort.
descending: If True, sorts in descending order.
null_first: If True, nulls will be returned at the beginning of the array.
"""
desc = "DESC" if descending else "ASC"
nulls_first = "NULLS FIRST" if null_first else "NULLS LAST"
return Expr(
f.array_sort(
array.expr, Expr.literal(desc).expr, Expr.literal(nulls_first).expr
)
)


def list_sort(array: Expr, descending: bool = False, null_first: bool = False) -> Expr:
"""This is an alias for ``array_sort``."""
return array_sort(array, descending=descending, null_first=null_first)


def array_slice(
array: Expr, begin: Expr, end: Expr, stride: Expr | None = None
) -> Expr:
Expand Down
8 changes: 8 additions & 0 deletions python/datafusion/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,14 @@ def py_flatten(arr):
lambda col: f.list_replace_all(col, literal(3.0), literal(4.0)),
lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
],
[
lambda col: f.array_sort(col, descending=True, null_first=True),
lambda data: [np.sort(arr)[::-1] for arr in data],
],
[
lambda col: f.list_sort(col, descending=False, null_first=False),
lambda data: [np.sort(arr) for arr in data],
],
[
lambda col: f.array_slice(col, literal(2), literal(4)),
lambda data: [arr[1:4] for arr in data],
Expand Down
4 changes: 4 additions & 0 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,8 @@ array_fn!(array_replace_n, array from to max);
array_fn!(list_replace_n, array_replace_n, array from to max);
array_fn!(array_replace_all, array from to);
array_fn!(list_replace_all, array_replace_all, array from to);
array_fn!(array_sort, array desc null_first);
array_fn!(list_sort, array_sort, array desc null_first);
array_fn!(array_intersect, first_array second_array);
array_fn!(list_intersect, array_intersect, first_array second_array);
array_fn!(array_union, array1 array2);
Expand Down Expand Up @@ -936,6 +938,8 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(list_replace_n))?;
m.add_wrapped(wrap_pyfunction!(array_replace_all))?;
m.add_wrapped(wrap_pyfunction!(list_replace_all))?;
m.add_wrapped(wrap_pyfunction!(array_sort))?;
m.add_wrapped(wrap_pyfunction!(list_sort))?;
m.add_wrapped(wrap_pyfunction!(array_slice))?;
m.add_wrapped(wrap_pyfunction!(list_slice))?;
m.add_wrapped(wrap_pyfunction!(flatten))?;
Expand Down

0 comments on commit f00b8ee

Please sign in to comment.