Skip to content

Commit

Permalink
fix bkajoin step 1
Browse files Browse the repository at this point in the history
  • Loading branch information
junwen12221 committed May 26, 2022
1 parent 1291e65 commit f2f10c5
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 31 deletions.
30 changes: 19 additions & 11 deletions hbt/src/main/java/io/mycat/calcite/logical/MycatView.java
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public MycatView(RelTraitSet relTrait, RelNode input, Distribution dataNode, Rex
// if (input instanceof MycatRel) {
// input = input.accept(new ToLocalConverter());
// }
if (input instanceof MycatRel){
if (input instanceof MycatRel) {
LOGGER.debug("may be a bug,MycatView input is MycatRel");
}
ToLocalConverter toLocalConverter = new ToLocalConverter();
Expand Down Expand Up @@ -131,6 +131,7 @@ public RexNode visitCall(RexCall call) {
}

public static ProjectIndexMapping project(ShardingIndexTable shardingIndexTable, List<Integer> projects) {
ShardingTable primaryTable = shardingIndexTable.getPrimaryTable();
ArrayList<String> restColumnListBuilder = new ArrayList<>();
ArrayList<String> indexColumnListBuilder = new ArrayList<>();
ShardingTable factTable = shardingIndexTable.getFactTable();
Expand All @@ -139,10 +140,17 @@ public static ProjectIndexMapping project(ShardingIndexTable shardingIndexTable,
Objects.requireNonNull(index);
String columnName = factTable.getColumns().get(index).getColumnName();
boolean covering = shardingIndexTable.getColumnByName(columnName) != null;
if (covering) {
boolean mustColumn = primaryTable.getColumns().stream()
.anyMatch(c -> c.getColumnName().equals(columnName) && (c.isShardingKey() || c.isPrimaryKey()));
if (mustColumn) {
indexColumnListBuilder.add(columnName);
} else {
restColumnListBuilder.add(columnName);
} else {
if (covering) {
indexColumnListBuilder.add(columnName);
} else {
restColumnListBuilder.add(columnName);
}
}
}
List<String> indexEqualKeys = (List) ImmutableList.builder().addAll(shardingIndexTable.getLogicTable().getShardingKeys())
Expand Down Expand Up @@ -252,10 +260,10 @@ public static List<RelNode> produceIndexViews(
RelNode rightProject = createMycatProject(primaryTableScan, indexMapping.getFactColumns());

String[] primaryOrShardingKeys = shardingTable.getLogicTable().getRawColumns().stream()
.filter(i -> i.isPrimaryKey() || i.isShardingKey()).filter(i->
leftProject.getRowType().getFieldNames().contains(i.getColumnName())&&
.filter(i -> i.isPrimaryKey() || i.isShardingKey()).filter(i ->
leftProject.getRowType().getFieldNames().contains(i.getColumnName()) &&
rightProject.getRowType().getFieldNames().contains(i.getColumnName())
).map(i->i.getColumnName()).toArray(n -> new String[n]);
).map(i -> i.getColumnName()).toArray(n -> new String[n]);

Join relNode = (Join) relBuilder

Expand All @@ -271,7 +279,7 @@ public static List<RelNode> produceIndexViews(

if (RelOptUtil.areRowTypesEqual(orginalRowType, mycatProject.getRowType(), false)) {
tableArrayList.add(mycatProject);
}else {
} else {
RelNode newRel = RelOptUtil.createCastRel(mycatProject, orginalRowType, true, (input, hints, childExprs, fieldNames) -> MycatProject.create(input, childExprs, orginalRowType));
if (RelOptUtil.areRowTypesEqual(orginalRowType, newRel.getRowType(), false)) {
tableArrayList.add(newRel);
Expand All @@ -282,8 +290,8 @@ public static List<RelNode> produceIndexViews(
}
}
return (List) tableArrayList;
}catch (Throwable throwable){
LOGGER.error("",throwable);
} catch (Throwable throwable) {
LOGGER.error("", throwable);
return Collections.emptyList();
}
}
Expand Down Expand Up @@ -338,11 +346,11 @@ public static MycatRel createMycatProject(RelNode indexTableScan, List<String> i
// }
project = MycatProject.create(project.getInput(0), projects, project.getRowType());
}
if (project instanceof MycatProject){
if (project instanceof MycatProject) {
MycatProject mycatProject = (MycatProject) project;
if (mycatProject.getInput() instanceof MycatView) {
MycatView mycatProjectInput = (MycatView) mycatProject.getInput();
return mycatProjectInput.changeTo(LocalProject.create(mycatProject,mycatProjectInput.getRelNode()));
return mycatProjectInput.changeTo(LocalProject.create(mycatProject, mycatProjectInput.getRelNode()));
}
}
return (MycatRel) project;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@
import lombok.Getter;
import org.apache.calcite.adapter.enumerable.*;
import org.apache.calcite.linq4j.tree.*;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.plan.*;
import org.apache.calcite.rel.*;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.*;
import org.apache.calcite.runtime.NewMycatDataContext;
import org.apache.calcite.sql.*;
Expand All @@ -38,7 +37,9 @@
import org.jetbrains.annotations.NotNull;

import java.lang.reflect.Method;
import java.time.LocalDateTime;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;

import static io.mycat.calcite.MycatImplementor.MYCAT_SQL_LOOKUP_IN;
Expand Down Expand Up @@ -247,7 +248,8 @@ public static Observable<Object[]> dispatchRightObservable(NewMycatDataContext c
if (argsList.isEmpty()) {
return Observable.empty();
}
RexShuttle rexShuttle = argSolver(argsList);
List<String> rightFieldNames = RelOptUtil.findAllTables(rightView.getRelNode()).get(0).getRowType().getFieldNames();
RexShuttle rexShuttle = argSolver(tableLookup.getInput(), argsList, rightFieldNames);
RelNode mycatInnerRelNode = rightView.getRelNode().accept(new RelShuttleImpl() {
@Override
public RelNode visit(RelNode other) {
Expand All @@ -274,7 +276,7 @@ public RelNode visit(RelNode other) {
}

@NotNull
private static RexShuttle argSolver(List<Object[]> argsList) {
private static RexShuttle argSolver(RelNode left, List<Object[]> argsList, List<String> rightFieldNames) {
return new RexShuttle() {
@Override
public void visitEach(Iterable<? extends RexNode> exprs) {
Expand All @@ -298,26 +300,68 @@ public RexNode visitDynamicParam(RexDynamicParam dynamicParam) {

@Override
public RexNode visitCall(RexCall call) {
RexBuilder rexBuilder = MycatCalciteSupport.RexBuilder;
if (call.getOperator() == MYCAT_SQL_LOOKUP_IN) {
List<RexNode> operands = call.getOperands();
RexCall exprRow = (RexCall) operands.get(0);
RexCall valueRow = (RexCall) operands.get(1);
if (argsList.size() == 1) {
List<String> columnNames = exprRow.getOperands().stream().map(i -> ((RexInputRef) i).getIndex()).map(i -> rightFieldNames.get(i)).collect(Collectors.toList());
List<String> valueNames = left.getRowType().getFieldNames();
columnNames = columnNames.stream().filter(i -> valueNames.contains(i)).collect(Collectors.toList());
List<RexNode> collect = exprRow.getOperands().stream().filter(rexNode -> {
int index = ((RexInputRef) rexNode).getIndex();
return valueNames.contains( rightFieldNames.get(index));
}).collect(Collectors.toList());
exprRow = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.ROW, collect);
List<List<RexLiteral>> rowList = new ArrayList<>();
for (Object[] args : argsList) {
List<RexLiteral> row = new ArrayList<>();
for (int i = 0; i < columnNames.size(); i++) {
String needColumnName = columnNames.get(i);


for (int j = 0; j < valueNames.size(); j++) {
String fieldName = valueNames.get(j);
if (needColumnName.equals(fieldName)) {
RelDataTypeField relDataTypeField = left.getRowType().getFieldList().get(j);
RelDataType fieldType = relDataTypeField.getType();
RexLiteral right = (RexLiteral) MycatCalciteSupport.RexBuilder.makeLiteral(Optional.ofNullable(args[j]).map(new Function<Object, String>() {
@Override
public String apply(Object o) {
if (o instanceof LocalDateTime) {
return java.sql.Timestamp.valueOf((LocalDateTime) o).toString();
}
return o.toString();
}
}).orElse(null));
row.add(right);
}
}
}
rowList.add(row);
}
if (rowList.get(0).size() != columnNames.size()) {
throw new UnsupportedOperationException("may be a bug");
}

if (rowList.size() == 1) {
ArrayList<RexNode> ands = new ArrayList<>();
for (int i = 0; i < exprRow.getOperands().size(); i++) {
RexNode right = MycatCalciteSupport.RexBuilder.makeLiteral(Objects.toString(argsList.get(0)[i]));
RexNode right = rowList.get(0).get(i);
RexNode left = exprRow.getOperands().get(i);
ands.add(MycatCalciteSupport.RexBuilder
.makeCall(SqlStdOperatorTable.EQUALS,left,MycatCalciteSupport.RexBuilder.makeCast(left.getType(), right)));
.makeCall(SqlStdOperatorTable.EQUALS, left, MycatCalciteSupport.RexBuilder.makeCast(left.getType(), right)));
}
if (ands.size()==1){
if (ands.size() == 1) {
return ands.get(0);
}
return MycatCalciteSupport.RexBuilder.makeCall(SqlStdOperatorTable.AND, ands);
return rexBuilder.makeCall(SqlStdOperatorTable.AND, ands);
}
List<RexNode> rowRexNodeList = new ArrayList<>();
for (List<RexLiteral> rexLiterals : rowList) {
RexNode rexNode = rexBuilder.makeCall(SqlStdOperatorTable.ROW, rexLiterals);
rowRexNodeList.add(rexNode);
}
LinkedList<RexNode> accept = MycatTableLookupValues.apply(true, argsList, valueRow.getOperands());
RexNode rexNode1 = MycatCalciteSupport.RexBuilder.makeIn(exprRow, accept);
return RexUtil.expandSearch(MycatCalciteSupport.RexBuilder, null, rexNode1);
return rexBuilder.makeIn(exprRow, rowRexNodeList);
}
return super.visitCall(call);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,19 @@ public void onMatch(RelOptRuleCall call) {
if (orginalTableSet.size() > 1) {
return;//右表不能是多个
}
String columnName = mycatLogicTable.getRowType().getFieldNames().get(originColumnOrdinal);
RexInputRef rexInputRef = new RexInputRef(originColumnOrdinal, mycatLogicTable.getRowType().getFieldList().get(originColumnOrdinal).getType());
rightExprs.add(rexInputRef);

CorrelationId correl = cluster.createCorrel();
correlationIds.add(correl);
RelDataType type = left.getRowType().getFieldList().get(pair.source).getType();
RexNode rexNode = rexBuilder.makeCorrel(typeFactory.createUnknownType(), correl);
leftExprs.add(rexBuilder.makeCast(type, rexNode));

} else {
continue;//不是原始字段,跳过
}
CorrelationId correl = cluster.createCorrel();
correlationIds.add(correl);
RelDataType type = left.getRowType().getFieldList().get(pair.source).getType();
RexNode rexNode = rexBuilder.makeCorrel(typeFactory.createUnknownType(), correl);
leftExprs.add(rexBuilder.makeCast(type, rexNode));
}
if (rightExprs.isEmpty()) {
return;
Expand Down

0 comments on commit f2f10c5

Please sign in to comment.