Skip to content

Commit

Permalink
add union all push down
Browse files Browse the repository at this point in the history
  • Loading branch information
junwen12221 committed Nov 29, 2021
1 parent a5fc417 commit d447de6
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 10 deletions.
54 changes: 49 additions & 5 deletions hbt/src/main/java/io/mycat/calcite/localrel/LocalRules.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.apache.calcite.rel.core.*;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.logical.LogicalTableScan;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.logical.ToLogicalConverter;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
Expand All @@ -23,6 +24,8 @@
import java.util.Optional;
import java.util.function.Consumer;

import static io.mycat.calcite.rewriter.SQLRBORewriter.view;

public class LocalRules {
public static final RelBuilderFactory LOCAL_BUILDER =
RelBuilder.proto(
Expand All @@ -42,7 +45,8 @@ public class LocalRules {
LocalRules.AggViewRule.DEFAULT_CONFIG.toRule(),
LocalRules.SortViewRule.DEFAULT_CONFIG.toRule(),
LocalRules.JoinViewRule.DEFAULT_CONFIG.toRule(),
LocalRules.CalcViewRule.DEFAULT_CONFIG.toRule()
LocalRules.CalcViewRule.DEFAULT_CONFIG.toRule(),
LocalRules.UnionAllViewRule.DEFAULT_CONFIG.toRule()
);

public static final List<RelOptRule> CBO_RULES = ImmutableList.of(
Expand Down Expand Up @@ -114,7 +118,7 @@ public UniversalFilterViewRule(Config config) {
public void onMatch(RelOptRuleCall call) {
Filter filter = call.rel(0);
MycatView view = call.rel(1);
SQLRBORewriter.view(view, LocalFilter.create(filter, view)).ifPresent(new Consumer<RelNode>() {
view(view, LocalFilter.create(filter, view)).ifPresent(new Consumer<RelNode>() {
@Override
public void accept(RelNode res) {
call.transformTo(normalize(res));
Expand Down Expand Up @@ -150,7 +154,7 @@ public PrimaryShardingTableFilterViewRule(Config config) {
public void onMatch(RelOptRuleCall call) {
Filter filter = call.rel(0);
MycatView view = call.rel(1);
SQLRBORewriter.view(view, LocalFilter.create(filter, view)).ifPresent(new Consumer<RelNode>() {
view(view, LocalFilter.create(filter, view)).ifPresent(new Consumer<RelNode>() {
@Override
public void accept(RelNode res) {
switch (view.getDistribution().type()) {
Expand Down Expand Up @@ -218,7 +222,7 @@ public ProjectViewRule(Config config) {
public void onMatch(RelOptRuleCall call) {
Project project = call.rel(0);
MycatView view = call.rel(1);
SQLRBORewriter.view(view, LocalProject.create(project, view)).ifPresent(res -> {
view(view, LocalProject.create(project, view)).ifPresent(res -> {
call.transformTo(normalize(res));
});
}
Expand Down Expand Up @@ -359,7 +363,7 @@ public CalcViewRule(Config config) {
public void onMatch(RelOptRuleCall call) {
Calc calc = call.rel(0);
MycatView input = call.rel(1);
SQLRBORewriter.view(input, LocalCalc.create(calc, input)).ifPresent(res -> {
view(input, LocalCalc.create(calc, input)).ifPresent(res -> {
call.transformTo(normalize(res));
});
}
Expand All @@ -379,6 +383,46 @@ default CalcViewRule.Config withOperandFor() {
}
}

public static class UnionAllViewRule extends RelRule<UnionAllViewRule.Config> {

public static final Config DEFAULT_CONFIG = Config.EMPTY.as(UnionAllViewRule.Config.class).withOperandFor();

public UnionAllViewRule(Config config) {
super(config);
}

@Override
public void onMatch(RelOptRuleCall call) {
Union union = call.rel(0);
MycatView left = call.rel(1);
MycatView right = call.rel(2);
List<RelNode> inputs = ImmutableList.of(left, right);
RelNode view = view(inputs, LogicalUnion.create(inputs, union.all));
if (view.getInputs().size()==2&&view.getInput(0) == left && view.getInput(1) == right) {
return;
}
call.transformTo(view);
}

public interface Config extends RelRule.Config {
@Override
default UnionAllViewRule toRule() {
return new UnionAllViewRule(this);
}

default UnionAllViewRule.Config withOperandFor() {
return withOperandSupplier(b0 ->
b0.operand(Union.class).predicate(union -> {
return union.all;
})
.inputs(b1 -> b1.operand(MycatView.class).noInputs(),
b1 -> b1.operand(MycatView.class).noInputs()))
.withDescription("UnionAllViewRule")
.as(UnionAllViewRule.Config.class);
}
}
}

@NotNull
private static ToLogicalConverter getToLogicalConverter(RelNode res) {
ToLogicalConverter toLogicalConverter = new ToLogicalConverter(MycatCalciteSupport.relBuilderFactory.create(res.getCluster(), null)) {
Expand Down
55 changes: 50 additions & 5 deletions hbt/src/main/java/io/mycat/calcite/rewriter/SQLRBORewriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import java.math.BigDecimal;
import java.text.NumberFormat;
import java.util.*;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;


Expand Down Expand Up @@ -588,10 +589,10 @@ public static Optional<RelNode> aggregate(RelNode original, Aggregate aggregate)
}
}
}

private static Optional<RelNode> splitAggregate(MycatView viewNode, Aggregate aggregate) {

AggregatePushContext aggregateContext = AggregatePushContext.split(aggregate);
AggregatePushContext aggregateContext = AggregatePushContext.split(aggregate);

MycatView newView = viewNode.changeTo(
LogicalAggregate.create(viewNode.getRelNode(),
Expand All @@ -614,12 +615,12 @@ private static Optional<RelNode> splitAggregate(MycatView viewNode, Aggregate ag
}


public static RelNode view(List<RelNode> inputs, LogicalUnion union) {
public static RelNode view(List<RelNode> inputs, Union union) {
if (union.all) {
List<RelNode> children = new ArrayList<>();
for (RelNode input : inputs) {
if (input instanceof LogicalUnion) {
LogicalUnion bottomUnion = (LogicalUnion) input;
Union bottomUnion = (Union) input;
if (bottomUnion.all) {
children.addAll(bottomUnion.getInputs());
} else {
Expand All @@ -629,7 +630,51 @@ public static RelNode view(List<RelNode> inputs, LogicalUnion union) {
children.add(input);
}
}
return union.copy(union.getTraitSet(), children);
inputs = children;

List<MycatView> inputViews = new LinkedList<>();
List<RelNode> newViews = new ArrayList<>();

List<RelNode> others = new ArrayList<>();

for (RelNode input : inputs) {
if (input instanceof MycatView) {
inputViews.add((MycatView) input);
} else {
others.add(input);
}
}
if (inputViews.size() > 1) {
MycatView left = inputViews.get(0);
List<MycatView> matchViews = new ArrayList<>();
matchViews.add(left);
List<MycatView> failViews = new ArrayList<>();
Distribution distribution = null;
for (MycatView right : inputViews.subList(1, inputViews.size())) {
Optional<Distribution> distributionOptional = left.getDistribution().join(right.getDistribution());
if (distributionOptional.isPresent()) {
distribution = distributionOptional.get();
matchViews.add(right);
continue;
}
failViews.add(right);
}
if (distribution != null) {
newViews.add(left.changeTo(union.copy(union.getTraitSet(),
(List) matchViews.stream().map(i -> i.getRelNode()).collect(Collectors.toList())),
distribution)
);
}else {
newViews.addAll(matchViews);
}
newViews.addAll(failViews);
} else {
newViews.addAll(inputViews);
}
inputs = (List) ImmutableList.builder().addAll(newViews).addAll(others).build();
}
if (inputs.size() == 1) {
return inputs.get(0);
}
return union.copy(union.getTraitSet(), inputs);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package io.mycat.drdsrunner;

import io.mycat.DrdsConst;
import io.mycat.DrdsExecutorCompiler;
import io.mycat.DrdsSqlCompiler;
import io.mycat.DrdsSqlWithParams;
import io.mycat.calcite.DrdsRunnerHelper;
import io.mycat.calcite.MycatRel;
import io.mycat.calcite.rewriter.OptimizationContext;
import io.mycat.calcite.spm.Plan;
import io.mycat.calcite.spm.PlanImpl;
import io.mycat.calcite.spm.SpecificSql;
import io.mycat.calcite.table.SchemaHandler;
import io.mycat.util.NameMap;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;

import javax.annotation.concurrent.NotThreadSafe;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;

@NotThreadSafe
@net.jcip.annotations.NotThreadSafe
public class UnionAllTest extends DrdsTest {

@BeforeClass
public static void beforeClass() {
DrdsTest.drdsRunner = null;
DrdsTest.metadataManager = null;
}

public static Explain parse(String sql) {
DrdsSqlCompiler drds = getDrds();
DrdsSqlWithParams drdsSqlWithParams = DrdsRunnerHelper.preParse(sql, null);
OptimizationContext optimizationContext = new OptimizationContext();
MycatRel dispatch = drds.dispatch(optimizationContext, drdsSqlWithParams);
Plan plan = new PlanImpl(dispatch, DrdsExecutorCompiler.getCodeExecuterContext(optimizationContext.relNodeContext.getConstantMap(), dispatch, false), drdsSqlWithParams.getAliasList());
return new Explain(plan, drdsSqlWithParams);
}

@Test
public void testSelectTest() throws Exception {
Explain explain = parse("select 1 union all select 1");
Assert.assertEquals("[{columnType=INTEGER, nullable=false, columnName=1}]", explain.getColumnInfo());
Assert.assertEquals("MycatUnion(all=[true]) MycatProject(?=[?0]) MycatValues(tuples=[[{ 0 }]]) MycatProject(?=[?1]) MycatValues(tuples=[[{ 0 }]])", explain.dumpPlan());
}

@Test
public void testSelectNormal() throws Exception {
Explain explain = parse("select 1 from db1.normal union all select 1 from db1.normal");
Assert.assertEquals("[{columnType=INTEGER, nullable=false, columnName=1}]", explain.getColumnInfo());
Assert.assertEquals("MycatView(distribution=[[db1.normal]])", explain.dumpPlan());
}

@Test
public void testSelectNormalGlobal() throws Exception {
Explain explain = parse("select 1 from db1.normal union all select 1 from db1.global");
Assert.assertEquals("[{columnType=INTEGER, nullable=false, columnName=1}]",
explain.getColumnInfo());
Assert.assertEquals("MycatView(distribution=[[db1.global, db1.normal]])", explain.dumpPlan());
}
@Test
public void testSelectNormalSharding() throws Exception {
Explain explain = parse("select 1 from db1.normal union all select 1 from db1.sharding");
Assert.assertEquals("[{columnType=INTEGER, nullable=false, columnName=1}]",
explain.getColumnInfo());
Assert.assertEquals("MycatUnion(all=[true]) MycatView(distribution=[[db1.normal]]) MycatView(distribution=[[db1.sharding]])", explain.dumpPlan());
}

}

0 comments on commit d447de6

Please sign in to comment.