Skip to content

Commit

Permalink
kotlin2cpg: handle annotations for various exprs
Browse files Browse the repository at this point in the history
  • Loading branch information
ursachec committed Aug 15, 2023
1 parent c3e276e commit 7f28c41
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,25 +196,30 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid
case typedExpr: KtDoWhileExpression => Seq(astForDoWhile(typedExpr))
case typedExpr: KtForExpression => Seq(astForFor(typedExpr))
case typedExpr: KtIfExpression => Seq(astForIf(typedExpr, argIdxMaybe, argNameMaybe, annotations))
case typedExpr: KtIsExpression => Seq(astForIsExpression(typedExpr, argIdxMaybe, argNameMaybe))
case typedExpr: KtLabeledExpression => astsForExpression(typedExpr.getBaseExpression, argIdxMaybe, argNameMaybe)
case typedExpr: KtLambdaExpression => Seq(astForLambda(typedExpr, argIdxMaybe, argNameMaybe))
case typedExpr: KtIsExpression => Seq(astForIsExpression(typedExpr, argIdxMaybe, argNameMaybe, annotations))
case typedExpr: KtLabeledExpression =>
astsForExpression(typedExpr.getBaseExpression, argIdxMaybe, argNameMaybe, annotations)
case typedExpr: KtLambdaExpression => Seq(astForLambda(typedExpr, argIdxMaybe, argNameMaybe))
case typedExpr: KtNameReferenceExpression if typedExpr.getReferencedNameElementType == KtTokens.IDENTIFIER =>
Seq(astForNameReference(typedExpr, argIdxMaybe, argNameMaybe))
// TODO: callable reference
case _: KtNameReferenceExpression => Seq()
case typedExpr: KtObjectLiteralExpression =>
Seq(astForObjectLiteralExpr(typedExpr, argIdxMaybe, argNameMaybe, annotations))
case typedExpr: KtParenthesizedExpression => astsForExpression(typedExpr.getExpression, argIdxMaybe, argNameMaybe)
case typedExpr: KtPostfixExpression => Seq(astForPostfixExpression(typedExpr, argIdxMaybe, argNameMaybe))
case typedExpr: KtPrefixExpression => Seq(astForPrefixExpression(typedExpr, argIdxMaybe, argNameMaybe))
case typedExpr: KtParenthesizedExpression =>
astsForExpression(typedExpr.getExpression, argIdxMaybe, argNameMaybe, annotations)
case typedExpr: KtPostfixExpression =>
Seq(astForPostfixExpression(typedExpr, argIdxMaybe, argNameMaybe, annotations))
case typedExpr: KtPrefixExpression =>
Seq(astForPrefixExpression(typedExpr, argIdxMaybe, argNameMaybe, annotations))
case typedExpr: KtProperty if typedExpr.isLocal => astsForProperty(typedExpr)
case typedExpr: KtReturnExpression => Seq(astForReturnExpression(typedExpr))
case typedExpr: KtStringTemplateExpression => Seq(astForStringTemplate(typedExpr, argIdxMaybe, argNameMaybe))
case typedExpr: KtSuperExpression => Seq(astForSuperExpression(typedExpr, argIdxMaybe, argNameMaybe))
case typedExpr: KtThisExpression => Seq(astForThisExpression(typedExpr, argIdxMaybe, argNameMaybe))
case typedExpr: KtStringTemplateExpression =>
Seq(astForStringTemplate(typedExpr, argIdxMaybe, argNameMaybe, annotations))
case typedExpr: KtSuperExpression => Seq(astForSuperExpression(typedExpr, argIdxMaybe, argNameMaybe, annotations))
case typedExpr: KtThisExpression => Seq(astForThisExpression(typedExpr, argIdxMaybe, argNameMaybe, annotations))
case typedExpr: KtThrowExpression => Seq(astForUnknown(typedExpr, argIdxMaybe, argNameMaybe))
case typedExpr: KtTryExpression => Seq(astForTry(typedExpr, argIdxMaybe, argNameMaybe))
case typedExpr: KtTryExpression => Seq(astForTry(typedExpr, argIdxMaybe, argNameMaybe, annotations))
case typedExpr: KtWhenExpression => Seq(astForWhen(typedExpr, argIdxMaybe, argNameMaybe))
case typedExpr: KtWhileExpression => Seq(astForWhile(typedExpr))
case typedExpr: KtNamedFunction if Option(typedExpr.getName).isEmpty =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -573,14 +573,18 @@ trait KtPsiToAst(implicit withSchemaValidation: ValidationMode) {
returnAst(returnNode(expr, expr.getText), children.toList)
}

