Skip to content

Commit

Permalink
fix some null handling bugs with vector expression processors (#15587)
Browse files Browse the repository at this point in the history
  • Loading branch information
clintropolis committed Dec 19, 2023
1 parent 9f56885 commit 8a45efb
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public DoubleOutLongsInFunctionVectorValueProcessor(
@Override
public ExpressionType getOutputType()
{
return ExpressionType.LONG;
return ExpressionType.DOUBLE;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,11 @@ public void processIndex(
outputNulls[i] = rightNulls[i];
} else {
output[i] = rightInput[i];
outputNulls[i] = false;
}
} else {
output[i] = leftInput[i];
outputNulls[i] = false;
}
}

Expand Down Expand Up @@ -580,9 +582,11 @@ public void processIndex(
outputNulls[i] = rightNulls[i];
} else {
output[i] = rightInput[i];
outputNulls[i] = false;
}
} else {
output[i] = leftInput[i];
outputNulls[i] = false;
}
}

Expand Down Expand Up @@ -744,6 +748,7 @@ public void processIndex(
}
}
output[i] = Evals.asLong(Evals.asBoolean(leftInput[i]) || Evals.asBoolean(rightInput[i]));
outputNulls[i] = false;
}

@Override
Expand Down Expand Up @@ -793,6 +798,7 @@ public void processIndex(
}
}
output[i] = Evals.asLong(Evals.asBoolean(leftInput[i]) || Evals.asBoolean(rightInput[i]));
outputNulls[i] = false;
}

@Override
Expand Down Expand Up @@ -839,6 +845,7 @@ public void processIndex(
return;
}
output[i] = Evals.asLong(Evals.asBoolean((String) leftInput[i]) || Evals.asBoolean((String) rightInput[i]));
outputNulls[i] = false;
}

@Override
Expand Down Expand Up @@ -907,6 +914,7 @@ public void processIndex(
}
}
output[i] = Evals.asLong(Evals.asBoolean(leftInput[i]) && Evals.asBoolean(rightInput[i]));
outputNulls[i] = false;
}

@Override
Expand All @@ -916,7 +924,7 @@ public ExprEvalVector<long[]> asEval()
}
},
() -> new BivariateFunctionVectorProcessor<double[], double[], long[]>(
ExpressionType.DOUBLE,
ExpressionType.LONG,
left.asVectorProcessor(inputTypes),
right.asVectorProcessor(inputTypes)
)
Expand Down Expand Up @@ -956,6 +964,7 @@ public void processIndex(
}
}
output[i] = Evals.asLong(Evals.asBoolean(leftInput[i]) && Evals.asBoolean(rightInput[i]));
outputNulls[i] = false;
}

@Override
Expand All @@ -965,7 +974,7 @@ public ExprEvalVector<long[]> asEval()
}
},
() -> new BivariateFunctionVectorProcessor<Object[], Object[], long[]>(
ExpressionType.STRING,
ExpressionType.LONG,
left.asVectorProcessor(inputTypes),
right.asVectorProcessor(inputTypes)
)
Expand Down Expand Up @@ -1004,6 +1013,7 @@ public void processIndex(
output[i] = Evals.asLong(
Evals.asBoolean((String) leftInput[i]) && Evals.asBoolean((String) rightInput[i])
);
outputNulls[i] = false;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.math.expr.vector.ExprEvalVector;
import org.apache.druid.math.expr.vector.ExprVectorProcessor;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -113,7 +114,7 @@ public void testUnaryLogicOperators()
public void testBinaryLogicOperators()
{
final String[] functions = new String[]{"&&", "||"};
final String[] templates = new String[]{"d1 %s d2", "l1 %s l2", "boolString1 %s boolString2"};
final String[] templates = new String[]{"d1 %s d2", "l1 %s l2", "boolString1 %s boolString2", "(d1 == d2) %s (l1 == l2)"};
testFunctions(types, templates, functions);
}

Expand Down Expand Up @@ -283,21 +284,17 @@ static void testExpression(String expr, Map<String, ExpressionType> types)
log.debug("[%s]", expr);
Expr parsed = Parser.parse(expr, ExprMacroTable.nil());

NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> bindings;
for (int iterations = 0; iterations < NUM_ITERATIONS; iterations++) {
bindings = makeRandomizedBindings(VECTOR_SIZE, types);
testExpressionWithBindings(expr, parsed, bindings);
}
bindings = makeSequentialBinding(VECTOR_SIZE, types);
testExpressionWithBindings(expr, parsed, bindings);
testExpression(expr, parsed, types, NUM_ITERATIONS);
testSequentialBinding(expr, parsed, types);
}

