diff --git a/polars/polars-lazy/src/frame/mod.rs b/polars/polars-lazy/src/frame/mod.rs index a02148ce916b..8e81642782f7 100644 --- a/polars/polars-lazy/src/frame/mod.rs +++ b/polars/polars-lazy/src/frame/mod.rs @@ -1305,7 +1305,7 @@ impl LazyGroupBy { .flat_map(|k| expr_to_leaf_column_names(k).into_iter()) .collect::>(); - self.agg([col("*").exclude(&keys).head(n).list().keep_name()]) + self.agg([col("*").exclude(&keys).head(n).keep_name()]) .explode([col("*").exclude(&keys)]) } diff --git a/polars/polars-lazy/src/physical_plan/executors/groupby_dynamic.rs b/polars/polars-lazy/src/physical_plan/executors/groupby_dynamic.rs index c7d7d815eb06..01c3fa95d617 100644 --- a/polars/polars-lazy/src/physical_plan/executors/groupby_dynamic.rs +++ b/polars/polars-lazy/src/physical_plan/executors/groupby_dynamic.rs @@ -28,7 +28,7 @@ impl GroupByDynamicExec { // if the periods are larger than the intervals, // the groups overlap if self.options.every < self.options.period { - state.flags |= StateFlags::OVERLAPPING_GROUPS + state.set_has_overlapping_groups(); } let keys = self diff --git a/polars/polars-lazy/src/physical_plan/executors/groupby_rolling.rs b/polars/polars-lazy/src/physical_plan/executors/groupby_rolling.rs index 3b2f9be81f7b..728b096bf653 100644 --- a/polars/polars-lazy/src/physical_plan/executors/groupby_rolling.rs +++ b/polars/polars-lazy/src/physical_plan/executors/groupby_rolling.rs @@ -68,7 +68,7 @@ impl GroupByRollingExec { }; // a rolling groupby has overlapping windows - state.flags |= StateFlags::OVERLAPPING_GROUPS; + state.set_has_overlapping_groups(); let agg_columns = POOL.install(|| { self.aggs diff --git a/polars/polars-lazy/src/physical_plan/executors/mod.rs b/polars/polars-lazy/src/physical_plan/executors/mod.rs index be9e81bdd57a..fbc0b64c42f6 100644 --- a/polars/polars-lazy/src/physical_plan/executors/mod.rs +++ b/polars/polars-lazy/src/physical_plan/executors/mod.rs @@ -52,7 +52,6 @@ pub(super) use self::stack::*; pub(super) use self::udf::*; pub(super) use self::union::*; use super::*; -use crate::physical_plan::state::StateFlags; fn execute_projection_cached_window_fns( df: &DataFrame, @@ -117,9 +116,9 @@ fn execute_projection_cached_window_fns( // don't bother caching if we only have a single window function in this partition if partition.1.len() == 1 { - state.flags.remove(StateFlags::CACHE_WINDOW_EXPR) + state.remove_cache_window_flag(); } else { - state.flags.insert(StateFlags::CACHE_WINDOW_EXPR); + state.insert_cache_window_flag(); } partition.1.sort_unstable_by_key(|(_idx, explode, _)| { @@ -136,12 +135,12 @@ fn execute_projection_cached_window_fns( .count() == 1 { - state.flags.insert(StateFlags::CACHE_WINDOW_EXPR) + state.insert_cache_window_flag(); } // caching more than one window expression is a complicated topic for another day // see issue #2523 else { - state.flags.remove(StateFlags::CACHE_WINDOW_EXPR) + state.remove_cache_window_flag(); } let s = e.evaluate(df, &state)?; diff --git a/polars/polars-lazy/src/physical_plan/expressions/aggregation.rs b/polars/polars-lazy/src/physical_plan/expressions/aggregation.rs index 1bcd32c56ad2..35f3359e1075 100644 --- a/polars/polars-lazy/src/physical_plan/expressions/aggregation.rs +++ b/polars/polars-lazy/src/physical_plan/expressions/aggregation.rs @@ -169,7 +169,14 @@ impl PhysicalExpr for AggregationExpr { } GroupByMethod::List => { let agg = ac.aggregated(); - rename_series(agg, &keep_name) + + if state.unset_finalize_window_as_list() { + rename_series(agg, &keep_name) + } else { + let ca = agg.list().unwrap(); + let s = run_list_agg(ca); + rename_series(s, &keep_name) + } } GroupByMethod::Groups => { let mut column: ListChunked = ac.groups().as_list_chunked(); @@ -430,7 +437,7 @@ impl PartitionedAggregation for AggregationExpr { if can_fast_explode { ca.set_fast_explode() } - Ok(ca.into_series()) + Ok(run_list_agg(&ca)) } GroupByMethod::First => { let mut agg = unsafe { partitioned.agg_first(groups) }; @@ -531,3 +538,19 @@ impl PhysicalExpr for AggQuantileExpr { true } } + +fn run_list_agg(ca: &ListChunked) -> Series { + assert_eq!(ca.chunks().len(), 1); + let arr = ca.chunks()[0].clone(); + + let offsets = (0i64..(ca.len() as i64 + 1)).collect::>(); + let offsets = unsafe { Offsets::new_unchecked(offsets) }; + + let new_arr = LargeListArray::new( + DataType::List(Box::new(ca.dtype().clone())).to_arrow(), + offsets.into(), + arr, + None, + ); + unsafe { ListChunked::from_chunks(ca.name(), vec![Box::new(new_arr)]).into_series() } +} diff --git a/polars/polars-lazy/src/physical_plan/expressions/apply.rs b/polars/polars-lazy/src/physical_plan/expressions/apply.rs index 5f43cf4449c6..6440cc851745 100644 --- a/polars/polars-lazy/src/physical_plan/expressions/apply.rs +++ b/polars/polars-lazy/src/physical_plan/expressions/apply.rs @@ -124,7 +124,7 @@ impl PhysicalExpr for ApplyExpr { if self.inputs.len() == 1 { let mut ac = self.inputs[0].evaluate_on_groups(df, groups, state)?; - match (state.overlapping_groups(), self.collect_groups) { + match (state.has_overlapping_groups(), self.collect_groups) { (_, ApplyOptions::ApplyList) => { let s = self.function.call_udf(&mut [ac.aggregated()])?; ac.with_series(s, true); @@ -212,7 +212,7 @@ impl PhysicalExpr for ApplyExpr { } else { let mut acs = self.prepare_multiple_inputs(df, groups, state)?; - match (state.overlapping_groups(), self.collect_groups) { + match (state.has_overlapping_groups(), self.collect_groups) { (_, ApplyOptions::ApplyList) => { let mut s = acs.iter_mut().map(|ac| ac.aggregated()).collect::>(); let s = self.function.call_udf(&mut s)?; @@ -235,7 +235,7 @@ impl PhysicalExpr for ApplyExpr { ApplyOptions::ApplyFlat, AggState::AggregatedFlat(_) | AggState::NotAggregated(_), ) = ( - state.overlapping_groups(), + state.has_overlapping_groups(), self.collect_groups, acs[0].agg_state(), ) { diff --git a/polars/polars-lazy/src/physical_plan/expressions/binary.rs b/polars/polars-lazy/src/physical_plan/expressions/binary.rs index 8234e95ed471..ee2e24fc7d21 100644 --- a/polars/polars-lazy/src/physical_plan/expressions/binary.rs +++ b/polars/polars-lazy/src/physical_plan/expressions/binary.rs @@ -7,7 +7,7 @@ use polars_core::series::unstable::UnstableSeries; use polars_core::POOL; use rayon::prelude::*; -use crate::physical_plan::state::{ExecutionState, StateFlags}; +use crate::physical_plan::state::ExecutionState; use crate::prelude::*; pub struct BinaryExpr { @@ -101,7 +101,7 @@ impl PhysicalExpr for BinaryExpr { fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { let mut state = state.split(); // don't cache window functions as they run in parallel - state.flags.remove(StateFlags::CACHE_WINDOW_EXPR); + state.remove_cache_window_flag(); let (lhs, rhs) = POOL.install(|| { rayon::join( || self.left.evaluate(df, &state), @@ -130,7 +130,7 @@ impl PhysicalExpr for BinaryExpr { match ( ac_l.agg_state(), ac_r.agg_state(), - state.overlapping_groups(), + state.has_overlapping_groups(), ) { // Some aggregations must return boolean masks that fit the group. That's why not all literals can take this path. // only literals that are used in arithmetic diff --git a/polars/polars-lazy/src/physical_plan/expressions/ternary.rs b/polars/polars-lazy/src/physical_plan/expressions/ternary.rs index 56b55d1f60ab..3c05fd7abb59 100644 --- a/polars/polars-lazy/src/physical_plan/expressions/ternary.rs +++ b/polars/polars-lazy/src/physical_plan/expressions/ternary.rs @@ -6,7 +6,7 @@ use polars_core::prelude::*; use polars_core::POOL; use crate::physical_plan::expression_err; -use crate::physical_plan::state::{ExecutionState, StateFlags}; +use crate::physical_plan::state::ExecutionState; use crate::prelude::*; pub struct TernaryExpr { @@ -93,7 +93,7 @@ impl PhysicalExpr for TernaryExpr { fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { let mut state = state.split(); // don't cache window functions as they run in parallel - state.flags.remove(StateFlags::CACHE_WINDOW_EXPR); + state.remove_cache_window_flag(); let mask_series = self.predicate.evaluate(df, &state)?; let mut mask = mask_series.bool()?.clone(); diff --git a/polars/polars-lazy/src/physical_plan/expressions/window.rs b/polars/polars-lazy/src/physical_plan/expressions/window.rs index 1161fd8a6366..65d760ca06c8 100644 --- a/polars/polars-lazy/src/physical_plan/expressions/window.rs +++ b/polars/polars-lazy/src/physical_plan/expressions/window.rs @@ -241,6 +241,7 @@ impl WindowExpr { explicit_list = finishes_list; } } + explicit_list } @@ -424,6 +425,16 @@ impl PhysicalExpr for WindowExpr { }); let explicit_list_agg = self.is_explicit_list_agg(); + // A `sort()` in a window function is one level flatter + // Assume we have column a : i32 + // than a sort in a groupby. A groupby sorts the groups and returns array: list[i32] + // whereas a window function returns array: i32 + // So a `sort().list()` in a groupby returns: list[list[i32]] + // whereas in a window function would return: list[i32] + if explicit_list_agg { + state.set_finalize_window_as_list(); + } + // if we flatten this column we need to make sure the groups are sorted. let mut sort_groups = self.options.explode || // if not diff --git a/polars/polars-lazy/src/physical_plan/planner/lp.rs b/polars/polars-lazy/src/physical_plan/planner/lp.rs index 50b5c5852566..9936d4d60861 100644 --- a/polars/polars-lazy/src/physical_plan/planner/lp.rs +++ b/polars/polars-lazy/src/physical_plan/planner/lp.rs @@ -37,6 +37,7 @@ fn partitionable_gb( let aexpr = expr_arena.get(*agg); let depth = (expr_arena).iter(*agg).count(); + // These single expressions are partitionable if matches!(aexpr, AExpr::Count) { continue; } diff --git a/polars/polars-lazy/src/physical_plan/state.rs b/polars/polars-lazy/src/physical_plan/state.rs index acd12d1f8171..dcee87e512f6 100644 --- a/polars/polars-lazy/src/physical_plan/state.rs +++ b/polars/polars-lazy/src/physical_plan/state.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; +use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::{Mutex, RwLock}; use bitflags::bitflags; @@ -16,6 +17,7 @@ pub type JoinTuplesCache = Arc>>; pub type GroupsProxyCache = Arc>>; bitflags! { + #[repr(transparent)] pub(super) struct StateFlags: u8 { /// More verbose logging const VERBOSE = 0x01; @@ -25,6 +27,13 @@ bitflags! { /// If this is the case, an `explode` will yield more values than rows in original `df`, /// this breaks some assumptions const OVERLAPPING_GROUPS = 0x04; + /// A `sort()` in a window function is one level flatter + /// Assume we have column a : i32 + /// than a sort in a groupby. A groupby sorts the groups and returns array: list[i32] + /// whereas a window function returns array: i32 + /// So a `sort().list()` in a groupby returns: list[list[i32]] + /// whereas in a window function would return: list[i32] + const FINALIZE_WINDOW_AS_LIST = 0x08; } } @@ -43,6 +52,15 @@ impl StateFlags { } flags } + fn as_u8(self) -> u8 { + unsafe { std::mem::transmute(self) } + } +} + +impl From for StateFlags { + fn from(value: u8) -> Self { + unsafe { std::mem::transmute(value) } + } } /// State/ cache that is maintained during the Execution of the physical plan. @@ -59,7 +77,7 @@ pub struct ExecutionState { pub(super) join_tuples: JoinTuplesCache, // every join/union split gets an increment to distinguish between schema state pub(super) branch_idx: usize, - pub(super) flags: StateFlags, + pub(super) flags: AtomicU8, pub(super) ext_contexts: Arc>, node_timer: Option, } @@ -101,7 +119,7 @@ impl ExecutionState { group_tuples: Default::default(), join_tuples: Default::default(), branch_idx: self.branch_idx, - flags: self.flags, + flags: AtomicU8::new(self.flags.load(Ordering::Relaxed)), ext_contexts: self.ext_contexts.clone(), node_timer: self.node_timer.clone(), } @@ -117,7 +135,7 @@ impl ExecutionState { group_tuples: self.group_tuples.clone(), join_tuples: self.join_tuples.clone(), branch_idx: self.branch_idx, - flags: self.flags, + flags: AtomicU8::new(self.flags.load(Ordering::Relaxed)), ext_contexts: self.ext_contexts.clone(), node_timer: self.node_timer.clone(), } @@ -137,7 +155,7 @@ impl ExecutionState { group_tuples: Arc::new(Mutex::new(PlHashMap::default())), join_tuples: Arc::new(Mutex::new(PlHashMap::default())), branch_idx: 0, - flags: StateFlags::init(), + flags: AtomicU8::new(StateFlags::init().as_u8()), ext_contexts: Default::default(), node_timer: None, } @@ -157,7 +175,7 @@ impl ExecutionState { group_tuples: Default::default(), join_tuples: Default::default(), branch_idx: 0, - flags: StateFlags::init(), + flags: AtomicU8::new(StateFlags::init().as_u8()), ext_contexts: Default::default(), node_timer: None, } @@ -201,21 +219,66 @@ impl ExecutionState { lock.clear(); } + fn set_flags(&self, f: &dyn Fn(StateFlags) -> StateFlags) { + let flags: StateFlags = self.flags.load(Ordering::Relaxed).into(); + let flags = f(flags); + self.flags.store(flags.as_u8(), Ordering::Relaxed); + } + /// Indicates that window expression's [`GroupTuples`] may be cached. pub(super) fn cache_window(&self) -> bool { - self.flags.contains(StateFlags::CACHE_WINDOW_EXPR) + let flags: StateFlags = self.flags.load(Ordering::Relaxed).into(); + flags.contains(StateFlags::CACHE_WINDOW_EXPR) } /// Indicates that a groupby operations groups may overlap. /// If this is the case, an `explode` will yield more values than rows in original `df`, /// this breaks some assumptions - pub(super) fn overlapping_groups(&self) -> bool { - self.flags.contains(StateFlags::OVERLAPPING_GROUPS) + pub(super) fn has_overlapping_groups(&self) -> bool { + let flags: StateFlags = self.flags.load(Ordering::Relaxed).into(); + flags.contains(StateFlags::OVERLAPPING_GROUPS) + } + #[cfg(feature = "dynamic_groupby")] + pub(super) fn set_has_overlapping_groups(&self) { + self.set_flags(&|mut flags| { + flags |= StateFlags::OVERLAPPING_GROUPS; + flags + }) } /// More verbose logging pub(super) fn verbose(&self) -> bool { - self.flags.contains(StateFlags::VERBOSE) + let flags: StateFlags = self.flags.load(Ordering::Relaxed).into(); + flags.contains(StateFlags::VERBOSE) + } + + pub(super) fn set_finalize_window_as_list(&self) { + self.set_flags(&|mut flags| { + flags |= StateFlags::FINALIZE_WINDOW_AS_LIST; + flags + }) + } + + pub(super) fn unset_finalize_window_as_list(&self) -> bool { + let mut flags: StateFlags = self.flags.load(Ordering::Relaxed).into(); + let is_set = flags.contains(StateFlags::FINALIZE_WINDOW_AS_LIST); + flags.remove(StateFlags::FINALIZE_WINDOW_AS_LIST); + self.flags.store(flags.as_u8(), Ordering::Relaxed); + is_set + } + + pub(super) fn remove_cache_window_flag(&mut self) { + self.set_flags(&|mut flags| { + flags.remove(StateFlags::CACHE_WINDOW_EXPR); + flags + }); + } + + pub(super) fn insert_cache_window_flag(&mut self) { + self.set_flags(&|mut flags| { + flags.insert(StateFlags::CACHE_WINDOW_EXPR); + flags + }); } } diff --git a/polars/polars-lazy/src/tests/aggregations.rs b/polars/polars-lazy/src/tests/aggregations.rs index 01c1c0c77f3a..4fbacad3de45 100644 --- a/polars/polars-lazy/src/tests/aggregations.rs +++ b/polars/polars-lazy/src/tests/aggregations.rs @@ -12,7 +12,6 @@ fn test_agg_exprs() -> PolarsResult<()> { .groupby_stable([col("cars")]) .agg([(lit(1) - col("A")) .map(|s| Ok(&s * 2), GetOutput::same_type()) - .list() .alias("foo")]) .collect()?; let ca = out.column("foo")?.list()?; @@ -331,7 +330,7 @@ fn test_binary_agg_context_2() -> PolarsResult<()> { .clone() .lazy() .groupby_stable([col("groups")]) - .agg([((col("vals").first() - col("vals")).list()).alias("vals")]) + .agg([(col("vals").first() - col("vals")).alias("vals")]) .collect()?; // 0 - [1, 2] = [0, -1] @@ -349,7 +348,7 @@ fn test_binary_agg_context_2() -> PolarsResult<()> { let out = df .lazy() .groupby_stable([col("groups")]) - .agg([((col("vals")) - col("vals").first()).list().alias("vals")]) + .agg([((col("vals")) - col("vals").first()).alias("vals")]) .collect()?; // [1, 2] - 1 = [0, 1] @@ -393,7 +392,7 @@ fn test_shift_elementwise_issue_2509() -> PolarsResult<()> { .lazy() // Don't use maintain order here! That hides the bug .groupby([col("x")]) - .agg(&[(col("y").shift(-1) + col("x")).list().alias("sum")]) + .agg(&[(col("y").shift(-1) + col("x")).alias("sum")]) .sort("x", Default::default()) .collect()?; diff --git a/polars/polars-lazy/src/tests/logical.rs b/polars/polars-lazy/src/tests/logical.rs index 8a632733baaf..bda536574ba1 100644 --- a/polars/polars-lazy/src/tests/logical.rs +++ b/polars/polars-lazy/src/tests/logical.rs @@ -23,10 +23,8 @@ fn test_duration() -> PolarsResult<()> { ) .groupby([col("groups")]) .agg([ - (col("date") - col("date").first()).list().alias("date"), - (col("datetime") - col("datetime").first()) - .list() - .alias("datetime"), + (col("date") - col("date").first()).alias("date"), + (col("datetime") - col("datetime").first()).alias("datetime"), ]) .explode([col("date"), col("datetime")]) .collect()?; diff --git a/polars/polars-lazy/src/tests/queries.rs b/polars/polars-lazy/src/tests/queries.rs index 972a2b93b089..a490925b0ed2 100644 --- a/polars/polars-lazy/src/tests/queries.rs +++ b/polars/polars-lazy/src/tests/queries.rs @@ -253,7 +253,7 @@ fn test_lazy_query_4() { .clone() .groupby([col("uid")]) .agg([ - col("day").list().alias("day"), + col("day").alias("day"), col("cumcases") .apply(|s: Series| Ok(&s - &(s.shift(1))), GetOutput::same_type()) .alias("diff_cases"), @@ -656,7 +656,7 @@ fn test_lazy_partition_agg() { let out = scan_foods_csv() .groupby([col("category")]) - .agg([col("calories").list()]) + .agg([col("calories")]) .sort("category", Default::default()) .collect() .unwrap(); @@ -1049,10 +1049,7 @@ fn test_multiple_explode() -> PolarsResult<()> { let out = df .lazy() .groupby([col("a")]) - .agg([ - col("b").list().alias("b_list"), - col("c").list().alias("c_list"), - ]) + .agg([col("b").alias("b_list"), col("c").alias("c_list")]) .explode([col("c_list"), col("b_list")]) .collect()?; assert_eq!(out.shape(), (5, 3)); @@ -1322,7 +1319,7 @@ fn test_sort_by() -> PolarsResult<()> { let out = df .lazy() .groupby_stable([col("b")]) - .agg([col("a").sort_by([col("b"), col("c")], [false]).list()]) + .agg([col("a").sort_by([col("b"), col("c")], [false])]) .collect()?; let a = out.column("a")?.explode()?; @@ -1703,45 +1700,6 @@ fn test_drop_and_select() -> PolarsResult<()> { Ok(()) } -#[test] -fn test_groupby_on_lists() -> PolarsResult<()> { - let s0 = Series::new("", [1i32, 2, 3]); - let s1 = Series::new("groups", [4i32, 5]); - - let mut builder = - ListPrimitiveChunkedBuilder::::new("arrays", 10, 10, DataType::Int32); - builder.append_series(&s0); - builder.append_series(&s1); - let s2 = builder.finish().into_series(); - - let df = DataFrame::new(vec![s1, s2])?; - let out = df - .clone() - .lazy() - .groupby([col("groups")]) - .agg([col("arrays").first()]) - .collect()?; - - assert_eq!( - out.column("arrays")?.dtype(), - &DataType::List(Box::new(DataType::Int32)) - ); - - let out = df - .clone() - .lazy() - .groupby([col("groups")]) - .agg([col("arrays").list()]) - .collect()?; - - assert_eq!( - out.column("arrays")?.dtype(), - &DataType::List(Box::new(DataType::List(Box::new(DataType::Int32)))) - ); - - Ok(()) -} - #[test] fn test_single_group_result() -> PolarsResult<()> { // the argsort should not auto explode @@ -2011,26 +1969,3 @@ fn test_partitioned_gb_ternary() -> PolarsResult<()> { Ok(()) } - -#[test] -fn test_foo() -> PolarsResult<()> { - let q1 = df![ - "x" => [1] - ]? - .lazy(); - - let q2 = df![ - "x" => [1], - "y" => [1] - ]? - .lazy(); - - let out = q1 - .clone() - .join(q2.clone(), [col("x")], [col("y")], JoinType::Semi) - .join(q2.clone(), [col("x")], [col("y")], JoinType::Semi) - .select([col("x")]) - .collect()?; - dbg!(out); - Ok(()) -} diff --git a/polars/tests/it/lazy/groupby.rs b/polars/tests/it/lazy/groupby.rs index 021c2c1583fe..cbfd5f133cd9 100644 --- a/polars/tests/it/lazy/groupby.rs +++ b/polars/tests/it/lazy/groupby.rs @@ -53,42 +53,6 @@ fn test_filter_after_tail() -> PolarsResult<()> { Ok(()) } -#[test] -#[cfg(feature = "unique_counts")] -fn test_list_arithmetic_in_groupby() -> PolarsResult<()> { - // specifically make the amount of groups equal to df height. - let df = df![ - "a" => ["foo", "ham", "bar"], - "b" => [1, 2, 3] - ]?; - - let out = df - .lazy() - .groupby_stable([col("a")]) - .agg([ - col("b").list().alias("original"), - (col("b").list() * lit(2)).alias("mult_lit"), - (col("b").list() / lit(2)).alias("div_lit"), - (col("b").list() - lit(2)).alias("min_lit"), - (col("b").list() + lit(2)).alias("plus_lit"), - (col("b").list() % lit(2)).alias("mod_lit"), - (lit(1) + col("b").list()).alias("lit_plus"), - (col("b").unique_counts() + count()).alias("plus_count"), - ]) - .collect()?; - - let cols = ["mult_lit", "div_lit", "plus_count"]; - let out = out.explode(&cols)?.select(&cols)?; - - assert!(out.frame_equal(&df![ - "mult_lit" => [2, 4, 6], - "div_lit"=> [0, 1, 1], - "plus_count" => [2 as IdxSize, 2, 2] - ]?)); - - Ok(()) -} - #[test] fn test_filter_diff_arithmetic() -> PolarsResult<()> { let df = df![ diff --git a/polars/tests/it/lazy/queries.rs b/polars/tests/it/lazy/queries.rs index 8463254cb01a..7c9528c07ef1 100644 --- a/polars/tests/it/lazy/queries.rs +++ b/polars/tests/it/lazy/queries.rs @@ -228,3 +228,45 @@ fn test_apply_multiple_columns() -> PolarsResult<()> { assert_eq!(Vec::from(out), &[Some(16)]); Ok(()) } + +#[test] +fn test_groupby_on_lists() -> PolarsResult<()> { + let s0 = Series::new("", [1i32, 2, 3]); + let s1 = Series::new("groups", [4i32, 5]); + + let mut builder = + ListPrimitiveChunkedBuilder::::new("arrays", 10, 10, DataType::Int32); + builder.append_series(&s0); + builder.append_series(&s1); + let s2 = builder.finish().into_series(); + + let df = DataFrame::new(vec![s1, s2])?; + let out = df + .clone() + .lazy() + .groupby([col("groups")]) + .agg([col("arrays").first()]) + .collect()?; + + assert_eq!( + out.column("arrays")?.dtype(), + &DataType::List(Box::new(DataType::Int32)) + ); + + let out = df + .clone() + .lazy() + .groupby([col("groups")]) + .agg([col("arrays").list()]) + .collect()?; + + // a list of lists + assert_eq!( + out.column("arrays")?.dtype(), + &DataType::List(Box::new(DataType::List(Box::new(DataType::List( + Box::new(DataType::Int32) + ))))) + ); + + Ok(()) +} diff --git a/py-polars/docs/source/reference/dataframe/groupby.rst b/py-polars/docs/source/reference/dataframe/groupby.rst index 62c9d34ff315..8c07dd27451b 100644 --- a/py-polars/docs/source/reference/dataframe/groupby.rst +++ b/py-polars/docs/source/reference/dataframe/groupby.rst @@ -10,7 +10,7 @@ This namespace is available after calling :code:`DataFrame.groupby(...)`. GroupBy.__iter__ GroupBy.agg - GroupBy.agg_list + GroupBy.all GroupBy.apply GroupBy.count GroupBy.first diff --git a/py-polars/polars/internals/dataframe/frame.py b/py-polars/polars/internals/dataframe/frame.py index 4c9d90685ab3..6dca4b4f0d49 100644 --- a/py-polars/polars/internals/dataframe/frame.py +++ b/py-polars/polars/internals/dataframe/frame.py @@ -3751,7 +3751,7 @@ def groupby_dynamic( ... df.groupby_dynamic("time", every="1h", closed="left").agg( ... [ ... pl.col("time").count().alias("time_count"), - ... pl.col("time").list().alias("time_agg_list"), + ... pl.col("time").alias("time_agg_list"), ... ] ... ) ... ) @@ -3853,7 +3853,7 @@ def groupby_dynamic( ... period="3i", ... include_boundaries=True, ... closed="right", - ... ).agg(pl.col("A").list().alias("A_agg_list")) + ... ).agg(pl.col("A").alias("A_agg_list")) ... ) shape: (3, 4) ┌─────────────────┬─────────────────┬─────┬─────────────────┐ diff --git a/py-polars/polars/internals/dataframe/groupby.py b/py-polars/polars/internals/dataframe/groupby.py index d24c969a36eb..ce59292371ab 100644 --- a/py-polars/polars/internals/dataframe/groupby.py +++ b/py-polars/polars/internals/dataframe/groupby.py @@ -94,7 +94,7 @@ def __iter__(self) -> GroupBy[DF]: .lazy() .with_row_count(name=temp_col) .groupby(self.by, maintain_order=self.maintain_order) - .agg(pli.col(temp_col).list()) + .agg(pli.col(temp_col)) .collect(no_optimization=True) ) @@ -663,6 +663,9 @@ def agg_list(self) -> pli.DataFrame: """ Aggregate the groups into Series. + .. deprecated:: 0.16.0 + Use ```all()`` + Examples -------- >>> df = pl.DataFrame({"a": ["one", "two", "one", "two"], "b": [1, 2, 3, 4]}) @@ -678,7 +681,28 @@ def agg_list(self) -> pli.DataFrame: └─────┴───────────┘ """ - return self.agg(pli.all().list()) + return self.agg(pli.all()) + + def all(self) -> pli.DataFrame: + """ + Aggregate the groups into Series. + + Examples + -------- + >>> df = pl.DataFrame({"a": ["one", "two", "one", "two"], "b": [1, 2, 3, 4]}) + >>> df.groupby("a", maintain_order=True).all() + shape: (2, 2) + ┌─────┬───────────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ str ┆ list[i64] │ + ╞═════╪═══════════╡ + │ one ┆ [1, 3] │ + │ two ┆ [2, 4] │ + └─────┴───────────┘ + + """ + return self.agg(pli.all()) class RollingGroupBy(Generic[DF]): @@ -720,7 +744,7 @@ def __iter__(self) -> RollingGroupBy[DF]: closed=self.closed, by=self.by, ) - .agg(pli.col(temp_col).list()) + .agg(pli.col(temp_col)) .collect(no_optimization=True) ) @@ -816,7 +840,7 @@ def __iter__(self) -> DynamicGroupBy[DF]: by=self.by, start_by=self.start_by, ) - .agg(pli.col(temp_col).list()) + .agg(pli.col(temp_col)) .collect(no_optimization=True) ) diff --git a/py-polars/polars/internals/lazyframe/frame.py b/py-polars/polars/internals/lazyframe/frame.py index 820e2f5ddad4..802fa10b0671 100644 --- a/py-polars/polars/internals/lazyframe/frame.py +++ b/py-polars/polars/internals/lazyframe/frame.py @@ -1981,7 +1981,7 @@ def groupby_dynamic( ... .agg( ... [ ... pl.col("time").count().alias("time_count"), - ... pl.col("time").list().alias("time_agg_list"), + ... pl.col("time").alias("time_agg_list"), ... ] ... ) ... ).collect() @@ -2087,7 +2087,7 @@ def groupby_dynamic( ... include_boundaries=True, ... closed="right", ... ) - ... .agg(pl.col("A").list().alias("A_agg_list")) + ... .agg(pl.col("A").alias("A_agg_list")) ... ).collect() shape: (3, 4) ┌─────────────────┬─────────────────┬─────┬─────────────────┐ diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 2c021ac178fe..7234bc64f8ee 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -216,7 +216,7 @@ def test_recursive_logical_type() -> None: df = pl.DataFrame({"str": ["A", "B", "A", "B", "C"], "group": [1, 1, 2, 1, 2]}) df = df.with_columns(pl.col("str").cast(pl.Categorical)) - df_groups = df.groupby("group").agg([pl.col("str").list().alias("cat_list")]) + df_groups = df.groupby("group").agg([pl.col("str").alias("cat_list")]) f = io.BytesIO() df_groups.write_parquet(f, use_pyarrow=True) f.seek(0) @@ -231,7 +231,7 @@ def test_nested_dictionary() -> None: pl.DataFrame({"str": ["A", "B", "A", "B", "C"], "group": [1, 1, 2, 1, 2]}) .with_columns(pl.col("str").cast(pl.Categorical)) .groupby("group") - .agg([pl.col("str").list().alias("cat_list")]) + .agg([pl.col("str").alias("cat_list")]) ) f = io.BytesIO() df.write_parquet(f) diff --git a/py-polars/tests/unit/test_datelike.py b/py-polars/tests/unit/test_datelike.py index 59b2d2469357..b97ddf104473 100644 --- a/py-polars/tests/unit/test_datelike.py +++ b/py-polars/tests/unit/test_datelike.py @@ -712,7 +712,7 @@ def test_truncate_negative_offset() -> None: out = df.groupby_dynamic( "idx", every="2i", period="3i", include_boundaries=True - ).agg(pl.col("A").list()) + ).agg(pl.col("A")) assert out.shape == (3, 4) assert out["A"].to_list() == [["A", "A", "B"], ["B", "B", "B"], ["B", "C"]] @@ -1161,7 +1161,7 @@ def test_rolling_groupby_by_argument() -> None: df = pl.DataFrame({"times": range(10), "groups": [1] * 4 + [2] * 6}) out = df.groupby_rolling("times", period="5i", by=["groups"]).agg( - pl.col("times").list().alias("agg_list") + pl.col("times").alias("agg_list") ) expected = pl.DataFrame( @@ -1549,7 +1549,7 @@ def test_duration_aggregations() -> None: pl.col("duration").max().alias("max"), pl.col("duration").quantile(0.1).alias("quantile"), pl.col("duration").median().alias("median"), - pl.col("duration").list().alias("list"), + pl.col("duration").alias("list"), ] ).to_dict(False) == { "group": ["A", "B"], diff --git a/py-polars/tests/unit/test_df.py b/py-polars/tests/unit/test_df.py index fa93b8d1640b..44afc14901fb 100644 --- a/py-polars/tests/unit/test_df.py +++ b/py-polars/tests/unit/test_df.py @@ -1712,7 +1712,7 @@ def __repr__(self) -> str: df = pl.DataFrame({"groups": [1, 1, 2], "a": foos}) assert sys.getrefcount(foos[0]) == base_count + 1 - out = df.groupby("groups", maintain_order=True).agg(pl.col("a").list().alias("a")) + out = df.groupby("groups", maintain_order=True).agg(pl.col("a").alias("a")) assert sys.getrefcount(foos[0]) == base_count + 2 s = out["a"].arr.explode() assert sys.getrefcount(foos[0]) == base_count + 3 diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 09da991a39a1..363ffb1b7a29 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -36,7 +36,7 @@ def test_lazy() -> None: # test if pl.list is available, this is `to_list` re-exported as list eager = df.groupby("a").agg(pl.list("b")) - assert sorted(eager.rows()) == [(1, [1.0]), (2, [2.0]), (3, [3.0])] + assert sorted(eager.rows()) == [(1, [[1.0]]), (2, [[2.0]]), (3, [[3.0]])] # profile lazyframe operation/plan lazy = df.lazy().groupby("a").agg(pl.list("b")) @@ -455,7 +455,7 @@ def test_head_groupby() -> None: out = ( df.sort(by="price", reverse=True) .groupby(keys, maintain_order=True) - .agg([col("*").exclude(keys).head(2).list().keep_name()]) + .agg([col("*").exclude(keys).head(2).keep_name()]) .explode(col("*").exclude(keys)) ) diff --git a/py-polars/tests/unit/test_lists.py b/py-polars/tests/unit/test_lists.py index e2449e090380..2f6e468063ff 100644 --- a/py-polars/tests/unit/test_lists.py +++ b/py-polars/tests/unit/test_lists.py @@ -427,7 +427,7 @@ def test_arr_contains_categorical() -> None: {"str": ["A", "B", "A", "B", "C"], "group": [1, 1, 2, 1, 2]} ).lazy() df = df.with_columns(pl.col("str").cast(pl.Categorical)) - df_groups = df.groupby("group").agg([pl.col("str").list().alias("str_list")]) + df_groups = df.groupby("group").agg([pl.col("str").alias("str_list")]) assert df_groups.filter(pl.col("str_list").arr.contains("C")).collect().to_dict( False ) == {"group": [2], "str_list": [["A", "C"]]} @@ -491,7 +491,7 @@ def test_groupby_list_column() -> None: pl.DataFrame({"a": ["a", "b", "a"]}) .with_columns(pl.col("a").cast(pl.Categorical)) .groupby("a", maintain_order=True) - .agg(pl.col("a").list().alias("a_list")) + .agg(pl.col("a").alias("a_list")) ) assert df.groupby("a_list", maintain_order=True).first().to_dict(False) == { diff --git a/py-polars/tests/unit/test_projections.py b/py-polars/tests/unit/test_projections.py index c431b7a29c69..371060b44597 100644 --- a/py-polars/tests/unit/test_projections.py +++ b/py-polars/tests/unit/test_projections.py @@ -12,7 +12,7 @@ def test_projection_on_semi_join_4789() -> None: ab = lfa.join(lfb, on="p", how="semi").inspect() - intermediate_agg = (ab.groupby("a").agg([pl.col("a").list().alias("seq")])).select( + intermediate_agg = (ab.groupby("a").agg([pl.col("a").alias("seq")])).select( ["a", "seq"] ) diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index a9579e28e46f..aa263ea846f5 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -247,7 +247,7 @@ def test_opaque_filter_on_lists_3784() -> None: ).lazy() df = df.with_columns(pl.col("str").cast(pl.Categorical)) - df_groups = df.groupby("group").agg([pl.col("str").list().alias("str_list")]) + df_groups = df.groupby("group").agg([pl.col("str").alias("str_list")]) pre = "A" succ = "B" diff --git a/py-polars/tests/unit/test_sort.py b/py-polars/tests/unit/test_sort.py index 17e42030eb08..97314436541a 100644 --- a/py-polars/tests/unit/test_sort.py +++ b/py-polars/tests/unit/test_sort.py @@ -296,7 +296,8 @@ def test_explicit_list_agg_sort_in_groupby() -> None: df = pl.DataFrame({"A": ["a", "a", "a", "b", "b", "a"], "B": [1, 2, 3, 4, 5, 6]}) assert ( df.groupby("A") - .agg(pl.col("B").list().sort(reverse=True)) + # this was col().list().sort() before we changed the logic + .agg(pl.col("B").sort(reverse=True)) .sort("A") .frame_equal(df.groupby("A").agg(pl.col("B").sort(reverse=True)).sort("A")) )