def astForIsExpression(expr: KtIsExpression, argIdx: Option[Int], argName: Option[String])(implicit
typeInfoProvider: TypeInfoProvider
): Ast = {
def astForIsExpression(
expr: KtIsExpression,
argIdx: Option[Int],
argName: Option[String],
annotations: Seq[KtAnnotationEntry] = Seq()
)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
registerType(typeInfoProvider.expressionType(expr, TypeConstants.any))
val args = astsForExpression(expr.getLeftHandSide, None) ++
Seq(astForTypeReference(expr.getTypeReference, None, argName))
val node = operatorCallNode(Operators.is, expr.getText, None, line(expr), column(expr))
callAst(withArgumentName(withArgumentIndex(node, argIdx), argName), args.toList)
.withChildren(annotations.map(astForAnnotationEntry))
}

def astForBinaryExprWithTypeRHS(
Expand All @@ -604,26 +608,34 @@ trait KtPsiToAst(implicit withSchemaValidation: ValidationMode) {
Ast(withArgumentName(withArgumentIndex(node, argIdx), argName))
}

def astForSuperExpression(expr: KtSuperExpression, argIdx: Option[Int], argName: Option[String])(implicit
typeInfoProvider: TypeInfoProvider
): Ast = {
def astForSuperExpression(
expr: KtSuperExpression,
argIdx: Option[Int],
argName: Option[String],
annotations: Seq[KtAnnotationEntry] = Seq()
)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
val typeFullName = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any))
val node = withArgumentName(
withArgumentIndex(identifierNode(expr, expr.getText, expr.getText, typeFullName), argIdx),
argName
)
astWithRefEdgeMaybe(expr.getText, node)
.withChildren(annotations.map(astForAnnotationEntry))
}

def astForThisExpression(expr: KtThisExpression, argIdx: Option[Int], argName: Option[String])(implicit
typeInfoProvider: TypeInfoProvider
): Ast = {
def astForThisExpression(
expr: KtThisExpression,
argIdx: Option[Int],
argName: Option[String],
annotations: Seq[KtAnnotationEntry] = Seq()
)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
val typeFullName = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any))
val node = withArgumentName(
withArgumentIndex(identifierNode(expr, expr.getText, expr.getText, typeFullName), argIdx),
argName
)
astWithRefEdgeMaybe(expr.getText, node)
.withChildren(annotations.map(astForAnnotationEntry))
}