public static void testExpressionWithBindings(
public static void testSequentialBinding(
String expr,
Expr parsed,
NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> bindings
Map<String, ExpressionType> types
)
{
NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> bindings = makeSequentialBinding(VECTOR_SIZE, types);
Assert.assertTrue(StringUtils.format("Cannot vectorize %s", expr), parsed.canVectorize(bindings.rhs));
ExpressionType outputType = parsed.getOutputType(bindings.rhs);
ExprEvalVector<?> vectorEval = parsed.asVectorProcessor(bindings.rhs).evalVector(bindings.rhs);
Expand All @@ -320,6 +317,55 @@ public static void testExpressionWithBindings(
}
}

public static void testExpression(
String expr,
Expr parsed,
Map<String, ExpressionType> types,
int numIterations
)
{
Expr.InputBindingInspector inspector = InputBindings.inspectorFromTypeMap(types);
Expr.VectorInputBindingInspector vectorInputBindingInspector = new Expr.VectorInputBindingInspector()
{
@Override
public int getMaxVectorSize()
{
return VECTOR_SIZE;
}

@Nullable
@Override
public ExpressionType getType(String name)
{
return inspector.getType(name);
}
};
Assert.assertTrue(StringUtils.format("Cannot vectorize %s", expr), parsed.canVectorize(inspector));
ExpressionType outputType = parsed.getOutputType(inspector);
final ExprVectorProcessor processor = parsed.asVectorProcessor(vectorInputBindingInspector);
// 'null' expressions can have an output type of null, but still evaluate in default mode, so skip type checks
if (outputType != null) {
Assert.assertEquals(expr, outputType, processor.getOutputType());
}
for (int iterations = 0; iterations < numIterations; iterations++) {
NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> bindings = makeRandomizedBindings(VECTOR_SIZE, types);
ExprEvalVector<?> vectorEval = processor.evalVector(bindings.rhs);
final Object[] vectorVals = vectorEval.getObjectVector();
for (int i = 0; i < VECTOR_SIZE; i++) {
ExprEval<?> eval = parsed.eval(bindings.lhs[i]);
// 'null' expressions can have an output type of null, but still evaluate in default mode, so skip type checks
if (outputType != null && !eval.isNumericNull()) {
Assert.assertEquals(eval.type(), outputType);
}
Assert.assertEquals(
StringUtils.format("Values do not match for row %s for expression %s", i, expr),
eval.valueOrDefault(),
vectorVals[i]
);
}
}
}

public static NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> makeRandomizedBindings(
int vectorSize,
Map<String, ExpressionType> types
Expand All @@ -332,7 +378,7 @@ public static NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> makeRan
types,
() -> r.nextLong(Integer.MAX_VALUE - 1),
r::nextDouble,
r::nextBoolean,
() -> r.nextDouble(0, 1.0) > 0.9,
() -> String.valueOf(r.nextInt())
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
Expand Down Expand Up @@ -62,13 +61,8 @@ public class VectorExpressionsSanityTest extends InitializedNullHandlingTest
static void testExpression(String expr, Expr parsed, Map<String, ExpressionType> types)
{
log.debug("[%s]", expr);
NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> bindings;
for (int iterations = 0; iterations < NUM_ITERATIONS; iterations++) {
bindings = VectorExprSanityTest.makeRandomizedBindings(VECTOR_SIZE, types);
VectorExprSanityTest.testExpressionWithBindings(expr, parsed, bindings);
}
bindings = VectorExprSanityTest.makeSequentialBinding(VECTOR_SIZE, types);
VectorExprSanityTest.testExpressionWithBindings(expr, parsed, bindings);
VectorExprSanityTest.testExpression(expr, parsed, types, NUM_ITERATIONS);
VectorExprSanityTest.testSequentialBinding(expr, parsed, types);
}

@Test
Expand Down

0 comments on commit 8a45efb

Please sign in to comment.