Skip to content

Commit

Permalink
feat: list.join's separator can be expression (pola-rs#11167)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored Sep 18, 2023
1 parent 89c1643 commit 040c53b
Show file tree
Hide file tree
Showing 13 changed files with 114 additions and 49 deletions.
8 changes: 8 additions & 0 deletions crates/polars-core/src/chunked_array/list/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,14 @@ impl ListChunked {
unsafe { self.amortized_iter().map(f).collect_ca(self.name()) }
}

pub fn for_each_amortized<'a, F>(&'a self, f: F)
where
F: FnMut(Option<UnstableSeries<'a>>),
{
// SAFETY: unstable series never lives longer than the iterator.
unsafe { self.amortized_iter().for_each(f) }
}

/// Apply a closure `F` elementwise.
#[must_use]
pub fn apply_amortized<'a, F>(&'a self, mut f: F) -> Self
Expand Down
73 changes: 56 additions & 17 deletions crates/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,62 @@ fn cast_rhs(
pub trait ListNameSpaceImpl: AsList {
/// In case the inner dtype [`DataType::Utf8`], the individual items will be joined into a
/// single string separated by `separator`.
fn lst_join(&self, separator: &str) -> PolarsResult<Utf8Chunked> {
fn lst_join(&self, separator: &Utf8Chunked) -> PolarsResult<Utf8Chunked> {
let ca = self.as_list();
match ca.inner_dtype() {
DataType::Utf8 => {
// used to amortize heap allocs
let mut buf = String::with_capacity(128);
DataType::Utf8 => match separator.len() {
1 => match separator.get(0) {
Some(separator) => self.join_literal(separator),
_ => Ok(Utf8Chunked::full_null(ca.name(), ca.len())),
},
_ => self.join_many(separator),
},
dt => polars_bail!(op = "`lst.join`", got = dt, expected = "Utf8"),
}
}

let mut builder = Utf8ChunkedBuilder::new(
ca.name(),
ca.len(),
ca.get_values_size() + separator.len() * ca.len(),
);
fn join_literal(&self, separator: &str) -> PolarsResult<Utf8Chunked> {
let ca = self.as_list();
// used to amortize heap allocs
let mut buf = String::with_capacity(128);
let mut builder = Utf8ChunkedBuilder::new(
ca.name(),
ca.len(),
ca.get_values_size() + separator.len() * ca.len(),
);

ca.for_each_amortized(|opt_s| {
let opt_val = opt_s.map(|s| {
// make sure that we don't write values of previous iteration
buf.clear();
let ca = s.as_ref().utf8().unwrap();
let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null"));

for val in iter {
buf.write_str(val).unwrap();
buf.write_str(separator).unwrap();
}
// last value should not have a separator, so slice that off
// saturating sub because there might have been nothing written.
&buf[..buf.len().saturating_sub(separator.len())]
});
builder.append_option(opt_val)
});
Ok(builder.finish())
}

// SAFETY: unstable series never lives longer than the iterator.
unsafe {
ca.amortized_iter().for_each(|opt_s| {
fn join_many(&self, separator: &Utf8Chunked) -> PolarsResult<Utf8Chunked> {
let ca = self.as_list();
// used to amortize heap allocs
let mut buf = String::with_capacity(128);
let mut builder =
Utf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size() + ca.len());
// SAFETY: unstable series never lives longer than the iterator.
unsafe {
ca.amortized_iter()
.zip(separator)
.for_each(|(opt_s, opt_sep)| match opt_sep {
Some(separator) => {
let opt_val = opt_s.map(|s| {
// make sure that we don't write values of previous iteration
buf.clear();
Expand All @@ -107,12 +147,11 @@ pub trait ListNameSpaceImpl: AsList {
&buf[..buf.len().saturating_sub(separator.len())]
});
builder.append_option(opt_val)
})
};
Ok(builder.finish())
},
dt => polars_bail!(op = "`lst.join`", got = dt, expected = "Utf8"),
},
_ => builder.append_null(),
})
}
Ok(builder.finish())
}

fn lst_max(&self) -> Series {
Expand Down
8 changes: 8 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub enum ListFunction {
Any,
#[cfg(feature = "list_any_all")]
All,
Join,
}

impl Display for ListFunction {
Expand All @@ -45,6 +46,7 @@ impl Display for ListFunction {
Any => "any",
#[cfg(feature = "list_any_all")]
All => "all",
Join => "join",
};
write!(f, "{name}")
}
Expand Down Expand Up @@ -279,3 +281,9 @@ pub(super) fn lst_any(s: &Series) -> PolarsResult<Series> {
pub(super) fn lst_all(s: &Series) -> PolarsResult<Series> {
s.list()?.lst_all()
}

pub(super) fn join(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].list()?;
let separator = s[1].utf8()?;
Ok(ca.lst_join(separator)?.into_series())
}
1 change: 1 addition & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
Any => map!(list::lst_any),
#[cfg(feature = "list_any_all")]
All => map!(list::lst_all),
Join => map_as_slice!(list::join),
}
},
#[cfg(feature = "dtype-array")]
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ impl FunctionExpr {
Any => mapper.with_dtype(DataType::Boolean),
#[cfg(feature = "list_any_all")]
All => mapper.with_dtype(DataType::Boolean),
Join => mapper.with_dtype(DataType::Utf8),
}
},
#[cfg(feature = "dtype-array")]
Expand Down
18 changes: 6 additions & 12 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,12 @@ impl ListNameSpace {
/// Join all string items in a sublist and place a separator between them.
/// # Error
/// This errors if inner type of list `!= DataType::Utf8`.
pub fn join(self, separator: &str) -> Expr {
let separator = separator.to_string();
self.0
.map(
move |s| {
s.list()?
.lst_join(&separator)
.map(|ca| Some(ca.into_series()))
},
GetOutput::from_type(DataType::Utf8),
)
.with_fmt("list.join")
pub fn join(self, separator: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::Join),
&[separator],
false,
)
}

/// Return the index of the minimal value of every sublist
Expand Down
9 changes: 1 addition & 8 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -682,14 +682,7 @@ impl SqlFunctionVisitor<'_> {
ArrayReverse => self.visit_unary(|e| e.list().reverse()),
ArraySum => self.visit_unary(|e| e.list().sum()),
ArrayToString => self.try_visit_binary(|e, s| {
let sep = match s {
Expr::Literal(LiteralValue::Utf8(ref sep)) => sep,
_ => {
polars_bail!(InvalidOperation: "Invalid 'separator' for ArrayToString: {}", function.args[1]);
}
};

Ok(e.list().join(sep))
Ok(e.list().join(s))
}),
ArrayUnique => self.visit_unary(|e| e.list().unique()),
Explode => self.visit_unary(|e| e.explode()),
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-sql/tests/functions_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ fn array_to_string() {
.lazy()
.group_by([col("b")])
.agg([col("a")])
.select(&[col("b"), col("a").list().join(", ").alias("as")])
.select(&[col("b"), col("a").list().join(lit(", ")).alias("as")])
.sort_by_exprs(vec![col("b"), col("as")], vec![false, false], false, true)
.collect()
.unwrap();
Expand Down
16 changes: 15 additions & 1 deletion py-polars/polars/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def contains(
item = parse_as_expression(item, str_as_lit=True)
return wrap_expr(self._pyexpr.list_contains(item))

def join(self, separator: str) -> Expr:
def join(self, separator: IntoExpr) -> Expr:
"""
Join all string items in a sublist and place a separator between them.
Expand Down Expand Up @@ -489,7 +489,21 @@ def join(self, separator: str) -> Expr:
│ x y │
└───────┘
>>> df = pl.DataFrame(
... {"s": [["a", "b", "c"], ["x", "y"]], "separator": ["*", "_"]}
... )
>>> df.select(pl.col("s").list.join(pl.col("separator")))
shape: (2, 1)
┌───────┐
│ s │
│ --- │
│ str │
╞═══════╡
│ a*b*c │
│ x_y │
└───────┘
"""
separator = parse_as_expression(separator, str_as_lit=True)
return wrap_expr(self._pyexpr.list_join(separator))

def arg_min(self) -> Expr:
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/series/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from polars import Expr, Series
from polars.polars import PySeries
from polars.type_aliases import NullBehavior, ToStructStrategy
from polars.type_aliases import IntoExpr, NullBehavior, ToStructStrategy


@expr_dispatch
Expand Down Expand Up @@ -198,7 +198,7 @@ def take(
def __getitem__(self, item: int) -> Series:
return self.get(item)

def join(self, separator: str) -> Series:
def join(self, separator: IntoExpr) -> Series:
"""
Join all string items in a sublist and place a separator between them.
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ impl PyExpr {
self.inner.clone().list().get(index.inner).into()
}

fn list_join(&self, separator: &str) -> Self {
self.inner.clone().list().join(separator).into()
fn list_join(&self, separator: PyExpr) -> Self {
self.inner.clone().list().join(separator.inner).into()
}

fn list_lengths(&self) -> Self {
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/namespaces/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ def test_list_concat() -> None:
assert out_s[0].to_list() == [1, 2, 4, 1]


def test_list_join() -> None:
df = pl.DataFrame(
{
"a": [["ab", "c", "d"], ["e", "f"], ["g"], [], None],
"separator": ["&", None, "*", "_", "*"],
}
)
out = df.select(pl.col("a").list.join("-"))
assert out.to_dict(False) == {"a": ["ab-c-d", "e-f", "g", "", None]}
out = df.select(pl.col("a").list.join(pl.col("separator")))
assert out.to_dict(False) == {"a": ["ab&c&d", None, "g", "", None]}


def test_list_arr_empty() -> None:
df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []]})

Expand Down
6 changes: 0 additions & 6 deletions py-polars/tests/unit/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,6 @@ def test_filter_where() -> None:
assert_frame_equal(result_filter, expected)


def test_list_join_strings() -> None:
s = pl.Series("a", [["ab", "c", "d"], ["e", "f"], ["g"], []])
expected = pl.Series("a", ["ab-c-d", "e-f", "g", ""])
assert_series_equal(s.list.join("-"), expected)


def test_count_expr() -> None:
df = pl.DataFrame({"a": [1, 2, 3, 3, 3], "b": ["a", "a", "b", "a", "a"]})

Expand Down

0 comments on commit 040c53b

Please sign in to comment.