Skip to content

Commit

Permalink
Fix incorrect results in COUNT(*) queries with LIMIT (#8049)
Browse files Browse the repository at this point in the history
Co-authored-by: Mark Sirek <sirek@cockroachlabs.com>
  • Loading branch information
msirek and Mark Sirek authored Nov 7, 2023
1 parent 07c08a3 commit 06fd26b
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 26 deletions.
159 changes: 138 additions & 21 deletions datafusion/physical-plan/src/limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,21 +188,11 @@ impl ExecutionPlan for GlobalLimitExec {
fn statistics(&self) -> Result<Statistics> {
let input_stats = self.input.statistics()?;
let skip = self.skip;
// the maximum row number needs to be fetched
let max_row_num = self
.fetch
.map(|fetch| {
if fetch >= usize::MAX - skip {
usize::MAX
} else {
fetch + skip
}
})
.unwrap_or(usize::MAX);
let col_stats = Statistics::unknown_column(&self.schema());
let fetch = self.fetch.unwrap_or(usize::MAX);

let fetched_row_number_stats = Statistics {
num_rows: Precision::Exact(max_row_num),
let mut fetched_row_number_stats = Statistics {
num_rows: Precision::Exact(fetch),
column_statistics: col_stats.clone(),
total_byte_size: Precision::Absent,
};
Expand All @@ -218,23 +208,55 @@ impl ExecutionPlan for GlobalLimitExec {
} => {
if nr <= skip {
// if all input data will be skipped, return 0
Statistics {
let mut skip_all_rows_stats = Statistics {
num_rows: Precision::Exact(0),
column_statistics: col_stats,
total_byte_size: Precision::Absent,
};
if !input_stats.num_rows.is_exact().unwrap_or(false) {
// The input stats are inexact, so the output stats must be too.
skip_all_rows_stats = skip_all_rows_stats.into_inexact();
}
} else if nr <= max_row_num {
// if the input does not reach the "fetch" globally, return input stats
skip_all_rows_stats
} else if nr <= fetch && self.skip == 0 {
// if the input does not reach the "fetch" globally, and "skip" is zero
// (meaning the input and output are identical), return input stats.
// Can input_stats still be used, but adjusted, in the "skip != 0" case?
input_stats
} else if nr - skip <= fetch {
// after "skip" input rows are skipped, the remaining rows are less than or equal to the
// "fetch" values, so `num_rows` must equal the remaining rows
let remaining_rows: usize = nr - skip;
let mut skip_some_rows_stats = Statistics {
num_rows: Precision::Exact(remaining_rows),
column_statistics: col_stats.clone(),
total_byte_size: Precision::Absent,
};
if !input_stats.num_rows.is_exact().unwrap_or(false) {
// The input stats are inexact, so the output stats must be too.
skip_some_rows_stats = skip_some_rows_stats.into_inexact();
}
skip_some_rows_stats
} else {
// if the input is greater than the "fetch", the num_row will be the "fetch",
// if the input is greater than "fetch+skip", the num_rows will be the "fetch",
// but we won't be able to predict the other statistics
if !input_stats.num_rows.is_exact().unwrap_or(false)
|| self.fetch.is_none()
{
// If the input stats are inexact, the output stats must be too.
// If the fetch value is `usize::MAX` because no LIMIT was specified,
// we also can't represent it as an exact value.
fetched_row_number_stats =
fetched_row_number_stats.into_inexact();
}
fetched_row_number_stats
}
}
_ => {
// the result output row number will always be no greater than the limit number
fetched_row_number_stats
// The result output `num_rows` will always be no greater than the limit number.
// Should `num_rows` be marked as `Absent` here when the `fetch` value is large,
// as the actual `num_rows` may be far away from the `fetch` value?
fetched_row_number_stats.into_inexact()
}
};
Ok(stats)
Expand Down Expand Up @@ -552,7 +574,10 @@ mod tests {
use crate::common::collect;
use crate::{common, test};

use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use arrow_schema::Schema;
use datafusion_physical_expr::expressions::col;
use datafusion_physical_expr::PhysicalExpr;

#[tokio::test]
async fn limit() -> Result<()> {
Expand Down Expand Up @@ -712,7 +737,7 @@ mod tests {
}

#[tokio::test]
async fn skip_3_fetch_10() -> Result<()> {
async fn skip_3_fetch_10_stats() -> Result<()> {
// there are total of 100 rows, we skipped 3 rows (offset = 3)
let row_count = skip_and_fetch(3, Some(10)).await?;
assert_eq!(row_count, 10);
Expand Down Expand Up @@ -748,7 +773,58 @@ mod tests {
assert_eq!(row_count, Precision::Exact(10));

let row_count = row_number_statistics_for_global_limit(5, Some(10)).await?;
assert_eq!(row_count, Precision::Exact(15));
assert_eq!(row_count, Precision::Exact(10));

let row_count = row_number_statistics_for_global_limit(400, Some(10)).await?;
assert_eq!(row_count, Precision::Exact(0));

let row_count = row_number_statistics_for_global_limit(398, Some(10)).await?;
assert_eq!(row_count, Precision::Exact(2));

let row_count = row_number_statistics_for_global_limit(398, Some(1)).await?;
assert_eq!(row_count, Precision::Exact(1));

let row_count = row_number_statistics_for_global_limit(398, None).await?;
assert_eq!(row_count, Precision::Exact(2));

let row_count =
row_number_statistics_for_global_limit(0, Some(usize::MAX)).await?;
assert_eq!(row_count, Precision::Exact(400));

let row_count =
row_number_statistics_for_global_limit(398, Some(usize::MAX)).await?;
assert_eq!(row_count, Precision::Exact(2));

let row_count =
row_number_inexact_statistics_for_global_limit(0, Some(10)).await?;
assert_eq!(row_count, Precision::Inexact(10));

let row_count =
row_number_inexact_statistics_for_global_limit(5, Some(10)).await?;
assert_eq!(row_count, Precision::Inexact(10));

let row_count =
row_number_inexact_statistics_for_global_limit(400, Some(10)).await?;
assert_eq!(row_count, Precision::Inexact(0));

let row_count =
row_number_inexact_statistics_for_global_limit(398, Some(10)).await?;
assert_eq!(row_count, Precision::Inexact(2));

let row_count =
row_number_inexact_statistics_for_global_limit(398, Some(1)).await?;
assert_eq!(row_count, Precision::Inexact(1));

let row_count = row_number_inexact_statistics_for_global_limit(398, None).await?;
assert_eq!(row_count, Precision::Inexact(2));

let row_count =
row_number_inexact_statistics_for_global_limit(0, Some(usize::MAX)).await?;
assert_eq!(row_count, Precision::Inexact(400));

let row_count =
row_number_inexact_statistics_for_global_limit(398, Some(usize::MAX)).await?;
assert_eq!(row_count, Precision::Inexact(2));

Ok(())
}
Expand Down Expand Up @@ -776,6 +852,47 @@ mod tests {
Ok(offset.statistics()?.num_rows)
}

pub fn build_group_by(
input_schema: &SchemaRef,
columns: Vec<String>,
) -> PhysicalGroupBy {
let mut group_by_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
for column in columns.iter() {
group_by_expr.push((col(column, input_schema).unwrap(), column.to_string()));
}
PhysicalGroupBy::new_single(group_by_expr.clone())
}

async fn row_number_inexact_statistics_for_global_limit(
skip: usize,
fetch: Option<usize>,
) -> Result<Precision<usize>> {
let num_partitions = 4;
let csv = test::scan_partitioned(num_partitions);

assert_eq!(csv.output_partitioning().partition_count(), num_partitions);

// Adding a "GROUP BY i" changes the input stats from Exact to Inexact.
let agg = AggregateExec::try_new(
AggregateMode::Final,
build_group_by(&csv.schema().clone(), vec!["i".to_string()]),
vec![],
vec![None],
vec![None],
csv.clone(),
csv.schema().clone(),
)?;
let agg_exec: Arc<dyn ExecutionPlan> = Arc::new(agg);

let offset = GlobalLimitExec::new(
Arc::new(CoalescePartitionsExec::new(agg_exec)),
skip,
fetch,
);

Ok(offset.statistics()?.num_rows)
}

async fn row_number_statistics_for_local_limit(
num_partitions: usize,
fetch: usize,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/explain.slt
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ query TT
EXPLAIN SELECT a, b, c FROM simple_explain_test limit 10;
----
physical_plan
GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(10), Bytes=Absent]
GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Inexact(10), Bytes=Absent]
--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], limit=10, has_header=true, statistics=[Rows=Absent, Bytes=Absent]

# Parquet scan with statistics collected
Expand Down
85 changes: 85 additions & 0 deletions datafusion/sqllogictest/test_files/limit.slt
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,91 @@ query T
SELECT c1 FROM aggregate_test_100 LIMIT 1 OFFSET 101
----

#
# global limit statistics test
#

statement ok
CREATE TABLE IF NOT EXISTS t1 (a INT) AS VALUES(1),(2),(3),(4),(5),(6),(7),(8),(9),(10);

# The aggregate does not need to be computed because the input statistics are exact and
# the number of rows is less than the skip value (OFFSET).
query TT
EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11);
----
logical_plan
Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
--Limit: skip=11, fetch=3
----TableScan: t1 projection=[], fetch=14
physical_plan
ProjectionExec: expr=[0 as COUNT(*)]
--EmptyExec: produce_one_row=true

query I
SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11);
----
0

# The aggregate does not need to be computed because the input statistics are exact and
# the number of rows is less than or equal to the the "fetch+skip" value (LIMIT+OFFSET).
query TT
EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8);
----
logical_plan
Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
--Limit: skip=8, fetch=3
----TableScan: t1 projection=[], fetch=11
physical_plan
ProjectionExec: expr=[2 as COUNT(*)]
--EmptyExec: produce_one_row=true

query I
SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8);
----
2

# The aggregate does not need to be computed because the input statistics are exact and
# an OFFSET, but no LIMIT, is specified.
query TT
EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 OFFSET 8);
----
logical_plan
Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
--Limit: skip=8, fetch=None
----TableScan: t1 projection=[]
physical_plan
ProjectionExec: expr=[2 as COUNT(*)]
--EmptyExec: produce_one_row=true

query I
SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8);
----
2

