Skip to content

Commit

Permalink
rewrite extends App to def main (joernio#1968)
Browse files Browse the repository at this point in the history
Scala3 doesn't support the 'magic' DelayedInit trait which is used by
`extends App` - the recommended way for cross compilation is to use
regular `def main(args: Array[String])` methods instead.

See https://docs.scala-lang.org/scala3/book/methods-main-methods.html
  • Loading branch information
mpollmeier committed Nov 11, 2022
1 parent 532913a commit 8073690
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,25 @@ class ParsedArguments(arguments: Seq[String]) extends ScallopConf(arguments) {
verify()
}

object Main extends App {
val parsedArguments = new ParsedArguments(args.toSeq)

val ignoreVenvDir =
if (parsedArguments.ignoreVenvDir.toOption.get) {
parsedArguments.venvDir.toOption
} else {
None
}

val py2CpgConfig =
Py2CpgOnFileSystemConfig(
Paths.get(parsedArguments.output.toOption.get),
Paths.get(parsedArguments.input.toOption.get),
ignoreVenvDir.map(Paths.get(_))
)

val cpg = Py2CpgOnFileSystem.buildCpg(py2CpgConfig)
cpg.close()
object Main {
def main(args: Array[String]) = {
val parsedArguments = new ParsedArguments(args.toSeq)

val ignoreVenvDir =
if (parsedArguments.ignoreVenvDir.toOption.get) {
parsedArguments.venvDir.toOption
} else {
None
}

val py2CpgConfig =
Py2CpgOnFileSystemConfig(
Paths.get(parsedArguments.output.toOption.get),
Paths.get(parsedArguments.input.toOption.get),
ignoreVenvDir.map(Paths.get(_))
)

val cpg = Py2CpgOnFileSystem.buildCpg(py2CpgConfig)
cpg.close()
}
}
73 changes: 37 additions & 36 deletions joern-cli/src/main/scala/io/joern/joerncli/JoernFlow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,44 @@ case class FlowConfig(
depth: Int = 1
)

object JoernFlow extends App {
object JoernFlow {
def main(args: Array[String]) = {
parseConfig(args).foreach { config =>
def debugOut(msg: String): Unit = {
if (config.verbose) {
print(msg)
}
}

debugOut("Loading graph... ")
val cpg = CpgBasedTool.loadFromOdb(config.cpgFileName)
debugOut("[DONE]\n")

implicit val resolver: ICallResolver = NoResolve
val sources = params(cpg, config.srcRegex, config.srcParam)
val sinks = params(cpg, config.dstRegex, config.dstParam)

private def parseConfig: Option[FlowConfig] = {
debugOut(s"Number of sources: ${sources.size}\n")
debugOut(s"Number of sinks: ${sinks.size}\n")

implicit val semantics: Semantics = DefaultSemantics()
val engineConfig = EngineConfig(config.depth)
debugOut(s"Analysis depth: ${engineConfig.maxCallDepth}\n")
implicit val context: EngineContext = EngineContext(semantics, engineConfig)

debugOut("Determining flows...")
sinks.foreach { s =>
List(s).to(Traversal).reachableByFlows(sources.to(Traversal)).p.foreach(println)
}
debugOut("[DONE]")

debugOut("Closing graph... ")
cpg.close()
debugOut("[DONE]\n")
}
}

private def parseConfig(args: Array[String]): Option[FlowConfig] = {
new scopt.OptionParser[FlowConfig]("joern-flow") {
head("Find flows")
help("help")
Expand Down Expand Up @@ -62,40 +97,6 @@ object JoernFlow extends App {
}
}.parse(args, FlowConfig())

parseConfig.foreach { config =>
def debugOut(msg: String): Unit = {
if (config.verbose) {
print(msg)
}
}

debugOut("Loading graph... ")
val cpg = CpgBasedTool.loadFromOdb(config.cpgFileName)
debugOut("[DONE]\n")

implicit val resolver: ICallResolver = NoResolve
val sources = params(cpg, config.srcRegex, config.srcParam)
val sinks = params(cpg, config.dstRegex, config.dstParam)

debugOut(s"Number of sources: ${sources.size}\n")
debugOut(s"Number of sinks: ${sinks.size}\n")

implicit val semantics: Semantics = DefaultSemantics()
val engineConfig = EngineConfig(config.depth)
debugOut(s"Analysis depth: ${engineConfig.maxCallDepth}\n")
implicit val context: EngineContext = EngineContext(semantics, engineConfig)

debugOut("Determining flows...")
sinks.foreach { s =>
List(s).to(Traversal).reachableByFlows(sources.to(Traversal)).p.foreach(println)
}
debugOut("[DONE]")

debugOut("Closing graph... ")
cpg.close()
debugOut("[DONE]\n")
}

private def params(cpg: Cpg, methodNameRegex: String, paramIndex: Option[Int]): List[MethodParameterIn] = {
cpg
.method(methodNameRegex)
Expand Down
43 changes: 23 additions & 20 deletions joern-cli/src/main/scala/io/joern/joerncli/JoernParse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,20 @@ import io.shiftleft.codepropertygraph.generated.Languages
import scala.collection.mutable
import scala.jdk.CollectionConverters._

object JoernParse extends App {

object JoernParse {
// Special string used to separate joern-parse opts from frontend-specific opts
val ARGS_DELIMITER = "--frontend-args"
val ARGS_DELIMITER = "--frontend-args"
val DEFAULT_CPG_OUT_FILE = "cpg.bin"
var generator: CpgGenerator = _

def main(args: Array[String]) = {
run(args) match {
case Right(msg) => println(msg)
case Left(errMsg) =>
println(s"Failure: $errMsg")
System.exit(1)
}
}

val optionParser = new scopt.OptionParser[ParserConfig]("joern-parse") {
arg[String]("input")
Expand Down Expand Up @@ -55,22 +65,10 @@ object JoernParse extends App {
note(s"Args specified after the $ARGS_DELIMITER separator will be passed to the front-end verbatim")
}

val DEFAULT_CPG_OUT_FILE = "cpg.bin"

val (parserArgs, frontendArgs) = CpgBasedTool.splitArgs(args)
val installConfig = new InstallConfig()

var generator: CpgGenerator = _

run() match {
case Right(msg) => println(msg)

case Left(errMsg) =>
println(s"Failure: $errMsg")
System.exit(1)
}
private def run(args: Array[String]): Either[String, String] = {
val (parserArgs, frontendArgs) = CpgBasedTool.splitArgs(args)
val installConfig = new InstallConfig()

private def run(): Either[String, String] = {
parseConfig(parserArgs) match {
case Right(config) =>
if (config.listLanguages) {
Expand All @@ -79,7 +77,7 @@ object JoernParse extends App {
for {
_ <- checkInputPath(config)
language <- getLanguage(config)
_ <- generateCpg(config, language)
_ <- generateCpg(installConfig, frontendArgs, config, language)
_ <- applyDefaultOverlays(config)
} yield newCpgCreatedString(config.outputCpgFile)

Expand Down Expand Up @@ -122,7 +120,12 @@ object JoernParse extends App {
}
}

private def generateCpg(config: ParserConfig, language: String): Either[String, String] = {
private def generateCpg(
installConfig: InstallConfig,
frontendArgs: List[String],
config: ParserConfig,
language: String
): Either[String, String] = {
if (config.enhanceOnly) {
Right("No generation required")
} else {
Expand Down
17 changes: 10 additions & 7 deletions joern-cli/src/main/scala/io/joern/joerncli/JoernScan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,16 @@ case class JoernScanConfig(
listLanguages: Boolean = false
)

object JoernScan extends App with BridgeBase {
object JoernScan extends BridgeBase {

val implementationVersion = getClass.getPackage.getImplementationVersion

val (scanArgs, frontendArgs) = CpgBasedTool.splitArgs(args)
def main(args: Array[String]) = {
val (scanArgs, frontendArgs) = CpgBasedTool.splitArgs(args)
optionParser.parse(scanArgs, JoernScanConfig()).foreach { config =>
run(config, frontendArgs)
}
}

val optionParser = new scopt.OptionParser[JoernScanConfig]("joern-scan") {
head(
Expand Down Expand Up @@ -106,9 +111,7 @@ object JoernScan extends App with BridgeBase {
note(s"Args specified after the ${CpgBasedTool.ARGS_DELIMITER} separator will be passed to the front-end verbatim")
}

optionParser.parse(scanArgs, JoernScanConfig()).foreach(run)

private def run(config: JoernScanConfig): Unit = {
private def run(config: JoernScanConfig, frontendArgs: List[String]): Unit = {
if (config.dump) {
dumpQueriesAsJson(config.dumpDestination)
} else if (config.listQueryNames) {
Expand All @@ -118,7 +121,7 @@ object JoernScan extends App with BridgeBase {
} else if (config.updateQueryDb) {
updateQueryDatabase(config.queryDbVersion)
} else {
runScanPlugin(config)
runScanPlugin(config, frontendArgs)
}
}

Expand All @@ -143,7 +146,7 @@ object JoernScan extends App with BridgeBase {
println(s.toString())
}

private def runScanPlugin(config: JoernScanConfig): Unit = {
private def runScanPlugin(config: JoernScanConfig, frontendArgs: List[String]): Unit = {

if (config.src == "") {
println(optionParser.usage)
Expand Down
21 changes: 12 additions & 9 deletions joern-cli/src/main/scala/io/joern/joerncli/JoernSlice.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.codepropertygraph.generated.nodes.CfgNode
import overflowdb.Edge

object JoernSlice extends App {
object JoernSlice {
case class Config(
cpgFileName: String = "cpg.bin",
outFile: String = "slice.bin",
Expand All @@ -22,7 +22,17 @@ object JoernSlice extends App {

case class Slice(nodes: List[CfgNode], edges: Map[CfgNode, List[Edge]])

private def parseConfig: Option[Config] =
def main(args: Array[String]) = {
parseConfig(args).foreach { config =>
Using.resource(CpgBasedTool.loadFromOdb(config.cpgFileName)) { cpg =>
val slice = calculateSlice(cpg, config.sourceFile, config.sourceLine)
storeSliceInNewCpg(config.outFile, slice)
}
}

}

private def parseConfig(args: Array[String]): Option[Config] =
new scopt.OptionParser[Config]("joern-slice") {
head("Extract intra-procedural backward slice for a line of code")
help("help")
Expand All @@ -41,13 +51,6 @@ object JoernSlice extends App {
.action((x, c) => c.copy(outFile = x))
}.parse(args, Config())

parseConfig.foreach { config =>
Using.resource(CpgBasedTool.loadFromOdb(config.cpgFileName)) { cpg =>
val slice = calculateSlice(cpg, config.sourceFile, config.sourceLine)
storeSliceInNewCpg(config.outFile, slice)
}
}

private def calculateSlice(cpg: Cpg, sourceFile: String, sourceLine: Int): Slice = {
val sinks = cpg.file.nameExact(sourceFile).ast.lineNumber(sourceLine).isCall.argument.l

Expand Down
56 changes: 29 additions & 27 deletions joern-cli/src/main/scala/io/joern/joerncli/JoernVectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,39 @@ trait EmbeddingGenerator[T, S] {

}

object JoernVectors extends App {
object JoernVectors {

implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats
case class Config(cpgFileName: String = "cpg.bin", outDir: String = "out", dimToFeature: Boolean = false)

private def parseConfig: Option[Config] =
def main(args: Array[String]) = {
parseConfig(args).foreach { config =>
exitIfInvalid(config.outDir, config.cpgFileName)
Using.resource(CpgBasedTool.loadFromOdb(config.cpgFileName)) { cpg =>
val generator = new BagOfPropertiesForNodes()
val embedding = generator.embed(cpg)
println("{")
println("\"objects\":")
traversalToJson(embedding.objects, { x: String => generator.defaultToString(x) })
if (config.dimToFeature) {
println(",\"dimToFeature\": ")
println(Serialization.write(embedding.dimToStructure))
}
println(",\"vectors\":")
traversalToJson(embedding.vectors, generator.vectorToString)
println(",\"edges\":")
traversalToJson(
cpg.graph.edges().map { x =>
Map("src" -> x.outNode().id(), "dst" -> x.inNode().id(), "label" -> x.label())
},
{ x: Map[String, Any] => generator.defaultToString(x) }
)
println("}")
}
}
}

private def parseConfig(args: Array[String]): Option[Config] =
new scopt.OptionParser[Config]("joern-vectors") {
head("Extract vector representations of code from CPG")
help("help")
Expand All @@ -155,31 +182,6 @@ object JoernVectors extends App {
.action((_, c) => c.copy(dimToFeature = true))
}.parse(args, Config())

parseConfig.foreach { config =>
exitIfInvalid(config.outDir, config.cpgFileName)
Using.resource(CpgBasedTool.loadFromOdb(config.cpgFileName)) { cpg =>
val generator = new BagOfPropertiesForNodes()
val embedding = generator.embed(cpg)
println("{")
println("\"objects\":")
traversalToJson(embedding.objects, { x: String => generator.defaultToString(x) })
if (config.dimToFeature) {
println(",\"dimToFeature\": ")
println(Serialization.write(embedding.dimToStructure))
}
println(",\"vectors\":")
traversalToJson(embedding.vectors, generator.vectorToString)
println(",\"edges\":")
traversalToJson(
cpg.graph.edges().map { x =>
Map("src" -> x.outNode().id(), "dst" -> x.inNode().id(), "label" -> x.label())
},
{ x: Map[String, Any] => generator.defaultToString(x) }
)
println("}")
}
}

private def traversalToJson[X](trav: Traversal[X], vectorToString: X => String): Unit = {
println("[")
trav.nextOption().foreach { vector => print(vectorToString(vector)) }
Expand Down
Loading

0 comments on commit 8073690

Please sign in to comment.