Skip to content

Commit

Permalink
Merge branch 'alamb/hash_agg_spike' of github.com:alamb/arrow-datafus…
Browse files Browse the repository at this point in the history
…ion into alamb/hash_agg_spike
  • Loading branch information
alamb committed Jul 5, 2023
2 parents 917c050 + eb919a9 commit 9eb6822
Show file tree
Hide file tree
Showing 20 changed files with 471 additions and 165 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ members = [
"test-utils",
"benchmarks",
]
resolver = "2"

[workspace.package]
version = "27.0.0"
Expand Down
8 changes: 2 additions & 6 deletions datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,8 @@ impl DFSchema {
let self_fields = self.fields().iter();
let other_fields = other.fields().iter();
self_fields.zip(other_fields).all(|(f1, f2)| {
// TODO: resolve field when exist alias
// f1.qualifier() == f2.qualifier()
// && f1.name() == f2.name()
// column(t1.a) field is "t1"."a"
// column(x) as t1.a field is ""."t1.a"
f1.qualified_name() == f2.qualified_name()
f1.qualifier() == f2.qualifier()
&& f1.name() == f2.name()
&& Self::datatype_is_semantically_equal(f1.data_type(), f2.data_type())
})
}
Expand Down
104 changes: 102 additions & 2 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ pub enum ScalarValue {
/// Months and days are encoded as 32-bit signed integers.
/// Nanoseconds is encoded as a 64-bit signed integer (no leap seconds).
IntervalMonthDayNano(Option<i128>),
/// Duration in seconds
DurationSecond(Option<i64>),
/// Duration in milliseconds
DurationMillisecond(Option<i64>),
/// Duration in microseconds
DurationMicrosecond(Option<i64>),
/// Duration in nanoseconds
DurationNanosecond(Option<i64>),
/// struct of nested ScalarValue
Struct(Option<Vec<ScalarValue>>, Fields),
/// Dictionary type: index type and value
Expand Down Expand Up @@ -210,6 +218,14 @@ impl PartialEq for ScalarValue {
(TimestampMicrosecond(_, _), _) => false,
(TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => v1.eq(v2),
(TimestampNanosecond(_, _), _) => false,
(DurationSecond(v1), DurationSecond(v2)) => v1.eq(v2),
(DurationSecond(_), _) => false,
(DurationMillisecond(v1), DurationMillisecond(v2)) => v1.eq(v2),
(DurationMillisecond(_), _) => false,
(DurationMicrosecond(v1), DurationMicrosecond(v2)) => v1.eq(v2),
(DurationMicrosecond(_), _) => false,
(DurationNanosecond(v1), DurationNanosecond(v2)) => v1.eq(v2),
(DurationNanosecond(_), _) => false,
(IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.eq(v2),
(IntervalYearMonth(v1), IntervalDayTime(v2)) => {
ym_to_milli(v1).eq(&dt_to_milli(v2))
Expand Down Expand Up @@ -357,6 +373,14 @@ impl PartialOrd for ScalarValue {
mdn_to_nano(v1).partial_cmp(&dt_to_nano(v2))
}
(IntervalMonthDayNano(_), _) => None,
(DurationSecond(v1), DurationSecond(v2)) => v1.partial_cmp(v2),
(DurationSecond(_), _) => None,
(DurationMillisecond(v1), DurationMillisecond(v2)) => v1.partial_cmp(v2),
(DurationMillisecond(_), _) => None,
(DurationMicrosecond(v1), DurationMicrosecond(v2)) => v1.partial_cmp(v2),
(DurationMicrosecond(_), _) => None,
(DurationNanosecond(v1), DurationNanosecond(v2)) => v1.partial_cmp(v2),
(DurationNanosecond(_), _) => None,
(Struct(v1, t1), Struct(v2, t2)) => {
if t1.eq(t2) {
v1.partial_cmp(v2)
Expand Down Expand Up @@ -1508,6 +1532,10 @@ impl std::hash::Hash for ScalarValue {
TimestampMillisecond(v, _) => v.hash(state),
TimestampMicrosecond(v, _) => v.hash(state),
TimestampNanosecond(v, _) => v.hash(state),
DurationSecond(v) => v.hash(state),
DurationMillisecond(v) => v.hash(state),
DurationMicrosecond(v) => v.hash(state),
DurationNanosecond(v) => v.hash(state),
IntervalYearMonth(v) => v.hash(state),
IntervalDayTime(v) => v.hash(state),
IntervalMonthDayNano(v) => v.hash(state),
Expand Down Expand Up @@ -1984,6 +2012,16 @@ impl ScalarValue {
ScalarValue::IntervalMonthDayNano(_) => {
DataType::Interval(IntervalUnit::MonthDayNano)
}
ScalarValue::DurationSecond(_) => DataType::Duration(TimeUnit::Second),
ScalarValue::DurationMillisecond(_) => {
DataType::Duration(TimeUnit::Millisecond)
}
ScalarValue::DurationMicrosecond(_) => {
DataType::Duration(TimeUnit::Microsecond)
}
ScalarValue::DurationNanosecond(_) => {
DataType::Duration(TimeUnit::Nanosecond)
}
ScalarValue::Struct(_, fields) => DataType::Struct(fields.clone()),
ScalarValue::Dictionary(k, v) => {
DataType::Dictionary(k.clone(), Box::new(v.get_datatype()))
Expand Down Expand Up @@ -2118,6 +2156,10 @@ impl ScalarValue {
ScalarValue::IntervalYearMonth(v) => v.is_none(),
ScalarValue::IntervalDayTime(v) => v.is_none(),
ScalarValue::IntervalMonthDayNano(v) => v.is_none(),
ScalarValue::DurationSecond(v) => v.is_none(),
ScalarValue::DurationMillisecond(v) => v.is_none(),
ScalarValue::DurationMicrosecond(v) => v.is_none(),
ScalarValue::DurationNanosecond(v) => v.is_none(),
ScalarValue::Struct(v, _) => v.is_none(),
ScalarValue::Dictionary(_, v) => v.is_null(),
}
Expand Down Expand Up @@ -2897,6 +2939,34 @@ impl ScalarValue {
e,
size
),
ScalarValue::DurationSecond(e) => build_array_from_option!(
Duration,
TimeUnit::Second,
DurationSecondArray,
e,
size
),
ScalarValue::DurationMillisecond(e) => build_array_from_option!(
Duration,
TimeUnit::Millisecond,
DurationMillisecondArray,
e,
size
),
ScalarValue::DurationMicrosecond(e) => build_array_from_option!(
Duration,
TimeUnit::Microsecond,
DurationMicrosecondArray,
e,
size
),
ScalarValue::DurationNanosecond(e) => build_array_from_option!(
Duration,
TimeUnit::Nanosecond,
DurationNanosecondArray,
e,
size
),
ScalarValue::Struct(values, fields) => match values {
Some(values) => {
let field_values: Vec<_> = fields
Expand Down Expand Up @@ -3264,6 +3334,18 @@ impl ScalarValue {
ScalarValue::IntervalMonthDayNano(val) => {
eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val)
}
ScalarValue::DurationSecond(val) => {
eq_array_primitive!(array, index, DurationSecondArray, val)
}
ScalarValue::DurationMillisecond(val) => {
eq_array_primitive!(array, index, DurationMillisecondArray, val)
}
ScalarValue::DurationMicrosecond(val) => {
eq_array_primitive!(array, index, DurationMicrosecondArray, val)
}
ScalarValue::DurationNanosecond(val) => {
eq_array_primitive!(array, index, DurationNanosecondArray, val)
}
ScalarValue::Struct(_, _) => unimplemented!(),
ScalarValue::Dictionary(key_type, v) => {
let (values_array, values_index) = match key_type.as_ref() {
Expand Down Expand Up @@ -3313,7 +3395,11 @@ impl ScalarValue {
| ScalarValue::Time64Nanosecond(_)
| ScalarValue::IntervalYearMonth(_)
| ScalarValue::IntervalDayTime(_)
| ScalarValue::IntervalMonthDayNano(_) => 0,
| ScalarValue::IntervalMonthDayNano(_)
| ScalarValue::DurationSecond(_)
| ScalarValue::DurationMillisecond(_)
| ScalarValue::DurationMicrosecond(_)
| ScalarValue::DurationNanosecond(_) => 0,
ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) => {
s.as_ref().map(|s| s.capacity()).unwrap_or_default()
}
Expand Down Expand Up @@ -3699,6 +3785,10 @@ impl fmt::Display for ScalarValue {
ScalarValue::IntervalDayTime(e) => format_option!(f, e)?,
ScalarValue::IntervalYearMonth(e) => format_option!(f, e)?,
ScalarValue::IntervalMonthDayNano(e) => format_option!(f, e)?,
ScalarValue::DurationSecond(e) => format_option!(f, e)?,
ScalarValue::DurationMillisecond(e) => format_option!(f, e)?,
ScalarValue::DurationMicrosecond(e) => format_option!(f, e)?,
ScalarValue::DurationNanosecond(e) => format_option!(f, e)?,
ScalarValue::Struct(e, fields) => match e {
Some(l) => write!(
f,
Expand Down Expand Up @@ -3781,6 +3871,16 @@ impl fmt::Debug for ScalarValue {
ScalarValue::IntervalMonthDayNano(_) => {
write!(f, "IntervalMonthDayNano(\"{self}\")")
}
ScalarValue::DurationSecond(_) => write!(f, "DurationSecond(\"{self}\")"),
ScalarValue::DurationMillisecond(_) => {
write!(f, "DurationMillisecond(\"{self}\")")
}
ScalarValue::DurationMicrosecond(_) => {
write!(f, "DurationMicrosecond(\"{self}\")")
}
ScalarValue::DurationNanosecond(_) => {
write!(f, "DurationNanosecond(\"{self}\")")
}
ScalarValue::Struct(e, fields) => {
// Use Debug representation of field values
match e {
Expand All @@ -3802,7 +3902,7 @@ impl fmt::Debug for ScalarValue {
}
}

/// Trait used to map a NativeTime to a ScalarType.
/// Trait used to map a NativeType to a ScalarValue
pub trait ScalarType<T: ArrowNativeType> {
/// returns a scalar from an optional T
fn scalar(r: Option<T>) -> ScalarValue;
Expand Down
39 changes: 22 additions & 17 deletions datafusion/core/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ use async_trait::async_trait;
use datafusion_common::SchemaExt;
use datafusion_execution::TaskContext;
use tokio::sync::RwLock;
use tokio::task::JoinSet;

use crate::datasource::{TableProvider, TableType};
use crate::error::{DataFusionError, Result};
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::insert::{DataSink, InsertExec};
use crate::physical_plan::memory::MemoryExec;
use crate::physical_plan::{common, SendableRecordBatchStream};
Expand Down Expand Up @@ -89,26 +89,31 @@ impl MemTable {
let exec = t.scan(state, None, &[], None).await?;
let partition_count = exec.output_partitioning().partition_count();

let tasks = (0..partition_count)
.map(|part_i| {
let task = state.task_ctx();
let exec = exec.clone();
let task = tokio::spawn(async move {
let stream = exec.execute(part_i, task)?;
common::collect(stream).await
});

AbortOnDropSingle::new(task)
})
// this collect *is needed* so that the join below can
// switch between tasks
.collect::<Vec<_>>();
let mut join_set = JoinSet::new();

for part_idx in 0..partition_count {
let task = state.task_ctx();
let exec = exec.clone();
join_set.spawn(async move {
let stream = exec.execute(part_idx, task)?;
common::collect(stream).await
});
}

let mut data: Vec<Vec<RecordBatch>> =
Vec::with_capacity(exec.output_partitioning().partition_count());

for result in futures::future::join_all(tasks).await {
data.push(result.map_err(|e| DataFusionError::External(Box::new(e)))??)
while let Some(result) = join_set.join_next().await {
match result {
Ok(res) => data.push(res?),
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
}

let exec = MemoryExec::try_new(&data, schema.clone(), None)?;
Expand Down
32 changes: 19 additions & 13 deletions datafusion/core/src/datasource/physical_plan/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use crate::datasource::physical_plan::file_stream::{
};
use crate::datasource::physical_plan::FileMeta;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use crate::physical_plan::{
Expand All @@ -46,7 +45,7 @@ use std::fs;
use std::path::Path;
use std::sync::Arc;
use std::task::Poll;
use tokio::task::{self, JoinHandle};
use tokio::task::JoinSet;

/// Execution plan for scanning a CSV file
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -331,7 +330,7 @@ pub async fn plan_to_csv(
)));
}

let mut tasks = vec![];
let mut join_set = JoinSet::new();
for i in 0..plan.output_partitioning().partition_count() {
let plan = plan.clone();
let filename = format!("part-{i}.csv");
Expand All @@ -340,22 +339,29 @@ pub async fn plan_to_csv(
let mut writer = csv::Writer::new(file);
let stream = plan.execute(i, task_ctx.clone())?;

let handle: JoinHandle<Result<()>> = task::spawn(async move {
stream
join_set.spawn(async move {
let result: Result<()> = stream
.map(|batch| writer.write(&batch?))
.try_collect()
.await
.map_err(DataFusionError::from)
.map_err(DataFusionError::from);
result
});
tasks.push(AbortOnDropSingle::new(handle));
}

futures::future::join_all(tasks)
.await
.into_iter()
.try_for_each(|result| {
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
})?;
while let Some(result) = join_set.join_next().await {
match result {
Ok(res) => res?, // propagate DataFusion error
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
}

Ok(())
}

Expand Down
Loading

0 comments on commit 9eb6822

Please sign in to comment.