From cb116be7bd3516778c4494ce1b9c7470977a1281 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20B=C3=BCrk?= Date: Fri, 11 Jun 2021 12:43:09 +0200 Subject: [PATCH] [FLINK-22788][table-planner-blink] Support equalisers for many fields When working with hundreds of fields, equalisers can fail to compile because the method body grows beyond 64kb. With this change, instead of generating all code into one method, we generate a dedicated method per field and then call all of those methods. This doesn't entirely remove the problem, but supports roughly a factor of 10 more fields and is currently deemed sufficient. This closes #16213. --- .../codegen/EqualiserCodeGenerator.scala | 162 ++++++++++-------- .../codegen/EqualiserCodeGeneratorTest.java | 34 ++++ 2 files changed, 127 insertions(+), 69 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala index be8d480f86152..79d8098d0446c 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala @@ -44,73 +44,16 @@ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType]) { // ignore time zone val ctx = CodeGeneratorContext(new TableConfig) val className = newName(name) - val header = - s""" - |if ($LEFT_INPUT.getRowKind() != $RIGHT_INPUT.getRowKind()) { - | return false; - |} - """.stripMargin - - val codes = for (i <- fieldTypes.indices) yield { - val fieldType = fieldTypes(i) - val fieldTypeTerm = primitiveTypeTermForType(fieldType) - val result = s"cmp$i" - val leftNullTerm = "leftIsNull$" + i - val rightNullTerm = "rightIsNull$" + i - val leftFieldTerm = "leftField$" + i - val rightFieldTerm = "rightField$" + i - - // TODO merge ScalarOperatorGens.generateEquals. - val (equalsCode, equalsResult) = if (isInternalPrimitive(fieldType)) { - ("", s"$leftFieldTerm == $rightFieldTerm") - } else if (isCompositeType(fieldType)) { - val equaliserGenerator = new EqualiserCodeGenerator( - getFieldTypes(fieldType).asScala.toArray) - val generatedEqualiser = equaliserGenerator - .generateRecordEqualiser("field$" + i + "GeneratedEqualiser") - val generatedEqualiserTerm = ctx.addReusableObject( - generatedEqualiser, "field$" + i + "GeneratedEqualiser") - val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName - val equaliserTerm = newName("equaliser") - ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = null;") - ctx.addReusableInitStatement( - s""" - |$equaliserTerm = ($equaliserTypeTerm) - | $generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader()); - |""".stripMargin) - ("", s"$equaliserTerm.equals($leftFieldTerm, $rightFieldTerm)") - } else { - val left = GeneratedExpression(leftFieldTerm, leftNullTerm, "", fieldType) - val right = GeneratedExpression(rightFieldTerm, rightNullTerm, "", fieldType) - val gen = generateEquals(ctx, left, right) - (gen.code, gen.resultTerm) - } - val leftReadCode = rowFieldReadAccess(ctx, i, LEFT_INPUT, fieldType) - val rightReadCode = rowFieldReadAccess(ctx, i, RIGHT_INPUT, fieldType) - s""" - |boolean $leftNullTerm = $LEFT_INPUT.isNullAt($i); - |boolean $rightNullTerm = $RIGHT_INPUT.isNullAt($i); - |boolean $result; - |if ($leftNullTerm && $rightNullTerm) { - | $result = true; - |} else if ($leftNullTerm|| $rightNullTerm) { - | $result = false; - |} else { - | $fieldTypeTerm $leftFieldTerm = $leftReadCode; - | $fieldTypeTerm $rightFieldTerm = $rightReadCode; - | $equalsCode - | $result = $equalsResult; - |} - |if (!$result) { - | return false; - |} - """.stripMargin + + val equalsMethodCodes = for (idx <- fieldTypes.indices) yield generateEqualsMethod(ctx, idx) + val equalsMethodCalls = for (idx <- fieldTypes.indices) yield { + val methodName = getEqualsMethodName(idx) + s"""result = result && $methodName($LEFT_INPUT, $RIGHT_INPUT);""" } - val functionCode = + val classCode = j""" public final class $className implements $RECORD_EQUALISER { - ${ctx.reuseMemberCode()} public $className(Object[] references) throws Exception { @@ -121,17 +64,98 @@ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType]) { public boolean equals($ROW_DATA $LEFT_INPUT, $ROW_DATA $RIGHT_INPUT) { if ($LEFT_INPUT instanceof $BINARY_ROW && $RIGHT_INPUT instanceof $BINARY_ROW) { return $LEFT_INPUT.equals($RIGHT_INPUT); - } else { - $header - ${ctx.reuseLocalVariableCode()} - ${codes.mkString("\n")} - return true; } + + if ($LEFT_INPUT.getRowKind() != $RIGHT_INPUT.getRowKind()) { + return false; + } + + boolean result = true; + ${equalsMethodCalls.mkString("\n")} + return result; } + + ${equalsMethodCodes.mkString("\n")} } """.stripMargin - new GeneratedRecordEqualiser(className, functionCode, ctx.references.toArray) + new GeneratedRecordEqualiser(className, classCode, ctx.references.toArray) + } + + private def getEqualsMethodName(idx: Int) = s"""equalsAtIndex$idx""" + + private def generateEqualsMethod(ctx: CodeGeneratorContext, idx: Int): String = { + val methodName = getEqualsMethodName(idx) + ctx.startNewLocalVariableStatement(methodName) + + val Seq(leftNullTerm, rightNullTerm) = ctx.addReusableLocalVariables( + ("boolean", "isNullLeft"), + ("boolean", "isNullRight") + ) + + val fieldType = fieldTypes(idx) + val fieldTypeTerm = primitiveTypeTermForType(fieldType) + val Seq(leftFieldTerm, rightFieldTerm) = ctx.addReusableLocalVariables( + (fieldTypeTerm, "leftField"), + (fieldTypeTerm, "rightField") + ) + + val leftReadCode = rowFieldReadAccess(ctx, idx, LEFT_INPUT, fieldType) + val rightReadCode = rowFieldReadAccess(ctx, idx, RIGHT_INPUT, fieldType) + + val (equalsCode, equalsResult) = generateEqualsCode(ctx, fieldType, + leftFieldTerm, rightFieldTerm, leftNullTerm, rightNullTerm) + + s""" + |private boolean $methodName($ROW_DATA $LEFT_INPUT, $ROW_DATA $RIGHT_INPUT) { + | ${ctx.reuseLocalVariableCode(methodName)} + | + | $leftNullTerm = $LEFT_INPUT.isNullAt($idx); + | $rightNullTerm = $RIGHT_INPUT.isNullAt($idx); + | if ($leftNullTerm && $rightNullTerm) { + | return true; + | } + | + | if ($leftNullTerm || $rightNullTerm) { + | return false; + | } + | + | $leftFieldTerm = $leftReadCode; + | $rightFieldTerm = $rightReadCode; + | $equalsCode + | + | return $equalsResult; + |} + """.stripMargin + } + + private def generateEqualsCode(ctx: CodeGeneratorContext, fieldType: LogicalType, + leftFieldTerm: String, rightFieldTerm: String, + leftNullTerm: String, rightNullTerm: String) = { + // TODO merge ScalarOperatorGens.generateEquals. + if (isInternalPrimitive(fieldType)) { + ("", s"$leftFieldTerm == $rightFieldTerm") + } else if (isCompositeType(fieldType)) { + val equaliserGenerator = new EqualiserCodeGenerator( + getFieldTypes(fieldType).asScala.toArray) + val generatedEqualiser = equaliserGenerator.generateRecordEqualiser("fieldGeneratedEqualiser") + val generatedEqualiserTerm = ctx.addReusableObject( + generatedEqualiser, "fieldGeneratedEqualiser") + val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName + val equaliserTerm = newName("equaliser") + ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = null;") + ctx.addReusableInitStatement( + s""" + |$equaliserTerm = ($equaliserTypeTerm) + | $generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader()); + |""".stripMargin) + ("", s"$equaliserTerm.equals($leftFieldTerm, $rightFieldTerm)") + } else { + val left = GeneratedExpression(leftFieldTerm, leftNullTerm, "", fieldType) + val right = GeneratedExpression(rightFieldTerm, rightNullTerm, "", fieldType) + val gen = generateEquals(ctx, left, right) + (gen.code, gen.resultTerm) + } } @tailrec diff --git a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/codegen/EqualiserCodeGeneratorTest.java b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/codegen/EqualiserCodeGeneratorTest.java index fe0e9a6c631d9..9ba9d003e7279 100644 --- a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/codegen/EqualiserCodeGeneratorTest.java +++ b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/codegen/EqualiserCodeGeneratorTest.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RawValueData; +import org.apache.flink.table.data.StringData; import org.apache.flink.table.data.TimestampData; import org.apache.flink.table.data.binary.BinaryRowData; import org.apache.flink.table.data.writer.BinaryRowWriter; @@ -30,13 +31,16 @@ import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.TimestampType; import org.apache.flink.table.types.logical.TypeInformationRawType; +import org.apache.flink.table.types.logical.VarCharType; import org.junit.Assert; import org.junit.Test; import java.util.function.Function; +import java.util.stream.IntStream; import static org.apache.flink.table.data.TimestampData.fromEpochMillis; +import static org.junit.Assert.assertTrue; /** Test for {@link EqualiserCodeGenerator}. */ public class EqualiserCodeGeneratorTest { @@ -81,6 +85,36 @@ public void testTimestamp() { assertBoolean(equaliser, func, fromEpochMillis(1024), fromEpochMillis(1025), false); } + @Test + public void testManyFields() { + final LogicalType[] fieldTypes = + IntStream.range(0, 999) + .mapToObj(i -> new VarCharType()) + .toArray(LogicalType[]::new); + + RecordEqualiser equaliser; + try { + equaliser = + new EqualiserCodeGenerator(fieldTypes) + .generateRecordEqualiser("ManyFields") + .newInstance(Thread.currentThread().getContextClassLoader()); + } catch (Exception e) { + Assert.fail("Expected compilation to succeed"); + + // Unreachable + throw e; + } + + final StringData[] fields = + IntStream.range(0, 999) + .mapToObj(i -> StringData.fromString("Entry " + i)) + .toArray(StringData[]::new); + assertTrue( + equaliser.equals( + GenericRowData.of((Object[]) fields), + GenericRowData.of((Object[]) fields))); + } + private static void assertBoolean( RecordEqualiser equaliser, Function toBinaryRow,