# The aggregate needs to be computed because the input statistics are inexact.
query TT
EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6);
----
logical_plan
Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
--Limit: skip=6, fetch=3
----Filter: t1.a > Int32(3)
------TableScan: t1 projection=[a]
physical_plan
AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)]
--CoalescePartitionsExec
----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)]
------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
--------GlobalLimitExec: skip=6, fetch=3
----------CoalesceBatchesExec: target_batch_size=8192
------------FilterExec: a@0 > 3
--------------MemoryExec: partitions=1, partition_sizes=[1]

query I
SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6);
----
1

########
# Clean up after the test
########
Expand Down
12 changes: 8 additions & 4 deletions datafusion/sqllogictest/test_files/window.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2010,10 +2010,14 @@ Projection: ARRAY_AGG(aggregate_test_100.c13) AS array_agg1
--------TableScan: aggregate_test_100 projection=[c13]
physical_plan
ProjectionExec: expr=[ARRAY_AGG(aggregate_test_100.c13)@0 as array_agg1]
--AggregateExec: mode=Single, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)]
----GlobalLimitExec: skip=0, fetch=1
------SortExec: TopK(fetch=1), expr=[c13@0 ASC NULLS LAST]
--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c13], has_header=true
--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)]
----CoalescePartitionsExec
------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)]
--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
----------GlobalLimitExec: skip=0, fetch=1
------------SortExec: TopK(fetch=1), expr=[c13@0 ASC NULLS LAST]
--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c13], has_header=true


query ?
SELECT ARRAY_AGG(c13) as array_agg1 FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1)
Expand Down

0 comments on commit 06fd26b

Please sign in to comment.