Skip to content

Commit

Permalink
Dataflow Engine: Slowdown When Gathering Large Number of Starting Poi…
Browse files Browse the repository at this point in the history
…nts (joernio#1946)

- Expressed the source traversal to starting point problem as a fork-join problem
- Instances where deep search spaces handled by a single thread are now naturally split up among available tasks
  • Loading branch information
DavidBakerEffendi committed Nov 4, 2022
1 parent 66162dc commit 066112d
Showing 1 changed file with 96 additions and 73 deletions.
Original file line number Diff line number Diff line change
@@ -1,35 +1,32 @@
package io.joern.dataflowengineoss.language

import io.joern.dataflowengineoss.DefaultSemantics
import io.shiftleft.codepropertygraph.generated.nodes.{
Call,
CfgNode,
Expression,
FieldIdentifier,
Identifier,
Literal,
Member,
StoredNode,
TypeDecl
}
import io.joern.dataflowengineoss.queryengine.{Engine, EngineContext, PathElement, ReachableByResult}
import io.joern.dataflowengineoss.semanticsloader.Semantics
import io.joern.x2cpg.Defines
import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.codepropertygraph.generated.Operators
import overflowdb.traversal._
import io.shiftleft.codepropertygraph.generated.nodes._
import io.shiftleft.semanticcpg.language._
import org.slf4j.LoggerFactory
import overflowdb.traversal._

import java.util.concurrent.{
ForkJoinPool,
ForkJoinTask,
RecursiveTask,
RejectedExecutionException,
RejectedExecutionHandler
}
import scala.collection.mutable
import scala.util.{Failure, Success, Try}

case class StartingPointWithSource(startingPoint: CfgNode, source: StoredNode)

/** Base class for nodes that can occur in data flows
*/
class ExtendedCfgNode(val traversal: Traversal[CfgNode]) extends AnyVal {

import ExtendedCfgNode._

def ddgIn(implicit semantics: Semantics = DefaultSemantics()): Traversal[CfgNode] = {
val cache = mutable.HashMap[CfgNode, Vector[PathElement]]()
val result = traversal.flatMap(x => x.ddgIn(Vector(PathElement(x)), withInvisible = false, cache))
Expand All @@ -45,13 +42,14 @@ class ExtendedCfgNode(val traversal: Traversal[CfgNode]) extends AnyVal {
}

def reachableBy[NodeType](sourceTravs: Traversal[NodeType]*)(implicit context: EngineContext): Traversal[NodeType] = {
val sources = ExtendedCfgNode.sourceTravsToStartingPoints(sourceTravs: _*)
val reachedSources =
reachableByInternal(sourceTravsToStartingPoints(sourceTravs)).map(_.startingPoint)
reachableByInternal(sources).map(_.startingPoint)
Traversal.from(reachedSources).cast[NodeType]
}

def reachableByFlows[A](sourceTravs: Traversal[A]*)(implicit context: EngineContext): Traversal[Path] = {
val sources = sourceTravsToStartingPoints(sourceTravs)
val sources = ExtendedCfgNode.sourceTravsToStartingPoints(sourceTravs: _*)
val startingPoints = sources.map(_.startingPoint)
val paths = reachableByInternal(sources)
.map { result =>
Expand All @@ -75,7 +73,8 @@ class ExtendedCfgNode(val traversal: Traversal[CfgNode]) extends AnyVal {
def reachableByDetailed[NodeType](
sourceTravs: Traversal[NodeType]*
)(implicit context: EngineContext): List[ReachableByResult] = {
reachableByInternal(sourceTravsToStartingPoints(sourceTravs))
val sources = ExtendedCfgNode.sourceTravsToStartingPoints(sourceTravs: _*)
reachableByInternal(sources)
}

private def removeConsecutiveDuplicates[T](l: Vector[T]): List[T] = {
Expand Down Expand Up @@ -107,65 +106,27 @@ class ExtendedCfgNode(val traversal: Traversal[CfgNode]) extends AnyVal {

object ExtendedCfgNode {

/** The code below deals with member variables, and specifically with the situation where literals that initialize
* static members are passed to `reachableBy` as sources. In this case, we determine the first usages of this member
* in each method, traversing the AST from left to right. This isn't fool-proof, e.g., goto-statements would be
* problematic, but it works quite well in practice.
*/
def sourceToStartingPoints[NodeType](src: NodeType): List[CfgNode] = {
src match {
case lit: Literal =>
List(lit) ++ usages(targetsToClassIdentifierPair(literalToInitializedMembers(lit)))
case member: Member =>
val initializedMember = memberToInitializedMembers(member)
usages(targetsToClassIdentifierPair(initializedMember))
case x => List(x).collect { case y: CfgNode => y }
private val log = LoggerFactory.getLogger(ExtendedCfgNode.getClass)

def sourceTravsToStartingPoints[NodeType](sourceTravs: Traversal[NodeType]*): List[StartingPointWithSource] = {
val fjp = ForkJoinPool.commonPool()
try {
fjp.invoke(new SourceTravsToStartingPointsTask(sourceTravs: _*))
} catch {
case e: RejectedExecutionException =>
log.error("Unable to execute 'SourceTravsToStartingPoints` task", e); List()
} finally {
fjp.shutdown()
}
}
}

def sourceTravsToStartingPoints[NodeType](sourceTravs: Seq[Traversal[NodeType]]): List[StartingPointWithSource] = {
val sources = sourceTravs
.flatMap(_.toList)
.collect { case n: StoredNode => n }
.dedup
.toList
.sortBy(_.id)
sources.flatMap { src =>
sourceToStartingPoints(src).map(s => StartingPointWithSource(s, src))
}
}

/** For a literal, determine if it is used in the initialization of any member variables. Return list of initialized
* members. An initialized member is either an identifier or a field-identifier.
*/
private def literalToInitializedMembers(lit: Literal): List[Expression] = {
lit.inAssignment
.where(_.method.nameExact(Defines.StaticInitMethodName, Defines.ConstructorMethodName))
.target
.flatMap {
case identifier: Identifier => List(identifier)
case call: Call if call.name == Operators.fieldAccess =>
call.ast.isFieldIdentifier.l
case _ => List[Expression]()
}
.l
}

private def memberToInitializedMembers(member: Member): List[Expression] = {
member.typeDecl.method
.nameExact(Defines.StaticInitMethodName, Defines.ConstructorMethodName)
.ast
.flatMap { x =>
x match {
case identifier: Identifier if identifier.name == member.name =>
Traversal(identifier).argumentIndex(1).where(_.inAssignment).l
case fieldIdentifier: FieldIdentifier if fieldIdentifier.canonicalName == member.head.name =>
Traversal(fieldIdentifier).where(_.inAssignment).l
case _ => List[Expression]()
}
}
.l
}
/** The code below deals with member variables, and specifically with the situation where literals that initialize
* static members are passed to `reachableBy` as sources. In this case, we determine the first usages of this member in
* each method, traversing the AST from left to right. This isn't fool-proof, e.g., goto-statements would be
* problematic, but it works quite well in practice.
*/
class SourceToStartingPoints(src: StoredNode) extends RecursiveTask[List[CfgNode]] {

private def usages(pairs: List[(TypeDecl, Expression)]): List[CfgNode] = {
pairs.flatMap { case (typeDecl, expression) =>
Expand Down Expand Up @@ -204,12 +165,74 @@ object ExtendedCfgNode {
}
}

/** For a literal, determine if it is used in the initialization of any member variables. Return list of initialized
* members. An initialized member is either an identifier or a field-identifier.
*/
private def literalToInitializedMembers(lit: Literal): List[Expression] = {
lit.inAssignment
.where(_.method.nameExact(Defines.StaticInitMethodName, Defines.ConstructorMethodName))
.target
.flatMap {
case identifier: Identifier => List(identifier)
case call: Call if call.name == Operators.fieldAccess =>
call.ast.isFieldIdentifier.l
case _ => List[Expression]()
}
.l
}

private def memberToInitializedMembers(member: Member): List[Expression] = {
member.typeDecl.method
.nameExact(Defines.StaticInitMethodName, Defines.ConstructorMethodName)
.ast
.flatMap { x =>
x match {
case identifier: Identifier if identifier.name == member.name =>
Traversal(identifier).argumentIndex(1).where(_.inAssignment).l
case fieldIdentifier: FieldIdentifier if fieldIdentifier.canonicalName == member.head.name =>
Traversal(fieldIdentifier).where(_.inAssignment).l
case _ => List[Expression]()
}
}
.l
}

private def notLeftHandOfAssignment(x: Expression): Boolean = {
!(x.argumentIndex == 1 && x.inAssignment.nonEmpty)
}

private def targetsToClassIdentifierPair(targets: List[Expression]): List[(TypeDecl, Expression)] = {
targets.flatMap(target => target.method.typeDecl.map { typeDecl => (typeDecl, target) })
}
override def compute(): List[CfgNode] =
src match {
case lit: Literal =>
List(lit) ++ usages(targetsToClassIdentifierPair(literalToInitializedMembers(lit)))
case member: Member =>
val initializedMember = memberToInitializedMembers(member)
usages(targetsToClassIdentifierPair(initializedMember))
case x => List(x).collect { case y: CfgNode => y }
}
}

class SourceTravsToStartingPointsTask[NodeType](sourceTravs: Traversal[NodeType]*)
extends RecursiveTask[List[StartingPointWithSource]] {

private val log = LoggerFactory.getLogger(this.getClass)

override def compute(): List[StartingPointWithSource] = {
val sources: List[StoredNode] = sourceTravs
.flatMap(_.toList)
.collect { case n: StoredNode => n }
.dedup
.toList
.sortBy(_.id)
val tasks = sources.map(src => (src, new SourceToStartingPoints(src).fork()))
tasks.flatMap { case (src, t: ForkJoinTask[List[CfgNode]]) =>
Try(t.get()) match {
case Failure(e) => log.error("Unable to complete 'SourceToStartingPoints' task", e); List()
case Success(sources) => sources.map(s => StartingPointWithSource(s, src))
}
}
}
}

0 comments on commit 066112d

Please sign in to comment.