Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix udf-compiler scala2.13 internal return statements #11553

Open
wants to merge 5 commits into
base: branch-24.12
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions scala2.13/udf-compiler/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,6 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<skipTests>true</skipTests>
</configuration>
jlowe marked this conversation as resolved.
Show resolved Hide resolved
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
Expand Down
194 changes: 125 additions & 69 deletions udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala
jlowe marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,40 @@ object CFG {
* Iterate through the code to find out the basic blocks
*/
def apply(lambdaReflection: LambdaReflection): CFG = {
// find the last return in this lambda expression.
// we use this in scala 2.13+ because we see new RETURN instructions where
// scala 2.12 would before use GOTO. We are undoing this while we parse the
// bytecode because we would like to not complicate our code parsing logic
// to merge different branches (each RETURN would be a leaf). By turning this
// into GOTO [last return] we bring back the old behavior of scala 2.12.
val codeIterator = lambdaReflection.codeIterator
codeIterator.begin()
var lastReturnOffset = 0
while(codeIterator.hasNext) {
val offset = codeIterator.next()
val opcode: Int = codeIterator.byteAt(offset)
if (opcode == Opcode.ARETURN ||
opcode == Opcode.FRETURN ||
opcode == Opcode.IRETURN ||
opcode == Opcode.DRETURN ||
opcode == Opcode.LRETURN ||
opcode == Opcode.RETURN) {
lastReturnOffset = offset
}
}

// labels: targets of branching instructions (offset)
// edges: connection between branch instruction offset, and target offsets (successors)
// if ifeq then there would be a true and a false successor
// if return there would be no successors (likely)
// goto has 1 successors
codeIterator.begin()
val (labels, edges) = collectLabelsAndEdges(codeIterator, lambdaReflection.constPool)
val (labels, edges) = collectLabelsAndEdges(
codeIterator, lambdaReflection.constPool, lastReturnOffset)

codeIterator.begin() // rewind
val instructionTable = createInstructionTable(codeIterator, lambdaReflection.constPool)
val instructionTable = createInstructionTable(
codeIterator, lambdaReflection.constPool, lastReturnOffset)

val (basicBlocks, offsetToBB) = createBasicBlocks(labels, instructionTable)

Expand All @@ -163,6 +185,7 @@ object CFG {
@tailrec
private def collectLabelsAndEdges(codeIterator: CodeIterator,
constPool: ConstPool,
lastReturnOffset: Int,
labels: SortedSet[Int] = SortedSet(),
edges: SortedMap[Int, List[(Int, Int)]] = SortedMap())
: (SortedSet[Int], SortedMap[Int, List[(Int, Int)]]) = {
Expand All @@ -171,72 +194,88 @@ object CFG {
val nextOffset: Int = codeIterator.lookAhead
val opcode: Int = codeIterator.byteAt(offset)
// here we are looking for branching instructions
opcode match {
case Opcode.IF_ICMPEQ | Opcode.IF_ICMPNE | Opcode.IF_ICMPLT |
Opcode.IF_ICMPGE | Opcode.IF_ICMPGT | Opcode.IF_ICMPLE |
Opcode.IFEQ | Opcode.IFNE | Opcode.IFLT | Opcode.IFGE |
Opcode.IFGT | Opcode.IFLE | Opcode.IFNULL | Opcode.IFNONNULL =>
// an if statement has two other offsets, false and true branches.
if ((opcode == Opcode.ARETURN ||
abellina marked this conversation as resolved.
Show resolved Hide resolved
opcode == Opcode.FRETURN ||
opcode == Opcode.IRETURN ||
opcode == Opcode.DRETURN ||
opcode == Opcode.LRETURN ||
opcode == Opcode.RETURN) && offset != lastReturnOffset) {
// if we had any return along the way, we are going to replace it
// with a GOTO [lastReturnOffset]
collectLabelsAndEdges(
codeIterator, constPool, lastReturnOffset,
labels + lastReturnOffset,
edges + (offset -> List((0, lastReturnOffset))))
} else {
opcode match {
case Opcode.IF_ICMPEQ | Opcode.IF_ICMPNE | Opcode.IF_ICMPLT |
Opcode.IF_ICMPGE | Opcode.IF_ICMPGT | Opcode.IF_ICMPLE |
Opcode.IFEQ | Opcode.IFNE | Opcode.IFLT | Opcode.IFGE |
Opcode.IFGT | Opcode.IFLE | Opcode.IFNULL | Opcode.IFNONNULL =>
// an if statement has two other offsets, false and true branches.

// the false offset is the next offset, per the definition of if<cond>
val falseOffset = nextOffset
// the false offset is the next offset, per the definition of if<cond>
val falseOffset = nextOffset

// in jvm, the if<cond> ops are followed by two bytes, which are to be
// used together (s16bitAt does this for us) only for the success case of the if
val trueOffset = offset + codeIterator.s16bitAt(offset + 1)
// in jvm, the if<cond> ops are followed by two bytes, which are to be
// used together (s16bitAt does this for us) only for the success case of the if
val trueOffset = offset + codeIterator.s16bitAt(offset + 1)

// keep iterating, having added the false and true offsets to the labels,
// and having added the edges (if offset -> List(false offset, true offset))
collectLabelsAndEdges(
codeIterator, constPool,
labels + falseOffset + trueOffset,
edges + (offset -> List((0, falseOffset), (1, trueOffset))))
case Opcode.TABLESWITCH =>
val defaultOffset = (offset + 4) / 4 * 4
val default = (-1, offset + codeIterator.s32bitAt(defaultOffset))
val lowOffset = defaultOffset + 4
val low = codeIterator.s32bitAt(lowOffset)
val highOffset = lowOffset + 4
val high = codeIterator.s32bitAt(highOffset)
val tableOffset = highOffset + 4
val table = List.tabulate(high - low + 1) { i =>
(low + i, offset + codeIterator.s32bitAt(tableOffset + i * 4))
} :+ default
collectLabelsAndEdges(
codeIterator, constPool,
labels ++ table.map(_._2),
edges + (offset -> table))
case Opcode.LOOKUPSWITCH =>
val defaultOffset = (offset + 4) / 4 * 4
val default = (-1, offset + codeIterator.s32bitAt(defaultOffset))
val npairsOffset = defaultOffset + 4
val npairs = codeIterator.s32bitAt(npairsOffset)
val tableOffset = npairsOffset + 4
val table = List.tabulate(npairs) { i =>
(codeIterator.s32bitAt(tableOffset + i * 8),
// keep iterating, having added the false and true offsets to the labels,
// and having added the edges (if offset -> List(false offset, true offset))
collectLabelsAndEdges(
codeIterator, constPool, lastReturnOffset,
labels + falseOffset + trueOffset,
edges + (offset -> List((0, falseOffset), (1, trueOffset))))
case Opcode.TABLESWITCH =>
val defaultOffset = (offset + 4) / 4 * 4
val default = (-1, offset + codeIterator.s32bitAt(defaultOffset))
val lowOffset = defaultOffset + 4
val low = codeIterator.s32bitAt(lowOffset)
val highOffset = lowOffset + 4
val high = codeIterator.s32bitAt(highOffset)
val tableOffset = highOffset + 4
val table = List.tabulate(high - low + 1) { i =>
(low + i, offset + codeIterator.s32bitAt(tableOffset + i * 4))
} :+ default
collectLabelsAndEdges(
codeIterator, constPool, lastReturnOffset,
labels ++ table.map(_._2),
edges + (offset -> table))
case Opcode.LOOKUPSWITCH =>
val defaultOffset = (offset + 4) / 4 * 4
val default = (-1, offset + codeIterator.s32bitAt(defaultOffset))
val npairsOffset = defaultOffset + 4
val npairs = codeIterator.s32bitAt(npairsOffset)
val tableOffset = npairsOffset + 4
val table = List.tabulate(npairs) { i =>
(codeIterator.s32bitAt(tableOffset + i * 8),
offset + codeIterator.s32bitAt(tableOffset + i * 8 + 4))
} :+ default
collectLabelsAndEdges(
codeIterator, constPool,
labels ++ table.map(_._2),
edges + (offset -> table))
case Opcode.GOTO | Opcode.GOTO_W =>
// goto statements have a single address target, we must go there
val getOffset = if (opcode == Opcode.GOTO) {
codeIterator.s16bitAt(_)
} else {
codeIterator.s32bitAt(_)
}
val labelOffset = offset + getOffset(offset + 1)
collectLabelsAndEdges(
codeIterator, constPool,
labels + labelOffset,
edges + (offset -> List((0, labelOffset))))
case Opcode.IF_ACMPEQ | Opcode.IF_ACMPNE |
Opcode.JSR | Opcode.JSR_W | Opcode.RET =>
val instructionStr = InstructionPrinter.instructionString(codeIterator, offset, constPool)
throw new SparkException("Unsupported instruction: " + instructionStr)
case _ => collectLabelsAndEdges(codeIterator, constPool, labels, edges)
} :+ default
collectLabelsAndEdges(
codeIterator, constPool, lastReturnOffset,
labels ++ table.map(_._2),
edges + (offset -> table))
case Opcode.GOTO | Opcode.GOTO_W =>
// goto statements have a single address target, we must go there
val getOffset = if (opcode == Opcode.GOTO) {
codeIterator.s16bitAt(_)
} else {
codeIterator.s32bitAt(_)
}
val labelOffset = offset + getOffset(offset + 1)
collectLabelsAndEdges(
codeIterator, constPool, lastReturnOffset,
labels + labelOffset,
edges + (offset -> List((0, labelOffset))))
case Opcode.IF_ACMPEQ | Opcode.IF_ACMPNE |
Opcode.JSR | Opcode.JSR_W | Opcode.RET =>
val instructionStr = InstructionPrinter.instructionString(
codeIterator, offset, constPool)
throw new SparkException("Unsupported instruction: " + instructionStr)
case _ => collectLabelsAndEdges(codeIterator, constPool, lastReturnOffset,
labels, edges)
}
}
} else {
// base case
Expand All @@ -245,15 +284,32 @@ object CFG {
}

@tailrec
private def createInstructionTable(codeIterator: CodeIterator, constPool: ConstPool,
private def createInstructionTable(
codeIterator: CodeIterator,
constPool: ConstPool,
lastReturnOffset: Int,
instructionTable: SortedMap[Int, Instruction] = SortedMap())
: SortedMap[Int, Instruction] = {
if (codeIterator.hasNext) {
val offset = codeIterator.next
val instructionStr = InstructionPrinter.instructionString(codeIterator, offset, constPool)
val instruction = Instruction(codeIterator, offset, instructionStr)
createInstructionTable(codeIterator, constPool,
instructionTable + (offset -> instruction))
val opcode = codeIterator.byteAt(offset)
if ((opcode == Opcode.ARETURN ||
opcode == Opcode.FRETURN ||
opcode == Opcode.IRETURN ||
opcode == Opcode.DRETURN ||
opcode == Opcode.LRETURN ||
opcode == Opcode.RETURN) && offset != lastReturnOffset) {
// an internal RETURN is replaced by GOTO to the last return of the
// lambda.
val instruction = Instruction(Opcode.GOTO, lastReturnOffset, "GOTO")
createInstructionTable(codeIterator, constPool, lastReturnOffset,
instructionTable + (offset -> instruction))
} else {
val instructionStr = InstructionPrinter.instructionString(codeIterator, offset, constPool)
val instruction = Instruction(codeIterator, offset, instructionStr)
createInstructionTable(codeIterator, constPool, lastReturnOffset,
instructionTable + (offset -> instruction))
}
} else {
instructionTable
}
Expand Down
jlowe marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class OpcodeSuite extends AnyFunSuite {
val conf: SparkConf = new SparkConf()
.set("spark.sql.extensions", "com.nvidia.spark.udf.Plugin")
.set("spark.rapids.sql.udfCompiler.enabled", "true")
.set("spark.sql.ansi.enabled", "false")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why we need to disable ANSI here? I assume this is for Spark 4.0, but then I would expect there to be corresponding negative tests where we are testing that we do not replace UDFs because of ANSI expression semantics. Seems like we need a comment and/or pointer to a tracking issue here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct that is due to spark 4, let me file something at least with the failures I see with ansi enabled.

Copy link
Collaborator Author

@abellina abellina Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #11633 to address as a separate issue. One thing that is nice about this failure is that it is detected at plan time, and we won't cause a corruption.

.set(RapidsConf.EXPLAIN.key, "true")

val spark: SparkSession =
Expand Down Expand Up @@ -2384,7 +2385,8 @@ class OpcodeSuite extends AnyFunSuite {
run(20)
} catch {
case e: RuntimeException =>
assert(e.getMessage == "Fold number must be in range [0, 20), but got 20.")
// in new versions of spark, the message has extra information, so we use contains.
assert(e.getMessage.contains("Fold number must be in range [0, 20), but got 20."))
}
}

Expand Down
Loading