Skip to content

Commit

Permalink
[FLINK-22788][table-planner-blink] Support equalisers for many fields
Browse files Browse the repository at this point in the history
When working with hundreds of fields, equalisers can fail to compile
because the method body grows beyond 64kb. With this change, instead of
generating all code into one method, we generate a dedicated method per
field and then call all of those methods. This doesn't entirely remove
the problem, but supports roughly a factor of 10 more fields and is
currently deemed sufficient.

This closes apache#16213.
  • Loading branch information
Airblader authored and twalthr committed Jun 22, 2021
1 parent 6168118 commit cb116be
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,73 +44,16 @@ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType]) {
// ignore time zone
val ctx = CodeGeneratorContext(new TableConfig)
val className = newName(name)
val header =
s"""
|if ($LEFT_INPUT.getRowKind() != $RIGHT_INPUT.getRowKind()) {
| return false;
|}
""".stripMargin

val codes = for (i <- fieldTypes.indices) yield {
val fieldType = fieldTypes(i)
val fieldTypeTerm = primitiveTypeTermForType(fieldType)
val result = s"cmp$i"
val leftNullTerm = "leftIsNull$" + i
val rightNullTerm = "rightIsNull$" + i
val leftFieldTerm = "leftField$" + i
val rightFieldTerm = "rightField$" + i

// TODO merge ScalarOperatorGens.generateEquals.
val (equalsCode, equalsResult) = if (isInternalPrimitive(fieldType)) {
("", s"$leftFieldTerm == $rightFieldTerm")
} else if (isCompositeType(fieldType)) {
val equaliserGenerator = new EqualiserCodeGenerator(
getFieldTypes(fieldType).asScala.toArray)
val generatedEqualiser = equaliserGenerator
.generateRecordEqualiser("field$" + i + "GeneratedEqualiser")
val generatedEqualiserTerm = ctx.addReusableObject(
generatedEqualiser, "field$" + i + "GeneratedEqualiser")
val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName
val equaliserTerm = newName("equaliser")
ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = null;")
ctx.addReusableInitStatement(
s"""
|$equaliserTerm = ($equaliserTypeTerm)
| $generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader());
|""".stripMargin)
("", s"$equaliserTerm.equals($leftFieldTerm, $rightFieldTerm)")
} else {
val left = GeneratedExpression(leftFieldTerm, leftNullTerm, "", fieldType)
val right = GeneratedExpression(rightFieldTerm, rightNullTerm, "", fieldType)
val gen = generateEquals(ctx, left, right)
(gen.code, gen.resultTerm)
}
val leftReadCode = rowFieldReadAccess(ctx, i, LEFT_INPUT, fieldType)
val rightReadCode = rowFieldReadAccess(ctx, i, RIGHT_INPUT, fieldType)
s"""
|boolean $leftNullTerm = $LEFT_INPUT.isNullAt($i);
|boolean $rightNullTerm = $RIGHT_INPUT.isNullAt($i);
|boolean $result;
|if ($leftNullTerm && $rightNullTerm) {
| $result = true;
|} else if ($leftNullTerm|| $rightNullTerm) {
| $result = false;
|} else {
| $fieldTypeTerm $leftFieldTerm = $leftReadCode;
| $fieldTypeTerm $rightFieldTerm = $rightReadCode;
| $equalsCode
| $result = $equalsResult;
|}
|if (!$result) {
| return false;
|}
""".stripMargin

val equalsMethodCodes = for (idx <- fieldTypes.indices) yield generateEqualsMethod(ctx, idx)
val equalsMethodCalls = for (idx <- fieldTypes.indices) yield {
val methodName = getEqualsMethodName(idx)
s"""result = result && $methodName($LEFT_INPUT, $RIGHT_INPUT);"""
}

val functionCode =
val classCode =
j"""
public final class $className implements $RECORD_EQUALISER {

${ctx.reuseMemberCode()}

public $className(Object[] references) throws Exception {
Expand All @@ -121,17 +64,98 @@ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType]) {
public boolean equals($ROW_DATA $LEFT_INPUT, $ROW_DATA $RIGHT_INPUT) {
if ($LEFT_INPUT instanceof $BINARY_ROW && $RIGHT_INPUT instanceof $BINARY_ROW) {
return $LEFT_INPUT.equals($RIGHT_INPUT);
} else {
$header
${ctx.reuseLocalVariableCode()}
${codes.mkString("\n")}
return true;
}

if ($LEFT_INPUT.getRowKind() != $RIGHT_INPUT.getRowKind()) {
return false;
}

boolean result = true;
${equalsMethodCalls.mkString("\n")}
return result;
}

${equalsMethodCodes.mkString("\n")}
}
""".stripMargin

new GeneratedRecordEqualiser(className, functionCode, ctx.references.toArray)
new GeneratedRecordEqualiser(className, classCode, ctx.references.toArray)
}

private def getEqualsMethodName(idx: Int) = s"""equalsAtIndex$idx"""

private def generateEqualsMethod(ctx: CodeGeneratorContext, idx: Int): String = {
val methodName = getEqualsMethodName(idx)
ctx.startNewLocalVariableStatement(methodName)

val Seq(leftNullTerm, rightNullTerm) = ctx.addReusableLocalVariables(
("boolean", "isNullLeft"),
("boolean", "isNullRight")
)

val fieldType = fieldTypes(idx)
val fieldTypeTerm = primitiveTypeTermForType(fieldType)
val Seq(leftFieldTerm, rightFieldTerm) = ctx.addReusableLocalVariables(
(fieldTypeTerm, "leftField"),
(fieldTypeTerm, "rightField")
)

val leftReadCode = rowFieldReadAccess(ctx, idx, LEFT_INPUT, fieldType)
val rightReadCode = rowFieldReadAccess(ctx, idx, RIGHT_INPUT, fieldType)

val (equalsCode, equalsResult) = generateEqualsCode(ctx, fieldType,
leftFieldTerm, rightFieldTerm, leftNullTerm, rightNullTerm)

s"""
|private boolean $methodName($ROW_DATA $LEFT_INPUT, $ROW_DATA $RIGHT_INPUT) {
| ${ctx.reuseLocalVariableCode(methodName)}
|
| $leftNullTerm = $LEFT_INPUT.isNullAt($idx);
| $rightNullTerm = $RIGHT_INPUT.isNullAt($idx);
| if ($leftNullTerm && $rightNullTerm) {
| return true;
| }
|
| if ($leftNullTerm || $rightNullTerm) {
| return false;
| }
|
| $leftFieldTerm = $leftReadCode;
| $rightFieldTerm = $rightReadCode;
| $equalsCode
|
| return $equalsResult;
|}
""".stripMargin
}

private def generateEqualsCode(ctx: CodeGeneratorContext, fieldType: LogicalType,
leftFieldTerm: String, rightFieldTerm: String,
leftNullTerm: String, rightNullTerm: String) = {
// TODO merge ScalarOperatorGens.generateEquals.
if (isInternalPrimitive(fieldType)) {
("", s"$leftFieldTerm == $rightFieldTerm")
} else if (isCompositeType(fieldType)) {
val equaliserGenerator = new EqualiserCodeGenerator(
getFieldTypes(fieldType).asScala.toArray)
val generatedEqualiser = equaliserGenerator.generateRecordEqualiser("fieldGeneratedEqualiser")
val generatedEqualiserTerm = ctx.addReusableObject(
generatedEqualiser, "fieldGeneratedEqualiser")
val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName
val equaliserTerm = newName("equaliser")
ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = null;")
ctx.addReusableInitStatement(
s"""
|$equaliserTerm = ($equaliserTypeTerm)
| $generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader());
|""".stripMargin)
("", s"$equaliserTerm.equals($leftFieldTerm, $rightFieldTerm)")
} else {
val left = GeneratedExpression(leftFieldTerm, leftNullTerm, "", fieldType)
val right = GeneratedExpression(rightFieldTerm, rightNullTerm, "", fieldType)
val gen = generateEquals(ctx, left, right)
(gen.code, gen.resultTerm)
}
}

@tailrec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RawValueData;
import org.apache.flink.table.data.StringData;
import org.apache.flink.table.data.TimestampData;
import org.apache.flink.table.data.binary.BinaryRowData;
import org.apache.flink.table.data.writer.BinaryRowWriter;
Expand All @@ -30,13 +31,16 @@
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.TimestampType;
import org.apache.flink.table.types.logical.TypeInformationRawType;
import org.apache.flink.table.types.logical.VarCharType;

import org.junit.Assert;
import org.junit.Test;

import java.util.function.Function;
import java.util.stream.IntStream;

import static org.apache.flink.table.data.TimestampData.fromEpochMillis;
import static org.junit.Assert.assertTrue;

/** Test for {@link EqualiserCodeGenerator}. */
public class EqualiserCodeGeneratorTest {
Expand Down Expand Up @@ -81,6 +85,36 @@ public void testTimestamp() {
assertBoolean(equaliser, func, fromEpochMillis(1024), fromEpochMillis(1025), false);
}

@Test
public void testManyFields() {
final LogicalType[] fieldTypes =
IntStream.range(0, 999)
.mapToObj(i -> new VarCharType())
.toArray(LogicalType[]::new);

RecordEqualiser equaliser;
try {
equaliser =
new EqualiserCodeGenerator(fieldTypes)
.generateRecordEqualiser("ManyFields")
.newInstance(Thread.currentThread().getContextClassLoader());
} catch (Exception e) {
Assert.fail("Expected compilation to succeed");

// Unreachable
throw e;
}

final StringData[] fields =
IntStream.range(0, 999)
.mapToObj(i -> StringData.fromString("Entry " + i))
.toArray(StringData[]::new);
assertTrue(
equaliser.equals(
GenericRowData.of((Object[]) fields),
GenericRowData.of((Object[]) fields)));
}

private static <T> void assertBoolean(
RecordEqualiser equaliser,
Function<T, BinaryRowData> toBinaryRow,
Expand Down

0 comments on commit cb116be

Please sign in to comment.