def astForClassLiteral(
Expand Down Expand Up @@ -871,9 +883,12 @@ trait KtPsiToAst(implicit withSchemaValidation: ValidationMode) {
callAst(withArgumentName(withArgumentIndex(callNode, argIdx), argName), List(identifierAst) ++ astsForIndexExpr)
}

def astForPostfixExpression(expr: KtPostfixExpression, argIdx: Option[Int], argName: Option[String])(implicit
typeInfoProvider: TypeInfoProvider
): Ast = {
def astForPostfixExpression(
expr: KtPostfixExpression,
argIdx: Option[Int],
argName: Option[String],
annotations: Seq[KtAnnotationEntry] = Seq()
)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
val operatorType = ktTokenToOperator(forPostfixExpr = true).applyOrElse(
KtPsiUtil.getOperationToken(expr),
{ (token: KtToken) =>
Expand All @@ -886,11 +901,15 @@ trait KtPsiToAst(implicit withSchemaValidation: ValidationMode) {
.filterNot(_.root == null)
val node = operatorCallNode(operatorType, expr.getText, Option(typeFullName), line(expr), column(expr))
callAst(withArgumentName(withArgumentIndex(node, argIdx), argName), args)
.withChildren(annotations.map(astForAnnotationEntry))
}

def astForPrefixExpression(expr: KtPrefixExpression, argIdx: Option[Int], argName: Option[String])(implicit
typeInfoProvider: TypeInfoProvider
): Ast = {
def astForPrefixExpression(
expr: KtPrefixExpression,
argIdx: Option[Int],
argName: Option[String],
annotations: Seq[KtAnnotationEntry] = Seq()
)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
val operatorType = ktTokenToOperator(forPostfixExpr = false).applyOrElse(
KtPsiUtil.getOperationToken(expr),
{ (token: KtToken) =>
Expand All @@ -903,6 +922,7 @@ trait KtPsiToAst(implicit withSchemaValidation: ValidationMode) {
.filterNot(_.root == null)
val node = operatorCallNode(operatorType, expr.getText, Option(typeFullName), line(expr), column(expr))
callAst(withArgumentName(withArgumentIndex(node, argIdx), argName), args)
.withChildren(annotations.map(astForAnnotationEntry))
}

private def astsForDestructuringDeclarationWithRHS(
Expand Down Expand Up @@ -1097,29 +1117,35 @@ trait KtPsiToAst(implicit withSchemaValidation: ValidationMode) {
Ast(withArgumentIndex(node, argIdx).argumentName(argNameMaybe))
}

def astForStringTemplate(expr: KtStringTemplateExpression, argIdx: Option[Int], argName: Option[String])(implicit
typeInfoProvider: TypeInfoProvider
): Ast = {
def astForStringTemplate(
expr: KtStringTemplateExpression,
argIdx: Option[Int],
argName: Option[String],
annotations: Seq[KtAnnotationEntry] = Seq()
)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
val typeFullName = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any))
if (expr.hasInterpolation) {
val args = expr.getEntries.filter(_.getExpression != null).zipWithIndex.map { case (entry, idx) =>
val entryTypeFullName = registerType(typeInfoProvider.expressionType(entry.getExpression, TypeConstants.any))
val valueCallNode = operatorCallNode(
Operators.formattedValue,
entry.getExpression.getText,
Option(entryTypeFullName),
line(entry.getExpression),
column(entry.getExpression)
)
val valueArgs = astsForExpression(entry.getExpression, Some(idx + 1))
callAst(valueCallNode, valueArgs.toList)
val outAst =
if (expr.hasInterpolation) {
val args = expr.getEntries.filter(_.getExpression != null).zipWithIndex.map { case (entry, idx) =>
val entryTypeFullName = registerType(typeInfoProvider.expressionType(entry.getExpression, TypeConstants.any))
val valueCallNode = operatorCallNode(
Operators.formattedValue,
entry.getExpression.getText,
Option(entryTypeFullName),
line(entry.getExpression),
column(entry.getExpression)
)
val valueArgs = astsForExpression(entry.getExpression, Some(idx + 1))
callAst(valueCallNode, valueArgs.toList)
}
val node =
operatorCallNode(Operators.formatString, expr.getText, Option(typeFullName), line(expr), column(expr))
callAst(withArgumentName(withArgumentIndex(node, argIdx), argName), args.toIndexedSeq.toList)
} else {
val node = literalNode(expr, expr.getText, typeFullName)
Ast(withArgumentName(withArgumentIndex(node, argIdx), argName))
}
val node = operatorCallNode(Operators.formatString, expr.getText, Option(typeFullName), line(expr), column(expr))
callAst(withArgumentName(withArgumentIndex(node, argIdx), argName), args.toIndexedSeq.toList)
} else {
val node = literalNode(expr, expr.getText, typeFullName)
Ast(withArgumentName(withArgumentIndex(node, argIdx), argName))
}
outAst.withChildren(annotations.map(astForAnnotationEntry))
}

private def astForQualifiedExpressionFieldAccess(
Expand Down Expand Up @@ -1441,9 +1467,12 @@ trait KtPsiToAst(implicit withSchemaValidation: ValidationMode) {
controlStructureAst(node, None, tryAstOption :: (clauseAsts ++ finallyAsts).toList)
}

private def astForTryAsExpression(expr: KtTryExpression, argIdx: Option[Int], argNameMaybe: Option[String])(implicit
typeInfoProvider: TypeInfoProvider
): Ast = {
private def astForTryAsExpression(
expr: KtTryExpression,
argIdx: Option[Int],
argNameMaybe: Option[String],
annotations: Seq[KtAnnotationEntry] = Seq()
)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
val typeFullName = registerType(
// TODO: remove the `last`
typeInfoProvider.expressionType(expr.getTryBlock.getStatements.asScala.last, TypeConstants.any)
Expand All @@ -1457,14 +1486,18 @@ trait KtPsiToAst(implicit withSchemaValidation: ValidationMode) {
.argumentName(argNameMaybe)

callAst(withArgumentIndex(node, argIdx), List(tryBlockAst) ++ clauseAsts)
.withChildren(annotations.map(astForAnnotationEntry))
}

// TODO: handle parameters passed to the clauses
def astForTry(expr: KtTryExpression, argIdx: Option[Int], argNameMaybe: Option[String])(implicit
typeInfoProvider: TypeInfoProvider
): Ast = {
def astForTry(
expr: KtTryExpression,
argIdx: Option[Int],
argNameMaybe: Option[String],
annotations: Seq[KtAnnotationEntry] = Seq()
)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
if (KtPsiUtil.isStatement(expr)) astForTryAsStatement(expr)
else astForTryAsExpression(expr, argIdx, argNameMaybe)
else astForTryAsExpression(expr, argIdx, argNameMaybe, annotations)
}

def astForWhile(expr: KtWhileExpression)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,152 @@ class AnnotationsTests extends KotlinCode2CpgFixture(withOssDataflow = false) {
cpg.all.collectAll[Annotation].codeExact("@Fancy").size shouldBe 1
}
}

"CPG for code with an annotation on an is-expression" should {
val cpg = code("""
|package mypkg
|@Target(AnnotationTarget.EXPRESSION)
|@Retention(AnnotationRetention.SOURCE)
|annotation class Fancy
|fun fn1() {
| @Fancy 1 is Int
|}
|""".stripMargin)

"contain an ANNOTATION node" in {
cpg.all.collectAll[Annotation].codeExact("@Fancy").size shouldBe 1
}
}

"CPG for code with an annotation on a labeled expression" should {
val cpg = code("""
|package mypkg
|@Target(AnnotationTarget.EXPRESSION)
|@Retention(AnnotationRetention.SOURCE)
|annotation class Fancy
|fun fn1() {
| @Fancy albel@ 1
|}
|""".stripMargin)

"contain an ANNOTATION node" in {
cpg.all.collectAll[Annotation].codeExact("@Fancy").size shouldBe 1
}
}

"CPG for code with an annotation on a parenthesized expression" should {
val cpg = code("""
|package mypkg
|@Target(AnnotationTarget.EXPRESSION)
|@Retention(AnnotationRetention.SOURCE)
|annotation class Fancy
|fun fn1() {
| @Fancy() (1)
|}
|""".stripMargin)

"contain an ANNOTATION node" in {
cpg.all.collectAll[Annotation].codeExact("@Fancy()").size shouldBe 1
}
}

"CPG for code with an annotation on a prefix expression" should {
val cpg = code("""
|package mypkg
|@Target(AnnotationTarget.EXPRESSION)
|@Retention(AnnotationRetention.SOURCE)
|annotation class Fancy
|fun fn1() {
| @Fancy ++1
|}
|""".stripMargin)

"contain an ANNOTATION node" in {
cpg.all.collectAll[Annotation].codeExact("@Fancy").size shouldBe 1
}
}

"CPG for code with an annotation on a postfix expression" should {
val cpg = code("""
|package mypkg
|@Target(AnnotationTarget.EXPRESSION)
|@Retention(AnnotationRetention.SOURCE)
|annotation class Fancy
|fun fn1() {
| @Fancy 1++
|}
|""".stripMargin)

"contain an ANNOTATION node" in {
cpg.all.collectAll[Annotation].codeExact("@Fancy").size shouldBe 1
}
}

"CPG for code with an annotation on a string template" should {
val cpg = code("""
|package mypkg
|@Target(AnnotationTarget.EXPRESSION)
|@Retention(AnnotationRetention.SOURCE)
|annotation class Fancy
|fun fn1() {
| @Fancy "something"
|}
|""".stripMargin)

"contain an ANNOTATION node" in {
cpg.all.collectAll[Annotation].codeExact("@Fancy").size shouldBe 1
}
}

"CPG for code with an annotation on a try expression" should {
val cpg = code("""
|package mypkg
|@Target(AnnotationTarget.EXPRESSION)
|@Retention(AnnotationRetention.SOURCE)
|annotation class Fancy
|fun fn1() {
| @Fancy try { "a" } catch (e: Exception) { "b" }
|}
|""".stripMargin)

"contain an ANNOTATION node" in {
cpg.all.collectAll[Annotation].codeExact("@Fancy").size shouldBe 1
}
}

"CPG for code with an annotation on a super expression" should {
val cpg = code("""
|package mypkg
|@Target(AnnotationTarget.EXPRESSION)
|@Retention(AnnotationRetention.SOURCE)
|annotation class Fancy
|class A {
| fun fn1() {
| println(@Fancy super.toString())
| }
|}
|""".stripMargin)

"contain an ANNOTATION node" in {
cpg.all.collectAll[Annotation].codeExact("@Fancy").size shouldBe 1
}
}

"CPG for code with an annotation on a this expression" should {
val cpg = code("""
|package mypkg
|@Target(AnnotationTarget.EXPRESSION)
|@Retention(AnnotationRetention.SOURCE)
|annotation class Fancy
|class A {
| fun fn1() {
| println(@Fancy this.toString())
| }
|}
|""".stripMargin)

"contain an ANNOTATION node" in {
cpg.all.collectAll[Annotation].codeExact("@Fancy").size shouldBe 1
}
}
}

0 comments on commit 7f28c41

Please sign in to comment.