diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/DruidExpression.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/DruidExpression.java index 7faac91ab096..ec257230e7be 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/DruidExpression.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/DruidExpression.java @@ -26,6 +26,9 @@ import com.google.common.primitives.Chars; import org.apache.calcite.rel.type.RelDataType; import org.apache.druid.math.expr.Expr; +import org.apache.druid.math.expr.ExprEval; +import org.apache.druid.math.expr.ExprType; +import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.segment.VirtualColumn; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.virtual.ExpressionVirtualColumn; @@ -180,6 +183,19 @@ public static DruidExpression ofLiteral( ); } + /** + * Create a literal expression from an {@link ExprEval}. + */ + public static DruidExpression ofLiteral(final DruidLiteral literal) + { + if (literal.type() != null && literal.type().is(ExprType.STRING)) { + return ofStringLiteral((String) literal.value()); + } else { + final ColumnType evalColumnType = literal.type() != null ? ExpressionType.toColumnType(literal.type()) : null; + return ofLiteral(evalColumnType, ExprEval.ofType(literal.type(), literal.value()).toExpr().stringify()); + } + } + public static DruidExpression ofStringLiteral(final String s) { return ofLiteral(ColumnType.STRING, stringLiteral(s)); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/DruidLiteral.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/DruidLiteral.java new file mode 100644 index 000000000000..616530035372 --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/DruidLiteral.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.expression; + +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.math.expr.ExprEval; +import org.apache.druid.math.expr.ExpressionType; + +import javax.annotation.Nullable; + +/** + * Literal value, plus a {@link ExpressionType} that represents how to interpret the literal value. + * + * These are similar to {@link ExprEval}, but not identical: unlike {@link ExprEval}, string values in this class + * are not normalized through {@link NullHandling#emptyToNullIfNeeded(String)}. This allows us to differentiate + * between null and empty-string literals even when {@link NullHandling#replaceWithDefault()}. + */ +public class DruidLiteral +{ + @Nullable + private final ExpressionType type; + + @Nullable + private final Object value; + + DruidLiteral(final ExpressionType type, @Nullable final Object value) + { + this.type = type; + this.value = value; + } + + @Nullable + public ExpressionType type() + { + return type; + } + + @Nullable + public Object value() + { + return value; + } + + public DruidLiteral castTo(final ExpressionType toType) + { + if (type.equals(toType)) { + return this; + } + + final ExprEval castEval = ExprEval.ofType(type, value).castTo(toType); + return new DruidLiteral(castEval.type(), castEval.value()); + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java index 0d3e2505853d..ec518ee4522f 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java @@ -38,6 +38,7 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.math.expr.Expr; +import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.query.aggregation.PostAggregator; import org.apache.druid.query.expression.TimestampFloorExprMacro; import org.apache.druid.query.extraction.ExtractionFn; @@ -240,7 +241,8 @@ public static DruidExpression toDruidExpressionWithPostAggOperands( } else if (rexNode instanceof RexCall) { return rexCallToDruidExpression(plannerContext, rowSignature, rexNode, postAggregatorVisitor); } else if (kind == SqlKind.LITERAL) { - return literalToDruidExpression(plannerContext, rexNode); + final DruidLiteral eval = calciteLiteralToDruidLiteral(plannerContext, rexNode); + return eval != null ? DruidExpression.ofLiteral(eval) : null; } else { // Can't translate. return null; @@ -306,61 +308,85 @@ private static DruidExpression rexCallToDruidExpression( } } + /** + * Create a {@link DruidLiteral} from a literal {@link RexNode}. Necessary because Calcite represents literals using + * different Java classes than Druid does. + * + * @param plannerContext planner context + * @param rexNode Calcite literal + * + * @return converted literal, or null if the literal cannot be converted + */ @Nullable - static DruidExpression literalToDruidExpression( + public static DruidLiteral calciteLiteralToDruidLiteral( final PlannerContext plannerContext, final RexNode rexNode ) { - final SqlTypeName sqlTypeName = rexNode.getType().getSqlTypeName(); + if (rexNode.isA(SqlKind.CAST)) { + if (SqlTypeFamily.DATE.contains(rexNode.getType())) { + // Cast to DATE suggests some timestamp flooring. We don't deal with that here, so return null. + return null; + } + + final DruidLiteral innerLiteral = + calciteLiteralToDruidLiteral(plannerContext, ((RexCall) rexNode).getOperands().get(0)); + if (innerLiteral == null) { + return null; + } + + final ColumnType castToColumnType = Calcites.getColumnTypeForRelDataType(rexNode.getType()); + if (castToColumnType == null) { + return null; + } + + final ExpressionType castToExprType = ExpressionType.fromColumnType(castToColumnType); + if (castToExprType == null) { + return null; + } + + return innerLiteral.castTo(castToExprType); + } // Translate literal. - final ColumnType columnType = Calcites.getColumnTypeForRelDataType(rexNode.getType()); + final SqlTypeName sqlTypeName = rexNode.getType().getSqlTypeName(); + final DruidLiteral retVal; + if (RexLiteral.isNullLiteral(rexNode)) { - return DruidExpression.ofLiteral(columnType, DruidExpression.nullLiteral()); + final ColumnType columnType = Calcites.getColumnTypeForRelDataType(rexNode.getType()); + final ExpressionType expressionType = columnType == null ? null : ExpressionType.fromColumnTypeStrict(columnType); + retVal = new DruidLiteral(expressionType, null); } else if (SqlTypeName.INT_TYPES.contains(sqlTypeName)) { final Number number = (Number) RexLiteral.value(rexNode); - return DruidExpression.ofLiteral( - columnType, - number == null ? DruidExpression.nullLiteral() : DruidExpression.longLiteral(number.longValue()) - ); + retVal = new DruidLiteral(ExpressionType.LONG, number == null ? null : number.longValue()); } else if (SqlTypeName.NUMERIC_TYPES.contains(sqlTypeName)) { // Numeric, non-INT, means we represent it as a double. final Number number = (Number) RexLiteral.value(rexNode); - return DruidExpression.ofLiteral( - columnType, - number == null ? DruidExpression.nullLiteral() : DruidExpression.doubleLiteral(number.doubleValue()) - ); + retVal = new DruidLiteral(ExpressionType.DOUBLE, number == null ? null : number.doubleValue()); } else if (SqlTypeFamily.INTERVAL_DAY_TIME == sqlTypeName.getFamily()) { // Calcite represents DAY-TIME intervals in milliseconds. final long milliseconds = ((Number) RexLiteral.value(rexNode)).longValue(); - return DruidExpression.ofLiteral(columnType, DruidExpression.longLiteral(milliseconds)); + retVal = new DruidLiteral(ExpressionType.LONG, milliseconds); } else if (SqlTypeFamily.INTERVAL_YEAR_MONTH == sqlTypeName.getFamily()) { // Calcite represents YEAR-MONTH intervals in months. final long months = ((Number) RexLiteral.value(rexNode)).longValue(); - return DruidExpression.ofLiteral(columnType, DruidExpression.longLiteral(months)); + retVal = new DruidLiteral(ExpressionType.LONG, months); } else if (SqlTypeName.STRING_TYPES.contains(sqlTypeName)) { - return DruidExpression.ofStringLiteral(RexLiteral.stringValue(rexNode)); + final String s = RexLiteral.stringValue(rexNode); + retVal = new DruidLiteral(ExpressionType.STRING, s); } else if (SqlTypeName.TIMESTAMP == sqlTypeName || SqlTypeName.DATE == sqlTypeName) { - if (RexLiteral.isNullLiteral(rexNode)) { - return DruidExpression.ofLiteral(columnType, DruidExpression.nullLiteral()); - } else { - return DruidExpression.ofLiteral( - columnType, - DruidExpression.longLiteral( - Calcites.calciteDateTimeLiteralToJoda(rexNode, plannerContext.getTimeZone()).getMillis() - ) - ); - } - } else if (SqlTypeName.BOOLEAN == sqlTypeName) { - return DruidExpression.ofLiteral( - columnType, - DruidExpression.longLiteral(RexLiteral.booleanValue(rexNode) ? 1 : 0) + retVal = new DruidLiteral( + ExpressionType.LONG, + Calcites.calciteDateTimeLiteralToJoda(rexNode, plannerContext.getTimeZone()).getMillis() ); + } else if (SqlTypeName.BOOLEAN == sqlTypeName) { + retVal = new DruidLiteral(ExpressionType.LONG, RexLiteral.booleanValue(rexNode) ? 1L : 0L); } else { // Can't translate other literals. return null; } + + return retVal; } /** @@ -647,8 +673,8 @@ private static DimFilter toSimpleLeafFilter( final DruidExpression rhsExpression = toDruidExpression(plannerContext, rowSignature, rhs); final Expr rhsParsed = rhsExpression != null - ? plannerContext.parseExpression(rhsExpression.getExpression()) - : null; + ? plannerContext.parseExpression(rhsExpression.getExpression()) + : null; // rhs must be a literal if (rhsParsed == null || !rhsParsed.isLiteral()) { return null; @@ -815,7 +841,9 @@ private static DimFilter toSimpleLeafFilter( } } else if (rexNode instanceof RexCall) { final SqlOperator operator = ((RexCall) rexNode).getOperator(); - final SqlOperatorConversion conversion = plannerContext.getPlannerToolbox().operatorTable().lookupOperatorConversion(operator); + final SqlOperatorConversion conversion = plannerContext.getPlannerToolbox() + .operatorTable() + .lookupOperatorConversion(operator); if (conversion == null) { return null; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java index d3e73cf3d681..38f2b6d5a8fa 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java @@ -25,7 +25,6 @@ import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; -import org.apache.druid.common.config.NullHandling; import org.apache.druid.math.expr.Evals; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; @@ -34,10 +33,8 @@ import org.apache.druid.query.filter.ArrayContainsElementFilter; import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.filter.EqualityFilter; -import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.NullFilter; import org.apache.druid.query.filter.OrDimFilter; -import org.apache.druid.query.filter.TypedInFilter; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.expression.DruidExpression; @@ -158,27 +155,13 @@ public DimFilter toDruidFilter( ); } } else { - if (plannerContext.isUseBoundsAndSelectors() || NullHandling.replaceWithDefault() || !simpleExtractionExpr.isDirectColumnAccess()) { - final InDimFilter.ValuesSet valuesSet = InDimFilter.ValuesSet.create(); - for (final Object arrayElement : arrayElements) { - valuesSet.add(Evals.asString(arrayElement)); - } - - return new InDimFilter( - simpleExtractionExpr.getSimpleExtraction().getColumn(), - valuesSet, - simpleExtractionExpr.getSimpleExtraction().getExtractionFn(), - null - ); - } else { - return new TypedInFilter( - simpleExtractionExpr.getSimpleExtraction().getColumn(), - ExpressionType.toColumnType((ExpressionType) exprEval.type().getElementType()), - Arrays.asList(arrayElements), - null, - null - ); - } + return ScalarInArrayOperatorConversion.makeInFilter( + plannerContext, + simpleExtractionExpr.getSimpleExtraction().getColumn(), + simpleExtractionExpr.getSimpleExtraction().getExtractionFn(), + Arrays.asList(arrayElements), + ExpressionType.toColumnType((ExpressionType) exprEval.type().getElementType()) + ); } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ScalarInArrayOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ScalarInArrayOperatorConversion.java index f6e3dcecf9d7..8a18ce73df22 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ScalarInArrayOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ScalarInArrayOperatorConversion.java @@ -19,32 +19,135 @@ package org.apache.druid.sql.calcite.expression.builtin; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.druid.math.expr.Evals; +import org.apache.druid.query.extraction.ExtractionFn; +import org.apache.druid.query.filter.DimFilter; +import org.apache.druid.query.filter.InDimFilter; +import org.apache.druid.query.filter.TypedInFilter; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.expression.DirectOperatorConversion; +import org.apache.druid.sql.calcite.expression.DruidExpression; +import org.apache.druid.sql.calcite.expression.DruidLiteral; +import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.OperatorConversions; +import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.List; public class ScalarInArrayOperatorConversion extends DirectOperatorConversion { - private static final SqlFunction SQL_FUNCTION = OperatorConversions - .operatorBuilder("SCALAR_IN_ARRAY") - .operandTypeChecker( - OperandTypes.sequence( - "'SCALAR_IN_ARRAY(expr, array)'", - OperandTypes.or( - OperandTypes.family(SqlTypeFamily.CHARACTER), - OperandTypes.family(SqlTypeFamily.NUMERIC) - ), - OperandTypes.family(SqlTypeFamily.ARRAY) - ) + public static final SqlFunction SQL_FUNCTION = OperatorConversions + .operatorBuilder("SCALAR_IN_ARRAY") + .operandTypeChecker( + OperandTypes.sequence( + "'SCALAR_IN_ARRAY(expr, array)'", + OperandTypes.or( + OperandTypes.family(SqlTypeFamily.CHARACTER), + OperandTypes.family(SqlTypeFamily.NUMERIC) + ), + OperandTypes.family(SqlTypeFamily.ARRAY) ) - .returnTypeInference(ReturnTypes.BOOLEAN_NULLABLE) - .build(); + ) + .returnTypeInference(ReturnTypes.BOOLEAN_NULLABLE) + .build(); public ScalarInArrayOperatorConversion() { super(SQL_FUNCTION, "scalar_in_array"); } + + @Nullable + @Override + public DimFilter toDruidFilter( + final PlannerContext plannerContext, + final RowSignature rowSignature, + @Nullable final VirtualColumnRegistry virtualColumnRegistry, + final RexNode rexNode + ) + { + final RexCall call = (RexCall) rexNode; + final RexNode scalarOperand = call.getOperands().get(0); + final RexNode arrayOperand = call.getOperands().get(1); + final DruidExpression scalarExpression = Expressions.toDruidExpression(plannerContext, rowSignature, scalarOperand); + final String scalarColumn; + final ExtractionFn scalarExtractionFn; + + if (scalarExpression == null) { + return null; + } + + if (scalarExpression.isDirectColumnAccess()) { + scalarColumn = scalarExpression.getDirectColumn(); + scalarExtractionFn = null; + } else if (scalarExpression.isSimpleExtraction() && plannerContext.isUseLegacyInFilter()) { + scalarColumn = scalarExpression.getSimpleExtraction().getColumn(); + scalarExtractionFn = scalarExpression.getSimpleExtraction().getExtractionFn(); + } else { + scalarColumn = virtualColumnRegistry.getOrCreateVirtualColumnForExpression( + scalarExpression, + scalarExpression.getDruidType() + ); + scalarExtractionFn = null; + } + + if (Calcites.isLiteral(arrayOperand, true, true)) { + final RelDataType elementType = arrayOperand.getType().getComponentType(); + final List arrayElements = ((RexCall) arrayOperand).getOperands(); + final List arrayElementLiteralValues = new ArrayList<>(arrayElements.size()); + + for (final RexNode arrayElement : arrayElements) { + final DruidLiteral arrayElementEval = Expressions.calciteLiteralToDruidLiteral(plannerContext, arrayElement); + if (arrayElementEval == null) { + return null; + } + + arrayElementLiteralValues.add(arrayElementEval.value()); + } + + return makeInFilter( + plannerContext, + scalarColumn, + scalarExtractionFn, + arrayElementLiteralValues, + Calcites.getColumnTypeForRelDataType(elementType) + ); + } + + return null; + } + + /** + * Create an {@link InDimFilter} or {@link TypedInFilter} based on a list of provided values. + */ + public static DimFilter makeInFilter( + final PlannerContext plannerContext, + final String columnName, + @Nullable final ExtractionFn extractionFn, + final List matchValues, + final ColumnType matchValueType + ) + { + if (plannerContext.isUseLegacyInFilter() || extractionFn != null) { + final InDimFilter.ValuesSet valuesSet = InDimFilter.ValuesSet.create(); + for (final Object matchValue : matchValues) { + valuesSet.add(Evals.asString(matchValue)); + } + + return new InDimFilter(columnName, valuesSet, extractionFn, null); + } else { + return new TypedInFilter(columnName, matchValueType, matchValues, null, null); + } + } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeArithmeticOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeArithmeticOperatorConversion.java index 734b98117b11..5f7025bfa89c 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeArithmeticOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeArithmeticOperatorConversion.java @@ -100,7 +100,8 @@ public DruidExpression toDruidExpression( expression -> rightRexNode.isA(SqlKind.LITERAL) ? StringUtils.format("'P%sM'", RexLiteral.value(rightRexNode)) : - StringUtils.format("concat('P', %s, 'M')", expression) + StringUtils.format("concat('P', %s, 'M')", expression), + ColumnType.STRING ), DruidExpression.ofLiteral(ColumnType.LONG, DruidExpression.longLiteral(direction > 0 ? 1 : -1)), DruidExpression.ofStringLiteral(plannerContext.getTimeZone().getID()) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java index cbaa300d5296..281fc66c8aaa 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java @@ -37,6 +37,8 @@ import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; +import org.apache.druid.query.filter.InDimFilter; +import org.apache.druid.query.filter.TypedInFilter; import org.apache.druid.query.lookup.LookupExtractor; import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; import org.apache.druid.query.lookup.RegisteredLookupExtractionFn; @@ -370,13 +372,21 @@ public boolean isStringifyArrays() * {@link org.apache.druid.query.filter.EqualityFilter}, and {@link org.apache.druid.query.filter.NullFilter} (false). * * Typically true when {@link NullHandling#replaceWithDefault()} and false when {@link NullHandling#sqlCompatible()}. - * Can be overriden by the undocumented context parameter {@link #CTX_SQL_USE_BOUNDS_AND_SELECTORS}. + * Can be overriden by the context parameter {@link #CTX_SQL_USE_BOUNDS_AND_SELECTORS}. */ public boolean isUseBoundsAndSelectors() { return useBoundsAndSelectors; } + /** + * Whether we should use {@link InDimFilter} (true) or {@link TypedInFilter} (false). + */ + public boolean isUseLegacyInFilter() + { + return useBoundsAndSelectors || NullHandling.replaceWithDefault(); + } + /** * Whether we should use {@link AggregatePullUpLookupRule} to pull LOOKUP functions on injective lookups up above * a GROUP BY. diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java index b90721a46542..b5ffd3089afa 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java @@ -65,6 +65,7 @@ import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; import org.apache.druid.query.topn.DimensionTopNMetricSpec; import org.apache.druid.query.topn.TopNQueryBuilder; +import org.apache.druid.segment.VirtualColumns; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.join.JoinType; @@ -1353,69 +1354,128 @@ public void testArrayContainsFilterWithDynamicParameter() @Test public void testScalarInArrayFilter() { - msqIncompatible(); testQuery( - "SELECT dim2 FROM druid.numfoo WHERE SCALAR_IN_ARRAY(dim2, ARRAY['a', 'd']) LIMIT 5", - ImmutableList.of( - newScanQueryBuilder() - .dataSource(CalciteTests.DATASOURCE3) - .intervals(querySegmentSpec(Filtration.eternity())) - .filters( - new ExpressionDimFilter("scalar_in_array(\"dim2\",array('a','d'))", ExprMacroTable.nil()) - ) - .columns("dim2") - .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) - .limit(5) - .context(QUERY_CONTEXT_DEFAULT) - .build() - ), - ImmutableList.of( - new Object[]{"a"}, - new Object[]{"a"} - ) + "SELECT dim2 FROM druid.numfoo\n" + + "WHERE\n" + + " SCALAR_IN_ARRAY(dim2, ARRAY['a', 'd'])\n" + + " OR SCALAR_IN_ARRAY(SUBSTRING(dim1, 1, 1), ARRAY[NULL, 'foo', 'bar'])\n" + + " OR SCALAR_IN_ARRAY(cnt * 2, ARRAY[3])\n", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns( + VirtualColumns.create( + NullHandling.sqlCompatible() + ? ImmutableList.of( + expressionVirtualColumn("v0", "substring(\"dim1\", 0, 1)", ColumnType.STRING), + expressionVirtualColumn("v1", "(\"cnt\" * 2)", ColumnType.LONG) + ) + : ImmutableList.of( + expressionVirtualColumn("v0", "(\"cnt\" * 2)", ColumnType.LONG) + ) + ) + ) + .filters( + NullHandling.sqlCompatible() + ? or( + in("dim2", Arrays.asList("a", "d")), + in("v0", Arrays.asList(null, "foo", "bar")), + in("v1", ColumnType.LONG, Collections.singletonList(3L)) + ) + : or( + in("dim2", Arrays.asList("a", "d")), + in("dim1", Arrays.asList(null, "foo", "bar"), new SubstringDimExtractionFn(0, 1)), + in("v0", ColumnType.LONG, Collections.singletonList(3L)) + ) + ) + .columns("dim2") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"a"}, + new Object[]{"a"} + ) + ); + } + + @Test + public void testNotScalarInArrayFilter() + { + testQuery( + "SELECT dim2 FROM druid.numfoo\n" + + "WHERE NOT SCALAR_IN_ARRAY(dim2, ARRAY['a', 'd'])\n", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters(not(in("dim2", Arrays.asList("a", "d")))) + .columns("dim2") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + NullHandling.sqlCompatible() + ? ImmutableList.of( + new Object[]{""}, + new Object[]{"abc"} + ) + : ImmutableList.of( + new Object[]{""}, + new Object[]{""}, + new Object[]{"abc"}, + new Object[]{""} + ) ); } @Test public void testArrayScalarInFilter_MVD() { - msqIncompatible(); + // In the fifth row, dim3 is an empty list. The Scan query in MSQ reads this with makeDimensionSelector, whereas + // the Scan query in native reads this makeColumnValueSelector. Behavior of those selectors is inconsistent. + // The DimensionSelector returns an empty list; the ColumnValueSelector returns a list containing a single null. + final String expectedValueForEmptyMvd = + queryFramework().engine().name().equals("msq-task") + ? NullHandling.defaultStringValue() + : "not abd"; + testBuilder() - .sql( - "SELECT dim3, (CASE WHEN scalar_in_array(dim3, Array['a', 'b', 'd']) THEN 'abd' ELSE 'not abd' END) " + - "FROM druid.numfoo" - ) - .expectedQueries( - ImmutableList.of( - newScanQueryBuilder() - .dataSource(CalciteTests.DATASOURCE3) - .intervals(querySegmentSpec(Filtration.eternity())) - .virtualColumns( - new ExpressionVirtualColumn( - "v0", - "case_searched(scalar_in_array(\"dim3\",array('a','b','d')),'abd','not abd')", - ColumnType.STRING, - ExprMacroTable.nil() - ) - ) - .columns("dim3", "v0") - .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) - .context(QUERY_CONTEXT_DEFAULT) - .build() + .sql( + "SELECT dim3, (CASE WHEN scalar_in_array(dim3, Array['a', 'b', 'd']) THEN 'abd' ELSE 'not abd' END) " + + "FROM druid.numfoo" + ) + .expectedQueries( + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns( + expressionVirtualColumn( + "v0", + "case_searched(scalar_in_array(\"dim3\",array('a','b','d')),'abd','not abd')", + ColumnType.STRING + ) ) + .columns("dim3", "v0") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build() ) - .expectedResults(ResultMatchMode.RELAX_NULLS, - ImmutableList.of( - new Object[]{"[\"a\",\"b\"]", "[\"abd\",\"abd\"]"}, - new Object[]{"[\"b\",\"c\"]", "[\"abd\",\"not abd\"]"}, - new Object[]{"d", "abd"}, - new Object[]{"", "not abd"}, - new Object[]{null, "not abd"}, - new Object[]{null, "not abd"} - ) + ) + .expectedResults( + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "[\"abd\",\"abd\"]"}, + new Object[]{"[\"b\",\"c\"]", "[\"abd\",\"not abd\"]"}, + new Object[]{"d", "abd"}, + new Object[]{"", "not abd"}, + new Object[]{NullHandling.defaultStringValue(), expectedValueForEmptyMvd}, + new Object[]{NullHandling.defaultStringValue(), "not abd"} ) - .run(); - + ) + .run(); } @Test diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/DruidExpressionTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/DruidExpressionTest.java index ba31a431b851..5e6d5b067e63 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/DruidExpressionTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/DruidExpressionTest.java @@ -19,9 +19,12 @@ package org.apache.druid.sql.calcite.expression; +import org.apache.druid.common.config.NullHandling; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.math.expr.Parser; +import org.apache.druid.segment.column.ColumnType; import org.apache.druid.testing.InitializedNullHandlingTest; import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; @@ -90,7 +93,7 @@ public void test_longLiteral_asString() } @Test - public void longLiteral_roundTrip() + public void test_longLiteral_roundTrip() { final long[] longs = { 0, @@ -107,4 +110,124 @@ public void longLiteral_roundTrip() Assert.assertEquals(n, ((Number) expr.getLiteralValue()).longValue()); } } + + @Test + public void test_ofLiteral_nullString() + { + final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.STRING, null)); + + Assert.assertEquals(ColumnType.STRING, expression.getDruidType()); + Assert.assertEquals("null", expression.getExpression()); + } + + @Test + public void test_ofLiteral_nullLong() + { + final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.LONG, null)); + + Assert.assertEquals(ColumnType.LONG, expression.getDruidType()); + Assert.assertEquals("null", expression.getExpression()); + } + + @Test + public void test_ofLiteral_nullDouble() + { + final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.DOUBLE, null)); + + Assert.assertEquals(ColumnType.DOUBLE, expression.getDruidType()); + Assert.assertEquals("null", expression.getExpression()); + } + + @Test + public void test_ofLiteral_nullArray() + { + final DruidExpression expression = + DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.STRING_ARRAY, null)); + + Assert.assertEquals(ColumnType.STRING_ARRAY, expression.getDruidType()); + Assert.assertEquals("null", expression.getExpression()); + } + + @Test + public void test_ofLiteral_string() + { + final String s = "abcdé\n \\\" ' \uD83E\uDD20 \txyz"; + final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.STRING, s)); + + Assert.assertEquals(ColumnType.STRING, expression.getDruidType()); + Assert.assertEquals("'abcdé\\u000A \\u005C\\u0022 \\u0027 \\uD83E\\uDD20 \\u0009xyz'", expression.getExpression()); + Assert.assertEquals(s, Parser.parse(expression.getExpression(), ExprMacroTable.nil()).getLiteralValue()); + } + + @Test + public void test_ofLiteral_emptyString() + { + final String s = ""; + final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.STRING, s)); + + Assert.assertEquals(ColumnType.STRING, expression.getDruidType()); + Assert.assertEquals("''", expression.getExpression()); + Assert.assertEquals( + NullHandling.emptyToNullIfNeeded(s), + Parser.parse(expression.getExpression(), ExprMacroTable.nil()).getLiteralValue() + ); + } + + @Test + public void test_ofLiteral_long() + { + final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.LONG, -123)); + + Assert.assertEquals(ColumnType.LONG, expression.getDruidType()); + Assert.assertEquals("-123", expression.getExpression()); + Assert.assertEquals(-123L, Parser.parse(expression.getExpression(), ExprMacroTable.nil()).getLiteralValue()); + } + + @Test + public void test_ofLiteral_double() + { + final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.DOUBLE, -123.4)); + + Assert.assertEquals(ColumnType.DOUBLE, expression.getDruidType()); + Assert.assertEquals("-123.4", expression.getExpression()); + Assert.assertEquals(-123.4, Parser.parse(expression.getExpression(), ExprMacroTable.nil()).getLiteralValue()); + } + + @Test + public void test_ofLiteral_doubleNan() + { + final DruidExpression expression = DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.DOUBLE, Double.NaN)); + + Assert.assertEquals(ColumnType.DOUBLE, expression.getDruidType()); + Assert.assertEquals("NaN", expression.getExpression()); + Assert.assertEquals(Double.NaN, Parser.parse(expression.getExpression(), ExprMacroTable.nil()).getLiteralValue()); + } + + @Test + public void test_ofLiteral_doubleNegativeInfinity() + { + final DruidExpression expression = + DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.DOUBLE, Double.NEGATIVE_INFINITY)); + + Assert.assertEquals(ColumnType.DOUBLE, expression.getDruidType()); + Assert.assertEquals("-Infinity", expression.getExpression()); + Assert.assertEquals( + Double.NEGATIVE_INFINITY, + Parser.parse(expression.getExpression(), ExprMacroTable.nil()).getLiteralValue() + ); + } + + @Test + public void test_ofLiteral_doublePositiveInfinity() + { + final DruidExpression expression = + DruidExpression.ofLiteral(new DruidLiteral(ExpressionType.DOUBLE, Double.POSITIVE_INFINITY)); + + Assert.assertEquals(ColumnType.DOUBLE, expression.getDruidType()); + Assert.assertEquals("Infinity", expression.getExpression()); + Assert.assertEquals( + Double.POSITIVE_INFINITY, + Parser.parse(expression.getExpression(), ExprMacroTable.nil()).getLiteralValue() + ); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java index 489d7f6eb036..86b5d38639be 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java @@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.calcite.avatica.util.TimeUnit; import org.apache.calcite.avatica.util.TimeUnitRange; +import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlIntervalQualifier; import org.apache.calcite.sql.SqlOperator; @@ -33,6 +34,8 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.query.expression.TestExprMacroTable; import org.apache.druid.query.extraction.RegexDimExtractionFn; import org.apache.druid.query.extraction.SubstringDimExtractionFn; @@ -65,12 +68,17 @@ import org.apache.druid.sql.calcite.expression.builtin.TimeParseOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.TimeShiftOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.TruncateOperatorConversion; +import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.DruidOperatorTable; +import org.apache.druid.sql.calcite.planner.DruidTypeSystem; +import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.util.CalciteTestBase; +import org.joda.time.DateTimeZone; import org.joda.time.Period; import org.junit.Assert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.Mockito; import java.math.BigDecimal; import java.util.Collections; @@ -104,29 +112,32 @@ public class ExpressionsTest extends CalciteTestBase .build(); private static final Map BINDINGS = ImmutableMap.builder() - .put("t", DateTimes.of("2000-02-03T04:05:06").getMillis()) - .put("a", 10) - .put("b", 25) - .put("p", 3) - .put("x", 2.25) - .put("y", 3.0) - .put("z", -2.25) - .put("o", 0) - .put("nan", Double.NaN) - .put("inf", Double.POSITIVE_INFINITY) - .put("-inf", Double.NEGATIVE_INFINITY) - .put("fnan", Float.NaN) - .put("finf", Float.POSITIVE_INFINITY) - .put("-finf", Float.NEGATIVE_INFINITY) - .put("s", "foo") - .put("hexstr", "EF") - .put("intstr", "-100") - .put("spacey", " hey there ") - .put("newliney", "beep\nboop") - .put("tstr", "2000-02-03 04:05:06") - .put("dstr", "2000-02-03") - .put("timezone", "America/Los_Angeles") - .build(); + .put( + "t", + DateTimes.of("2000-02-03T04:05:06").getMillis() + ) + .put("a", 10) + .put("b", 25) + .put("p", 3) + .put("x", 2.25) + .put("y", 3.0) + .put("z", -2.25) + .put("o", 0) + .put("nan", Double.NaN) + .put("inf", Double.POSITIVE_INFINITY) + .put("-inf", Double.NEGATIVE_INFINITY) + .put("fnan", Float.NaN) + .put("finf", Float.POSITIVE_INFINITY) + .put("-finf", Float.NEGATIVE_INFINITY) + .put("s", "foo") + .put("hexstr", "EF") + .put("intstr", "-100") + .put("spacey", " hey there ") + .put("newliney", "beep\nboop") + .put("tstr", "2000-02-03 04:05:06") + .put("dstr", "2000-02-03") + .put("timezone", "America/Los_Angeles") + .build(); private ExpressionTestHelper testHelper; @@ -1923,7 +1934,7 @@ public void testTimeMinusDayTimeInterval() (args) -> "(" + args.get(0).getExpression() + " - " + args.get(1).getExpression() + ")", ImmutableList.of( DruidExpression.ofColumn(ColumnType.LONG, "t"), - DruidExpression.ofLiteral(ColumnType.STRING, "90060000") + DruidExpression.ofLiteral(ColumnType.LONG, "90060000") ) ), DateTimes.of("2000-02-03T04:05:06").minus(period).getMillis() @@ -2815,4 +2826,173 @@ public void testHumanReadableDecimalByteFormat() "45.678 KB" ); } + + @Test + public void testCalciteLiteralToDruidLiteral() + { + final RexBuilder rexBuilder = new RexBuilder(DruidTypeSystem.TYPE_FACTORY); + final PlannerContext plannerContext = Mockito.mock(PlannerContext.class); + Mockito.when(plannerContext.getTimeZone()).thenReturn(DateTimeZone.UTC); + + assertDruidLiteral( + new DruidLiteral(ExpressionType.STRING, null), + Expressions.calciteLiteralToDruidLiteral( + plannerContext, + rexBuilder.makeNullLiteral(rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR)) + ) + ); + + assertDruidLiteral( + new DruidLiteral(ExpressionType.STRING, ""), + Expressions.calciteLiteralToDruidLiteral( + plannerContext, + rexBuilder.makeLiteral("") + ) + ); + + assertDruidLiteral( + new DruidLiteral(ExpressionType.LONG, null), + Expressions.calciteLiteralToDruidLiteral( + plannerContext, + rexBuilder.makeNullLiteral(rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT)) + ) + ); + + assertDruidLiteral( + new DruidLiteral(null, null), + Expressions.calciteLiteralToDruidLiteral( + plannerContext, + rexBuilder.makeNullLiteral(rexBuilder.getTypeFactory().createSqlType(SqlTypeName.NULL)) + ) + ); + + assertDruidLiteral( + new DruidLiteral(ExpressionType.STRING, "abc"), + Expressions.calciteLiteralToDruidLiteral(plannerContext, rexBuilder.makeLiteral("abc")) + ); + + assertDruidLiteral( + new DruidLiteral(ExpressionType.LONG, 1L), + Expressions.calciteLiteralToDruidLiteral(plannerContext, rexBuilder.makeLiteral(true)) + ); + + assertDruidLiteral( + new DruidLiteral(ExpressionType.LONG, 123L), + Expressions.calciteLiteralToDruidLiteral( + plannerContext, + rexBuilder.makeExactLiteral( + BigDecimal.valueOf(123L), + rexBuilder.getTypeFactory().createSqlType(SqlTypeName.INTEGER) + ) + ) + ); + + assertDruidLiteral( + new DruidLiteral(ExpressionType.DOUBLE, 123.0), + Expressions.calciteLiteralToDruidLiteral( + plannerContext, + rexBuilder.makeExactLiteral( + BigDecimal.valueOf(123L), + rexBuilder.getTypeFactory().createSqlType(SqlTypeName.DECIMAL) + ) + ) + ); + + assertDruidLiteral( + new DruidLiteral(ExpressionType.LONG, DateTimes.of("2000").getMillis()), + Expressions.calciteLiteralToDruidLiteral( + plannerContext, + Calcites.jodaToCalciteTimestampLiteral( + rexBuilder, + DateTimes.of("2000"), + DateTimeZone.UTC, + DruidTypeSystem.DEFAULT_TIMESTAMP_PRECISION + ) + ) + ); + + assertDruidLiteral( + new DruidLiteral(ExpressionType.LONG, DateTimes.of("2000").getMillis()), + Expressions.calciteLiteralToDruidLiteral( + plannerContext, + rexBuilder.makeDateLiteral(Calcites.jodaToCalciteDateString(DateTimes.of("2000"), DateTimeZone.UTC)) + ) + ); + + assertDruidLiteral( + new DruidLiteral(ExpressionType.LONG, 3L), + Expressions.calciteLiteralToDruidLiteral( + plannerContext, + rexBuilder.makeIntervalLiteral( + BigDecimal.valueOf(3), + new SqlIntervalQualifier(TimeUnit.DAY, TimeUnit.HOUR, SqlParserPos.ZERO) + ) + ) + ); + + assertDruidLiteral( + new DruidLiteral(ExpressionType.LONG, 3), + Expressions.calciteLiteralToDruidLiteral( + plannerContext, + rexBuilder.makeIntervalLiteral( + BigDecimal.valueOf(3), + new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO) + ) + ) + ); + + assertDruidLiteral( + new DruidLiteral(ExpressionType.STRING, "123"), + Expressions.calciteLiteralToDruidLiteral( + plannerContext, + rexBuilder.makeCast( + rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR), + rexBuilder.makeExactLiteral( + BigDecimal.valueOf(123.7), + rexBuilder.getTypeFactory().createSqlType(SqlTypeName.INTEGER) + ) + ) + ) + ); + + assertDruidLiteral( + new DruidLiteral(ExpressionType.DOUBLE, 123.0), + Expressions.calciteLiteralToDruidLiteral( + plannerContext, + rexBuilder.makeCast( + rexBuilder.getTypeFactory().createSqlType(SqlTypeName.DOUBLE), + rexBuilder.makeExactLiteral( + BigDecimal.valueOf(123L), + rexBuilder.getTypeFactory().createSqlType(SqlTypeName.INTEGER) + ) + ) + ) + ); + + Assert.assertNull( + Expressions.calciteLiteralToDruidLiteral( + plannerContext, + rexBuilder.makeCast( + rexBuilder.getTypeFactory().createSqlType(SqlTypeName.DATE), + Calcites.jodaToCalciteTimestampLiteral( + rexBuilder, + DateTimes.of("2000-01-02T03:04:05"), + DateTimeZone.UTC, + DruidTypeSystem.DEFAULT_TIMESTAMP_PRECISION + ) + ) + ) + ); + } + + private void assertDruidLiteral( + final DruidLiteral expected, + final DruidLiteral actual + ) + { + Assert.assertEquals( + StringUtils.format("%s: %s", expected.type(), expected.value()), + StringUtils.format("%s: %s", actual.type(), actual.value()) + ); + } }