Skip to content

Commit

Permalink
feat: detect in nested blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
Crocmagnon committed Jul 23, 2024
1 parent f35e8a2 commit 693b30c
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 54 deletions.
160 changes: 106 additions & 54 deletions pkg/analyzer/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,64 +36,39 @@ func run(pass *analysis.Pass) (interface{}, error) {
return
}

for _, stmt := range body.List {
assignStmt, ok := stmt.(*ast.AssignStmt)
if !ok {
continue
}

t := pass.TypesInfo.TypeOf(assignStmt.Lhs[0])
if t == nil {
continue
}

if t.String() != "context.Context" {
continue
}

if assignStmt.Tok == token.DEFINE {
break
}

// allow assignment to non-pointer children of values defined within the loop
if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
if obj.Pos() >= body.Pos() && obj.Pos() < body.End() {
continue // definition is within the loop
}
}
}
assignStmt := findNestedContext(pass, body, body.List)
if assignStmt == nil {
return
}

suggestedStmt := ast.AssignStmt{
Lhs: assignStmt.Lhs,
TokPos: assignStmt.TokPos,
Tok: token.DEFINE,
Rhs: assignStmt.Rhs,
}
suggested, err := render(pass.Fset, &suggestedStmt)

var fixes []analysis.SuggestedFix
if err == nil {
fixes = append(fixes, analysis.SuggestedFix{
Message: "replace `=` with `:=`",
TextEdits: []analysis.TextEdit{
{
Pos: assignStmt.Pos(),
End: assignStmt.End(),
NewText: []byte(suggested),
},
suggestedStmt := ast.AssignStmt{
Lhs: assignStmt.Lhs,
TokPos: assignStmt.TokPos,
Tok: token.DEFINE,
Rhs: assignStmt.Rhs,
}
suggested, err := render(pass.Fset, &suggestedStmt)

var fixes []analysis.SuggestedFix
if err == nil {
fixes = append(fixes, analysis.SuggestedFix{
Message: "replace `=` with `:=`",
TextEdits: []analysis.TextEdit{
{
Pos: assignStmt.Pos(),
End: assignStmt.End(),
NewText: []byte(suggested),
},
})
}

pass.Report(analysis.Diagnostic{
Pos: assignStmt.Pos(),
Message: "nested context in loop",
SuggestedFixes: fixes,
},
})

break
}

pass.Report(analysis.Diagnostic{
Pos: assignStmt.Pos(),
Message: "nested context in loop",
SuggestedFixes: fixes,
})

})

return nil, nil
Expand All @@ -113,6 +88,83 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) {
return nil, errUnknown
}

func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.Stmt) *ast.AssignStmt {
for _, stmt := range stmts {
if inner, ok := stmt.(*ast.BlockStmt); ok {
found := findNestedContext(pass, inner, inner.List)
if found != nil {
return found
}
}

if inner, ok := stmt.(*ast.IfStmt); ok {
found := findNestedContext(pass, inner.Body, inner.Body.List)
if found != nil {
return found
}
}

if inner, ok := stmt.(*ast.SwitchStmt); ok {
found := findNestedContext(pass, inner.Body, inner.Body.List)
if found != nil {
return found
}
}

if inner, ok := stmt.(*ast.CaseClause); ok {
found := findNestedContext(pass, block, inner.Body)
if found != nil {
return found
}
}

if inner, ok := stmt.(*ast.SelectStmt); ok {
found := findNestedContext(pass, inner.Body, inner.Body.List)
if found != nil {
return found
}
}

if inner, ok := stmt.(*ast.CommClause); ok {
found := findNestedContext(pass, block, inner.Body)
if found != nil {
return found
}
}

assignStmt, ok := stmt.(*ast.AssignStmt)
if !ok {
continue
}

t := pass.TypesInfo.TypeOf(assignStmt.Lhs[0])
if t == nil {
continue
}

if t.String() != "context.Context" {
continue
}

if assignStmt.Tok == token.DEFINE {
break
}

// allow assignment to non-pointer children of values defined within the loop
if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
if obj.Pos() >= block.Pos() && obj.Pos() < block.End() {
continue // definition is within the loop
}
}
}

return assignStmt
}

return nil
}

func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
for {
switch n := node.(type) {
Expand Down
108 changes: 108 additions & 0 deletions testdata/src/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,49 @@ func example() {
ctx = wrapContext(ctx) // want "nested context in loop"
break
}

for {
err := doSomething()
if err != nil {
ctx := wrapContext(ctx)
ctx = wrapContext(ctx)
}

switch err {
case nil:
ctx := wrapContext(ctx)
ctx = wrapContext(ctx)
default:
ctx := wrapContext(ctx)
ctx = wrapContext(ctx)
}

{
ctx := wrapContext(ctx)
ctx = wrapContext(ctx)
}

select {
case <-ctx.Done():
ctx := wrapContext(ctx)
ctx = wrapContext(ctx)
default:
}

ctx = wrapContext(ctx) // want "nested context in loop"

break
}
}

func wrapContext(ctx context.Context) context.Context {
return context.WithoutCancel(ctx)
}

func doSomething() error {
return nil
}

// storing contexts in a struct isn't recommended, but local copies of a non-pointer struct should act like local copies of a context.
func inStructs(ctx context.Context) {
for i := 0; i < 10; i++ {
Expand Down Expand Up @@ -71,3 +108,74 @@ func inStructs(ctx context.Context) {
rp[0].Ctx = context.WithValue(rp[0].Ctx, "other", "val")
}
}

func inVariousNestedBlocks(ctx context.Context) {
for {
err := doSomething()
if err != nil {
ctx = wrapContext(ctx) // want "nested context in loop"
}

break
}

for {
err := doSomething()
if err != nil {
if true {
ctx = wrapContext(ctx) // want "nested context in loop"
}
}

break
}

for {
err := doSomething()
switch err {
case nil:
ctx = wrapContext(ctx) // want "nested context in loop"
}

break
}

for {
err := doSomething()
switch err {
default:
ctx = wrapContext(ctx) // want "nested context in loop"
}

break
}

for {
ctx := wrapContext(ctx)

err := doSomething()
if err != nil {
ctx = wrapContext(ctx)
}

break
}

for {
{
ctx = wrapContext(ctx) // want "nested context in loop"
}

break
}

for {
select {
case <-ctx.Done():
ctx = wrapContext(ctx) // want "nested context in loop"
default:
}

break
}
}

0 comments on commit 693b30c

Please sign in to comment.