Skip to content

Commit

Permalink
[FLINK-21225][table-planner-blink] Support OVER window distinct aggre…
Browse files Browse the repository at this point in the history
…gates in Table API

This closes apache#14877
  • Loading branch information
LadyForest committed Feb 10, 2021
1 parent ced0fdd commit 31346e8
Show file tree
Hide file tree
Showing 4 changed files with 351 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ public Optional<RexNode> convert(CallExpression call, ConvertContext context) {
if (call.getFunctionDefinition() == BuiltInFunctionDefinitions.OVER) {
FlinkTypeFactory typeFactory = context.getTypeFactory();
Expression agg = children.get(0);
FunctionDefinition def = ((CallExpression) agg).getFunctionDefinition();
boolean isDistinct = BuiltInFunctionDefinitions.DISTINCT == def;

SqlAggFunction aggFunc = agg.accept(new SqlAggFunctionVisitor(context.getRelBuilder()));
RelDataType aggResultType =
typeFactory.createFieldTypeFromLogicalType(
Expand All @@ -78,7 +81,16 @@ public Optional<RexNode> convert(CallExpression call, ConvertContext context) {

// assemble exprs by agg children
List<RexNode> aggExprs =
agg.getChildren().stream().map(context::toRexNode).collect(Collectors.toList());
agg.getChildren().stream()
.map(
child -> {
if (isDistinct) {
return context.toRexNode(child.getChildren().get(0));
} else {
return context.toRexNode(child);
}
})
.collect(Collectors.toList());

// assemble order by key
Expression orderKeyExpr = children.get(1);
Expand Down Expand Up @@ -123,7 +135,7 @@ public Optional<RexNode> convert(CallExpression call, ConvertContext context) {
isPhysical,
true,
false,
false));
isDistinct));
}
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,40 @@ Calc(select=[c, w0$o0 AS _c1, w0$o1 AS _c2])
+- Exchange(distribution=[hash[c]])
+- Calc(select=[a, c, proctime, CAST(a) AS $3])
+- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime])
]]>
</Resource>
</TestCase>
<TestCase name="testRowTimeBoundedDistinctWithPartitionedRangeOver">
<Resource name="ast">
<![CDATA[
LogicalProject(c=[$2], _c1=[AS(COUNT(DISTINCT $0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST RANGE 7200000 PRECEDING), _UTF-16LE'_c1')], _c2=[AS(SUM(DISTINCT $0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST RANGE 7200000 PRECEDING), _UTF-16LE'_c2')], _c3=[AS(AVG(DISTINCT AS(CAST($0):FLOAT, _UTF-16LE'a')) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST RANGE 7200000 PRECEDING), _UTF-16LE'_c3')])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[c, w0$o0 AS _c1, w0$o1 AS _c2, w0$o2 AS _c3])
+- OverAggregate(partitionBy=[c], orderBy=[rowtime ASC], window=[ RANG BETWEEN 7200000 PRECEDING AND CURRENT ROW], select=[a, c, rowtime, $3, COUNT(DISTINCT a) AS w0$o0, SUM(DISTINCT a) AS w0$o1, AVG(DISTINCT $3) AS w0$o2])
+- Exchange(distribution=[hash[c]])
+- Calc(select=[a, c, rowtime, CAST(a) AS $3])
+- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime])
]]>
</Resource>
</TestCase>
<TestCase name="testRowTimeBoundedDistinctWithPartitionedRowsOver">
<Resource name="ast">
<![CDATA[
LogicalProject(c=[$2], _c1=[AS(COUNT(DISTINCT $0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST ROWS 2 PRECEDING), _UTF-16LE'_c1')], _c2=[AS(SUM(DISTINCT $0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST ROWS 2 PRECEDING), _UTF-16LE'_c2')], _c3=[AS(AVG(DISTINCT AS(CAST($0):FLOAT, _UTF-16LE'a')) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST ROWS 2 PRECEDING), _UTF-16LE'_c3')])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[c, w0$o0 AS _c1, w0$o1 AS _c2, w0$o2 AS _c3])
+- OverAggregate(partitionBy=[c], orderBy=[rowtime ASC], window=[ ROWS BETWEEN 2 PRECEDING AND CURRENT ROW], select=[a, c, rowtime, $3, COUNT(DISTINCT a) AS w0$o0, SUM(DISTINCT a) AS w0$o1, AVG(DISTINCT $3) AS w0$o2])
+- Exchange(distribution=[hash[c]])
+- Calc(select=[a, c, rowtime, CAST(a) AS $3])
+- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime])
]]>
</Resource>
</TestCase>
Expand Down Expand Up @@ -217,6 +251,40 @@ Calc(select=[c, w0$o0 AS _c1, w0$o1 AS wAvg])
+- Exchange(distribution=[hash[b]])
+- Calc(select=[b, c, rowtime, CAST(a) AS $3])
+- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime])
]]>
</Resource>
</TestCase>
<TestCase name="testRowTimeUnboundedDistinctWithPartitionedRangeOver">
<Resource name="ast">
<![CDATA[
LogicalProject(c=[$2], _c1=[AS(COUNT(DISTINCT $0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST), _UTF-16LE'_c1')], _c2=[AS(SUM(DISTINCT $0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST), _UTF-16LE'_c2')], _c3=[AS(AVG(DISTINCT AS(CAST($0):FLOAT, _UTF-16LE'a')) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST), _UTF-16LE'_c3')])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[c, w0$o0 AS _c1, w0$o1 AS _c2, w0$o2 AS _c3])
+- OverAggregate(partitionBy=[c], orderBy=[rowtime ASC], window=[ RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, rowtime, $3, COUNT(DISTINCT a) AS w0$o0, SUM(DISTINCT a) AS w0$o1, AVG(DISTINCT $3) AS w0$o2])
+- Exchange(distribution=[hash[c]])
+- Calc(select=[a, c, rowtime, CAST(a) AS $3])
+- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime])
]]>
</Resource>
</TestCase>
<TestCase name="testRowTimeUnboundedDistinctWithPartitionedRowsOver">
<Resource name="ast">
<![CDATA[
LogicalProject(c=[$2], _c1=[AS(COUNT(DISTINCT $0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST ROWS UNBOUNDED PRECEDING), _UTF-16LE'_c1')], _c2=[AS(SUM(DISTINCT $0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST ROWS UNBOUNDED PRECEDING), _UTF-16LE'_c2')], _c3=[AS(AVG(DISTINCT AS(CAST($0):FLOAT, _UTF-16LE'a')) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST ROWS UNBOUNDED PRECEDING), _UTF-16LE'_c3')])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[c, w0$o0 AS _c1, w0$o1 AS _c2, w0$o2 AS _c3])
+- OverAggregate(partitionBy=[c], orderBy=[rowtime ASC], window=[ ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, rowtime, $3, COUNT(DISTINCT a) AS w0$o0, SUM(DISTINCT a) AS w0$o1, AVG(DISTINCT $3) AS w0$o2])
+- Exchange(distribution=[hash[c]])
+- Calc(select=[a, c, rowtime, CAST(a) AS $3])
+- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime])
]]>
</Resource>
</TestCase>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,55 @@ class OverAggregateTest extends TableTestBase {
streamUtil.verifyExecPlan(result)
}

