Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust, python)!: Formalize list aggregation difference between groupbys, selection and window functions #6487

Merged
merged 5 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ impl LazyGroupBy {
.flat_map(|k| expr_to_leaf_column_names(k).into_iter())
.collect::<Vec<_>>();

self.agg([col("*").exclude(&keys).head(n).list().keep_name()])
self.agg([col("*").exclude(&keys).head(n).keep_name()])
.explode([col("*").exclude(&keys)])
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl GroupByDynamicExec {
// if the periods are larger than the intervals,
// the groups overlap
if self.options.every < self.options.period {
state.flags |= StateFlags::OVERLAPPING_GROUPS
state.set_has_overlapping_groups();
}

let keys = self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl GroupByRollingExec {
};

// a rolling groupby has overlapping windows
state.flags |= StateFlags::OVERLAPPING_GROUPS;
state.set_has_overlapping_groups();

let agg_columns = POOL.install(|| {
self.aggs
Expand Down
9 changes: 4 additions & 5 deletions polars/polars-lazy/src/physical_plan/executors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ pub(super) use self::stack::*;
pub(super) use self::udf::*;
pub(super) use self::union::*;
use super::*;
use crate::physical_plan::state::StateFlags;

fn execute_projection_cached_window_fns(
df: &DataFrame,
Expand Down Expand Up @@ -117,9 +116,9 @@ fn execute_projection_cached_window_fns(

// don't bother caching if we only have a single window function in this partition
if partition.1.len() == 1 {
state.flags.remove(StateFlags::CACHE_WINDOW_EXPR)
state.remove_cache_window_flag();
} else {
state.flags.insert(StateFlags::CACHE_WINDOW_EXPR);
state.insert_cache_window_flag();
}

partition.1.sort_unstable_by_key(|(_idx, explode, _)| {
Expand All @@ -136,12 +135,12 @@ fn execute_projection_cached_window_fns(
.count()
== 1
{
state.flags.insert(StateFlags::CACHE_WINDOW_EXPR)
state.insert_cache_window_flag();
}
// caching more than one window expression is a complicated topic for another day
// see issue #2523
else {
state.flags.remove(StateFlags::CACHE_WINDOW_EXPR)
state.remove_cache_window_flag();
}

let s = e.evaluate(df, &state)?;
Expand Down
27 changes: 25 additions & 2 deletions polars/polars-lazy/src/physical_plan/expressions/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,14 @@ impl PhysicalExpr for AggregationExpr {
}
GroupByMethod::List => {
let agg = ac.aggregated();
rename_series(agg, &keep_name)

if state.unset_finalize_window_as_list() {
rename_series(agg, &keep_name)
} else {
let ca = agg.list().unwrap();
let s = run_list_agg(ca);
rename_series(s, &keep_name)
}
}
GroupByMethod::Groups => {
let mut column: ListChunked = ac.groups().as_list_chunked();
Expand Down Expand Up @@ -430,7 +437,7 @@ impl PartitionedAggregation for AggregationExpr {
if can_fast_explode {
ca.set_fast_explode()
}
Ok(ca.into_series())
Ok(run_list_agg(&ca))
}
GroupByMethod::First => {
let mut agg = unsafe { partitioned.agg_first(groups) };
Expand Down Expand Up @@ -531,3 +538,19 @@ impl PhysicalExpr for AggQuantileExpr {
true
}
}

fn run_list_agg(ca: &ListChunked) -> Series {
assert_eq!(ca.chunks().len(), 1);
let arr = ca.chunks()[0].clone();

let offsets = (0i64..(ca.len() as i64 + 1)).collect::<Vec<_>>();
let offsets = unsafe { Offsets::new_unchecked(offsets) };

let new_arr = LargeListArray::new(
DataType::List(Box::new(ca.dtype().clone())).to_arrow(),
offsets.into(),
arr,
None,
);
unsafe { ListChunked::from_chunks(ca.name(), vec![Box::new(new_arr)]).into_series() }
}
6 changes: 3 additions & 3 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl PhysicalExpr for ApplyExpr {
if self.inputs.len() == 1 {
let mut ac = self.inputs[0].evaluate_on_groups(df, groups, state)?;

match (state.overlapping_groups(), self.collect_groups) {
match (state.has_overlapping_groups(), self.collect_groups) {
(_, ApplyOptions::ApplyList) => {
let s = self.function.call_udf(&mut [ac.aggregated()])?;
ac.with_series(s, true);
Expand Down Expand Up @@ -212,7 +212,7 @@ impl PhysicalExpr for ApplyExpr {
} else {
let mut acs = self.prepare_multiple_inputs(df, groups, state)?;

match (state.overlapping_groups(), self.collect_groups) {
match (state.has_overlapping_groups(), self.collect_groups) {
(_, ApplyOptions::ApplyList) => {
let mut s = acs.iter_mut().map(|ac| ac.aggregated()).collect::<Vec<_>>();
let s = self.function.call_udf(&mut s)?;
Expand All @@ -235,7 +235,7 @@ impl PhysicalExpr for ApplyExpr {
ApplyOptions::ApplyFlat,
AggState::AggregatedFlat(_) | AggState::NotAggregated(_),
) = (
state.overlapping_groups(),
state.has_overlapping_groups(),
self.collect_groups,
acs[0].agg_state(),
) {
Expand Down
6 changes: 3 additions & 3 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use polars_core::series::unstable::UnstableSeries;
use polars_core::POOL;
use rayon::prelude::*;

use crate::physical_plan::state::{ExecutionState, StateFlags};
use crate::physical_plan::state::ExecutionState;
use crate::prelude::*;

pub struct BinaryExpr {
Expand Down Expand Up @@ -101,7 +101,7 @@ impl PhysicalExpr for BinaryExpr {
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Series> {
let mut state = state.split();
// don't cache window functions as they run in parallel
state.flags.remove(StateFlags::CACHE_WINDOW_EXPR);
state.remove_cache_window_flag();
let (lhs, rhs) = POOL.install(|| {
rayon::join(
|| self.left.evaluate(df, &state),
Expand Down Expand Up @@ -130,7 +130,7 @@ impl PhysicalExpr for BinaryExpr {
match (
ac_l.agg_state(),
ac_r.agg_state(),
state.overlapping_groups(),
state.has_overlapping_groups(),
) {
// Some aggregations must return boolean masks that fit the group. That's why not all literals can take this path.
// only literals that are used in arithmetic
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/physical_plan/expressions/ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use polars_core::prelude::*;
use polars_core::POOL;

use crate::physical_plan::expression_err;
use crate::physical_plan::state::{ExecutionState, StateFlags};
use crate::physical_plan::state::ExecutionState;
use crate::prelude::*;

pub struct TernaryExpr {
Expand Down Expand Up @@ -93,7 +93,7 @@ impl PhysicalExpr for TernaryExpr {
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Series> {
let mut state = state.split();
// don't cache window functions as they run in parallel
state.flags.remove(StateFlags::CACHE_WINDOW_EXPR);
state.remove_cache_window_flag();
let mask_series = self.predicate.evaluate(df, &state)?;
let mut mask = mask_series.bool()?.clone();

Expand Down
11 changes: 11 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ impl WindowExpr {
explicit_list = finishes_list;
}
}

explicit_list
}

Expand Down Expand Up @@ -424,6 +425,16 @@ impl PhysicalExpr for WindowExpr {
});
let explicit_list_agg = self.is_explicit_list_agg();

// A `sort()` in a window function is one level flatter
// Assume we have column a : i32
// than a sort in a groupby. A groupby sorts the groups and returns array: list[i32]
// whereas a window function returns array: i32
// So a `sort().list()` in a groupby returns: list[list[i32]]
// whereas in a window function would return: list[i32]
if explicit_list_agg {
state.set_finalize_window_as_list();
}

// if we flatten this column we need to make sure the groups are sorted.
let mut sort_groups = self.options.explode ||
// if not
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/src/physical_plan/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ fn partitionable_gb(
let aexpr = expr_arena.get(*agg);
let depth = (expr_arena).iter(*agg).count();

// These single expressions are partitionable
if matches!(aexpr, AExpr::Count) {
continue;
}
Expand Down
81 changes: 72 additions & 9 deletions polars/polars-lazy/src/physical_plan/state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::borrow::Cow;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::{Mutex, RwLock};

use bitflags::bitflags;
Expand All @@ -16,6 +17,7 @@ pub type JoinTuplesCache = Arc<Mutex<PlHashMap<String, JoinOptIds>>>;
pub type GroupsProxyCache = Arc<Mutex<PlHashMap<String, GroupsProxy>>>;

bitflags! {
#[repr(transparent)]
pub(super) struct StateFlags: u8 {
/// More verbose logging
const VERBOSE = 0x01;
Expand All @@ -25,6 +27,13 @@ bitflags! {
/// If this is the case, an `explode` will yield more values than rows in original `df`,
/// this breaks some assumptions
const OVERLAPPING_GROUPS = 0x04;
/// A `sort()` in a window function is one level flatter
/// Assume we have column a : i32
/// than a sort in a groupby. A groupby sorts the groups and returns array: list[i32]
/// whereas a window function returns array: i32
/// So a `sort().list()` in a groupby returns: list[list[i32]]
/// whereas in a window function would return: list[i32]
const FINALIZE_WINDOW_AS_LIST = 0x08;
}
}

Expand All @@ -43,6 +52,15 @@ impl StateFlags {
}
flags
}
fn as_u8(self) -> u8 {
unsafe { std::mem::transmute(self) }
}
}

impl From<u8> for StateFlags {
fn from(value: u8) -> Self {
unsafe { std::mem::transmute(value) }
}
}

/// State/ cache that is maintained during the Execution of the physical plan.
Expand All @@ -59,7 +77,7 @@ pub struct ExecutionState {
pub(super) join_tuples: JoinTuplesCache,
// every join/union split gets an increment to distinguish between schema state
pub(super) branch_idx: usize,
pub(super) flags: StateFlags,
pub(super) flags: AtomicU8,
pub(super) ext_contexts: Arc<Vec<DataFrame>>,
node_timer: Option<NodeTimer>,
}
Expand Down Expand Up @@ -101,7 +119,7 @@ impl ExecutionState {
group_tuples: Default::default(),
join_tuples: Default::default(),
branch_idx: self.branch_idx,
flags: self.flags,
flags: AtomicU8::new(self.flags.load(Ordering::Relaxed)),
ext_contexts: self.ext_contexts.clone(),
node_timer: self.node_timer.clone(),
}
Expand All @@ -117,7 +135,7 @@ impl ExecutionState {
group_tuples: self.group_tuples.clone(),
join_tuples: self.join_tuples.clone(),
branch_idx: self.branch_idx,
flags: self.flags,
flags: AtomicU8::new(self.flags.load(Ordering::Relaxed)),
ext_contexts: self.ext_contexts.clone(),
node_timer: self.node_timer.clone(),
}
Expand All @@ -137,7 +155,7 @@ impl ExecutionState {
group_tuples: Arc::new(Mutex::new(PlHashMap::default())),
join_tuples: Arc::new(Mutex::new(PlHashMap::default())),
branch_idx: 0,
flags: StateFlags::init(),
flags: AtomicU8::new(StateFlags::init().as_u8()),
ext_contexts: Default::default(),
node_timer: None,
}
Expand All @@ -157,7 +175,7 @@ impl ExecutionState {
group_tuples: Default::default(),
join_tuples: Default::default(),
branch_idx: 0,
flags: StateFlags::init(),
flags: AtomicU8::new(StateFlags::init().as_u8()),
ext_contexts: Default::default(),
node_timer: None,
}
Expand Down Expand Up @@ -201,21 +219,66 @@ impl ExecutionState {
lock.clear();
}

fn set_flags(&self, f: &dyn Fn(StateFlags) -> StateFlags) {
let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
let flags = f(flags);
self.flags.store(flags.as_u8(), Ordering::Relaxed);
}

/// Indicates that window expression's [`GroupTuples`] may be cached.
pub(super) fn cache_window(&self) -> bool {
self.flags.contains(StateFlags::CACHE_WINDOW_EXPR)
let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
flags.contains(StateFlags::CACHE_WINDOW_EXPR)
}

/// Indicates that a groupby operations groups may overlap.
/// If this is the case, an `explode` will yield more values than rows in original `df`,
/// this breaks some assumptions
pub(super) fn overlapping_groups(&self) -> bool {
self.flags.contains(StateFlags::OVERLAPPING_GROUPS)
pub(super) fn has_overlapping_groups(&self) -> bool {
let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
flags.contains(StateFlags::OVERLAPPING_GROUPS)
}
#[cfg(feature = "dynamic_groupby")]
pub(super) fn set_has_overlapping_groups(&self) {
self.set_flags(&|mut flags| {
flags |= StateFlags::OVERLAPPING_GROUPS;
flags
})
}

/// More verbose logging
pub(super) fn verbose(&self) -> bool {
self.flags.contains(StateFlags::VERBOSE)
let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
flags.contains(StateFlags::VERBOSE)
}

pub(super) fn set_finalize_window_as_list(&self) {
self.set_flags(&|mut flags| {
flags |= StateFlags::FINALIZE_WINDOW_AS_LIST;
flags
})
}

pub(super) fn unset_finalize_window_as_list(&self) -> bool {
let mut flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
let is_set = flags.contains(StateFlags::FINALIZE_WINDOW_AS_LIST);
flags.remove(StateFlags::FINALIZE_WINDOW_AS_LIST);
self.flags.store(flags.as_u8(), Ordering::Relaxed);
is_set
}

pub(super) fn remove_cache_window_flag(&mut self) {
self.set_flags(&|mut flags| {
flags.remove(StateFlags::CACHE_WINDOW_EXPR);
flags
});
}

pub(super) fn insert_cache_window_flag(&mut self) {
self.set_flags(&|mut flags| {
flags.insert(StateFlags::CACHE_WINDOW_EXPR);
flags
});
}
}

Expand Down
Loading