Skip to content

Commit

Permalink
feat(rust, python): out of core sort on multiple columns (#7244)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Feb 28, 2023
1 parent 5941afd commit 86f8be0
Show file tree
Hide file tree
Showing 31 changed files with 400 additions and 160 deletions.
2 changes: 1 addition & 1 deletion polars/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ mod reverse;
pub(crate) mod rolling_window;
mod set;
mod shift;
pub(crate) mod sort;
pub mod sort;
pub(crate) mod take;
pub(crate) mod unique;
#[cfg(feature = "zip_with")]
Expand Down
66 changes: 38 additions & 28 deletions polars/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use polars_arrow::data_types::IsFloat;
use polars_row::{convert_columns, RowsEncoded, SortField};
use polars_utils::iter::EnumerateIdxTrait;

use super::*;
Expand Down Expand Up @@ -67,45 +68,54 @@ pub(crate) fn arg_sort_multiple_impl<T: PartialOrd + Send + IsFloat + Copy>(
Ok(ca.into_inner())
}

pub(crate) fn argsort_multiple_row_fmt(
pub fn _get_rows_encoded_compat_array(by: &Series) -> PolarsResult<ArrayRef> {
let by = convert_sort_column_multi_sort(by, true)?;
let by = by.rechunk();

let out = match by.dtype() {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(_) => {
let ca = by.categorical().unwrap();
if ca.use_lexical_sort() {
by.to_arrow(0)
} else {
ca.logical().chunks[0].clone()
}
}
_ => by.to_arrow(0),
};
Ok(out)
}

pub fn _get_rows_encoded(
by: &[Series],
mut descending: Vec<bool>,
descending: &[bool],
nulls_last: bool,
parallel: bool,
) -> PolarsResult<IdxCa> {
use polars_row::{convert_columns, SortField};
broadcast_descending(by.len(), &mut descending);

) -> PolarsResult<RowsEncoded> {
debug_assert_eq!(by.len(), descending.len());
let mut cols = Vec::with_capacity(by.len());
let mut fields = Vec::with_capacity(by.len());

debug_assert_eq!(by.len(), descending.len());
for (by, descending) in by.iter().zip(descending) {
let by = convert_sort_column_multi_sort(by, true)?;
let by = by.rechunk();

let arr = match by.dtype() {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(_) => {
let ca = by.categorical().unwrap();
if ca.use_lexical_sort() {
by.to_arrow(0)
} else {
ca.logical().chunks[0].clone()
}
}
_ => by.to_arrow(0),
};
let data_type = arr.data_type().clone();
let arr = _get_rows_encoded_compat_array(by)?;

cols.push(arr);
fields.push(SortField {
descending,
descending: *descending,
nulls_last,
data_type,
})
}
let rows_encoded = convert_columns(&cols, fields);
Ok(convert_columns(&cols, &fields))
}

pub(crate) fn argsort_multiple_row_fmt(
by: &[Series],
mut descending: Vec<bool>,
nulls_last: bool,
parallel: bool,
) -> PolarsResult<IdxCa> {
_broadcast_descending(by.len(), &mut descending);

let rows_encoded = _get_rows_encoded(by, &descending, nulls_last)?;
let mut items: Vec<_> = rows_encoded.iter().enumerate_idx().collect();

if parallel {
Expand Down
6 changes: 3 additions & 3 deletions polars/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod arg_sort;
#[cfg(feature = "sort_multiple")]
mod arg_sort_multiple;
pub mod arg_sort_multiple;
#[cfg(feature = "dtype-categorical")]
mod categorical;

Expand Down Expand Up @@ -742,7 +742,7 @@ pub(crate) fn convert_sort_column_multi_sort(
Ok(out)
}

pub(super) fn broadcast_descending(n_cols: usize, descending: &mut Vec<bool>) {
pub fn _broadcast_descending(n_cols: usize, descending: &mut Vec<bool>) {
if n_cols > descending.len() && descending.len() == 1 {
while n_cols != descending.len() {
descending.push(descending[0]);
Expand All @@ -765,7 +765,7 @@ pub(crate) fn prepare_arg_sort(
let first = columns.remove(0);

// broadcast ordering
broadcast_descending(n_cols, &mut descending);
_broadcast_descending(n_cols, &mut descending);
Ok((first, columns, descending))
}

Expand Down
10 changes: 10 additions & 0 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,16 @@ impl DataFrame {
inner(self, series)
}

/// Adds a column to the `DataFrame` without doing any checks
/// on length or duplicates.
///
/// # Safety
/// The caller must ensure `column.len() == self.height()` .
pub unsafe fn with_column_unchecked(&mut self, column: Series) -> &mut Self {
self.get_columns_mut().push(column);
self
}

fn add_column_by_schema(&mut self, s: Series, schema: &Schema) -> PolarsResult<()> {
let name = s.name();
if let Some((idx, _, _)) = schema.get_full(name) {
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/polars-pipe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ polars-core = { version = "0.27.2", path = "../../polars-core", features = ["laz
polars-io = { version = "0.27.2", path = "../../polars-io", default-features = false, features = ["ipc"] }
polars-ops = { version = "0.27.2", path = "../../polars-ops", features = ["search_sorted"] }
polars-plan = { version = "0.27.2", path = "../polars-plan", default-features = false, features = ["compile"] }
polars-row = { version = "0.27.2", path = "../../polars-row" }
polars-utils = { version = "0.27.2", path = "../../polars-utils", features = ["sysinfo"] }
rayon.workspace = true
smartstring = { version = "1" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ impl Sink for FilesSink {
Ok(SinkResult::CanHaveMoreInput)
}

fn combine(&mut self, _other: Box<dyn Sink>) {
fn combine(&mut self, _other: &mut dyn Sink) {
// already synchronized
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ impl Sink for GenericGroupbySink {
Ok(SinkResult::CanHaveMoreInput)
}

fn combine(&mut self, mut other: Box<dyn Sink>) {
fn combine(&mut self, other: &mut dyn Sink) {
// don't parallel this as this is already done in parallel.

let other = other.as_any().downcast_ref::<Self>().unwrap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ where
Ok(SinkResult::CanHaveMoreInput)
}

fn combine(&mut self, mut other: Box<dyn Sink>) {
fn combine(&mut self, other: &mut dyn Sink) {
// don't parallel this as this is already done in parallel.
let other = other.as_any().downcast_ref::<Self>().unwrap();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ impl Sink for Utf8GroupbySink {
Ok(SinkResult::CanHaveMoreInput)
}

fn combine(&mut self, mut other: Box<dyn Sink>) {
fn combine(&mut self, other: &mut dyn Sink) {
// don't parallel this as this is already done in parallel.

let other = other.as_any().downcast_ref::<Self>().unwrap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl Sink for CrossJoin {
Ok(SinkResult::CanHaveMoreInput)
}

fn combine(&mut self, mut other: Box<dyn Sink>) {
fn combine(&mut self, other: &mut dyn Sink) {
let other = other.as_any().downcast_mut::<Self>().unwrap();
let other_chunks = std::mem::take(&mut other.chunks);
self.chunks.extend(other_chunks.into_iter());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ impl Sink for GenericBuild {
Ok(SinkResult::CanHaveMoreInput)
}

fn combine(&mut self, mut other: Box<dyn Sink>) {
fn combine(&mut self, other: &mut dyn Sink) {
if self.is_empty() {
let other = other.as_any().downcast_mut::<Self>().unwrap();
if !other.is_empty() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl Sink for OrderedSink {
Ok(SinkResult::CanHaveMoreInput)
}

fn combine(&mut self, mut other: Box<dyn Sink>) {
fn combine(&mut self, other: &mut dyn Sink) {
let other = other.as_any().downcast_ref::<OrderedSink>().unwrap();
self.chunks.extend_from_slice(&other.chunks);
self.sort();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl Sink for SliceSink {
}
}

fn combine(&mut self, _other: Box<dyn Sink>) {
fn combine(&mut self, _other: &mut dyn Sink) {
// no-op
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod io;
mod ooc;
mod sink;
mod sink_multiple;
mod source;

pub(crate) use sink::SortSink;
pub(crate) use sink_multiple::SortSinkMultiple;
42 changes: 25 additions & 17 deletions polars/polars-lazy/polars-pipe/src/executors/sinks/sort/sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use std::sync::{Arc, Mutex};

use polars_core::error::PolarsResult;
use polars_core::frame::DataFrame;
use polars_core::prelude::{AnyValue, SchemaRef, Series};
use polars_core::prelude::{AnyValue, SchemaRef, Series, SortOptions};
use polars_core::utils::accumulate_dataframes_vertical_unchecked;
use polars_plan::prelude::SortArguments;
use polars_utils::atomic::SyncCounter;
use polars_utils::sys::MEMINFO;

Expand All @@ -26,19 +27,13 @@ pub struct SortSink {
io_thread: Arc<Mutex<Option<IOThread>>>,
// location in the dataframe of the columns to sort by
sort_idx: usize,
descending: bool,
slice: Option<(i64, usize)>,
sort_args: SortArguments,
// sampled values so we can find the distribution.
dist_sample: Vec<AnyValue<'static>>,
}

impl SortSink {
pub(crate) fn new(
sort_idx: usize,
descending: bool,
schema: SchemaRef,
slice: Option<(i64, usize)>,
) -> Self {
pub(crate) fn new(sort_idx: usize, sort_args: SortArguments, schema: SchemaRef) -> Self {
// for testing purposes
let ooc = std::env::var("POLARS_FORCE_OOC_SORT").is_ok();

Expand All @@ -50,8 +45,7 @@ impl SortSink {
ooc,
io_thread: Default::default(),
sort_idx,
descending,
slice,
sort_args,
dist_sample: vec![],
};
if ooc {
Expand Down Expand Up @@ -129,7 +123,7 @@ impl Sink for SortSink {
Ok(SinkResult::CanHaveMoreInput)
}

fn combine(&mut self, mut other: Box<dyn Sink>) {
fn combine(&mut self, other: &mut dyn Sink) {
let other = other.as_any().downcast_mut::<Self>().unwrap();
self.chunks.extend(std::mem::take(&mut other.chunks));
self.ooc |= other.ooc;
Expand All @@ -150,9 +144,8 @@ impl Sink for SortSink {
ooc: self.ooc,
io_thread: self.io_thread.clone(),
sort_idx: self.sort_idx,
descending: self.descending,
sort_args: self.sort_args.clone(),
dist_sample: vec![],
slice: self.slice,
})
}

Expand All @@ -168,15 +161,30 @@ impl Sink for SortSink {
let io_thread = lock.as_ref().unwrap();

let dist = Series::from_any_values("", &self.dist_sample).unwrap();
let dist = dist.sort(self.descending);
let dist = dist.sort_with(SortOptions {
descending: self.sort_args.descending[0],
nulls_last: self.sort_args.nulls_last,
multithreaded: true,
});

block_thread_until_io_thread_done(io_thread);

sort_ooc(io_thread, dist, self.sort_idx, self.descending, self.slice)
sort_ooc(
io_thread,
dist,
self.sort_idx,
self.sort_args.descending[0],
self.sort_args.slice,
)
} else {
let chunks = std::mem::take(&mut self.chunks);
let df = accumulate_dataframes_vertical_unchecked(chunks);
let df = sort_accumulated(df, self.sort_idx, self.descending, self.slice)?;
let df = sort_accumulated(
df,
self.sort_idx,
self.sort_args.descending[0],
self.sort_args.slice,
)?;
Ok(FinalizedSink::Finished(df))
}
}
Expand Down
Loading

0 comments on commit 86f8be0

Please sign in to comment.