@Test
def testRowTimeBoundedDistinctWithPartitionedRangeOver(): Unit = {
val result = table
.window(Over partitionBy 'c orderBy 'rowtime preceding 2.hours following CURRENT_RANGE as 'w)
.select('c,
'a.count.distinct over 'w,
'a.sum.distinct over 'w,
('a.cast(DataTypes.FLOAT) as 'a).avg.distinct over 'w)

streamUtil.verifyExecPlan(result)
}

@Test
def testRowTimeUnboundedDistinctWithPartitionedRangeOver(): Unit = {
val result = table
.window(Over partitionBy 'c orderBy 'rowtime preceding UNBOUNDED_RANGE as 'w)
.select('c,
'a.count.distinct over 'w,
'a.sum.distinct over 'w,
('a.cast(DataTypes.FLOAT) as 'a).avg.distinct over 'w)

streamUtil.verifyExecPlan(result)
}

@Test
def testRowTimeBoundedDistinctWithPartitionedRowsOver(): Unit = {
val result = table
.window(Over partitionBy 'c orderBy 'rowtime preceding 2.rows following CURRENT_ROW as 'w)
.select('c,
'a.count.distinct over 'w,
'a.sum.distinct over 'w,
('a.cast(DataTypes.FLOAT) as 'a).avg.distinct over 'w)

streamUtil.verifyExecPlan(result)
}

@Test
def testRowTimeUnboundedDistinctWithPartitionedRowsOver(): Unit = {
val result = table
.window(Over partitionBy 'c orderBy 'rowtime preceding UNBOUNDED_ROW following
CURRENT_ROW as 'w)
.select('c,
'a.count.distinct over 'w,
'a.sum.distinct over 'w,
('a.cast(DataTypes.FLOAT) as 'a).avg.distinct over 'w)

streamUtil.verifyExecPlan(result)
}

@Test
def testRowTimeUnboundedPartitionedRowsOver(): Unit = {
val weightedAvg = new WeightedAvgWithRetract
Expand Down
Loading

0 comments on commit 31346e8

Please sign in to comment.