diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstCreator.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstCreator.scala index 25df11b98b3d..c84120a808a9 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstCreator.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstCreator.scala @@ -177,10 +177,10 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid ) case typedExpr: KtArrayAccessExpression => Seq(astForArrayAccess(typedExpr, argIdxMaybe, argNameMaybe)) case typedExpr: KtAnonymousInitializer => astsForExpression(typedExpr.getBody, argIdxMaybe) - case typedExpr: KtBinaryExpression => astsForBinaryExpr(typedExpr, argIdxMaybe, argNameMaybe) + case typedExpr: KtBinaryExpression => astsForBinaryExpr(typedExpr, argIdxMaybe, argNameMaybe, annotations) case typedExpr: KtBlockExpression => astsForBlock(typedExpr, argIdxMaybe, argNameMaybe) case typedExpr: KtBinaryExpressionWithTypeRHS => - Seq(astForBinaryExprWithTypeRHS(typedExpr, argIdxMaybe, argNameMaybe)) + Seq(astForBinaryExprWithTypeRHS(typedExpr, argIdxMaybe, argNameMaybe, annotations)) case typedExpr: KtBreakExpression => Seq(astForBreak(typedExpr)) case typedExpr: KtCallExpression => astsForCall(typedExpr, argIdxMaybe, argNameMaybe, annotations) case typedExpr: KtConstantExpression => Seq(astForLiteral(typedExpr, argIdxMaybe, argNameMaybe, annotations)) diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/KtPsiToAst.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/KtPsiToAst.scala index 0ed183ebc21c..bf6b24cf8dc8 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/KtPsiToAst.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/KtPsiToAst.scala @@ -581,13 +581,17 @@ trait KtPsiToAst(implicit withSchemaValidation: ValidationMode) { callAst(withArgumentName(withArgumentIndex(node, argIdx), argName), args.toList) } - def astForBinaryExprWithTypeRHS(expr: KtBinaryExpressionWithTypeRHS, argIdx: Option[Int], argName: Option[String])( - implicit typeInfoProvider: TypeInfoProvider - ): Ast = { + def astForBinaryExprWithTypeRHS( + expr: KtBinaryExpressionWithTypeRHS, + argIdx: Option[Int], + argName: Option[String], + annotations: Seq[KtAnnotationEntry] = Seq() + )(implicit typeInfoProvider: TypeInfoProvider): Ast = { registerType(typeInfoProvider.expressionType(expr, TypeConstants.any)) val args = astsForExpression(expr.getLeft, None) ++ Seq(astForTypeReference(expr.getRight, None, None)) val node = operatorCallNode(Operators.cast, expr.getText, None, line(expr), column(expr)) callAst(withArgumentName(withArgumentIndex(node, argIdx), argName), args.toList) + .withChildren(annotations.map(astForAnnotationEntry)) } def astForTypeReference(expr: KtTypeReference, argIdx: Option[Int], argName: Option[String])(implicit @@ -2144,9 +2148,12 @@ trait KtPsiToAst(implicit withSchemaValidation: ValidationMode) { .withChildren(annotationAsts) } - def astsForBinaryExpr(expr: KtBinaryExpression, argIdx: Option[Int], argNameMaybe: Option[String])(implicit - typeInfoProvider: TypeInfoProvider - ): Seq[Ast] = { + def astsForBinaryExpr( + expr: KtBinaryExpression, + argIdx: Option[Int], + argNameMaybe: Option[String], + annotations: Seq[KtAnnotationEntry] = Seq() + )(implicit typeInfoProvider: TypeInfoProvider): Seq[Ast] = { val opRef = expr.getOperationReference // TODO: add the rest of the operators @@ -2221,6 +2228,7 @@ trait KtPsiToAst(implicit withSchemaValidation: ValidationMode) { val rhsArgs = astsForExpression(expr.getRight, None) lhsArgs.dropRight(1) ++ rhsArgs.dropRight(1) ++ Seq( callAst(withArgumentIndex(node, argIdx).argumentName(argNameMaybe), List(lhsArgs.last, rhsArgs.last)) + .withChildren(annotations.map(astForAnnotationEntry)) ) } diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AnnotationsTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AnnotationsTests.scala index f1e80792e450..de82c52c3eda 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AnnotationsTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/AnnotationsTests.scala @@ -128,4 +128,36 @@ class AnnotationsTests extends KotlinCode2CpgFixture(withOssDataflow = false) { cpg.all.collectAll[Annotation].codeExact("@Fancy").size shouldBe 1 } } + + "CPG for code with an annotation on a binary expression" should { + val cpg = code(""" + |package mypkg + |@Target(AnnotationTarget.EXPRESSION) + |@Retention(AnnotationRetention.SOURCE) + |annotation class Fancy + |fun fn1() { + | @Fancy 1 + 1 + |} + |""".stripMargin) + + "contain an ANNOTATION node" in { + cpg.all.collectAll[Annotation].codeExact("@Fancy").size shouldBe 1 + } + } + + "CPG for code with an annotation on a binary expression with type RHS" should { + val cpg = code(""" + |package mypkg + |@Target(AnnotationTarget.EXPRESSION) + |@Retention(AnnotationRetention.SOURCE) + |annotation class Fancy + |fun fn1() { + | @Fancy 1 is Int + |} + |""".stripMargin) + + "contain an ANNOTATION node" in { + cpg.all.collectAll[Annotation].codeExact("@Fancy").size shouldBe 1 + } + } }