diff --git a/go/analysis/passes/copylock/copylock.go b/go/analysis/passes/copylock/copylock.go index b3ca8ada40a..ff2b41ac4aa 100644 --- a/go/analysis/passes/copylock/copylock.go +++ b/go/analysis/passes/copylock/copylock.go @@ -242,29 +242,23 @@ func lockPathRhs(pass *analysis.Pass, x ast.Expr) typePath { // lockPath returns a typePath describing the location of a lock value // contained in typ. If there is no contained lock, it returns nil. // -// The seenTParams map is used to short-circuit infinite recursion via type -// parameters. -func lockPath(tpkg *types.Package, typ types.Type, seenTParams map[*typeparams.TypeParam]bool) typePath { - if typ == nil { +// The seen map is used to short-circuit infinite recursion due to type cycles. +func lockPath(tpkg *types.Package, typ types.Type, seen map[types.Type]bool) typePath { + if typ == nil || seen[typ] { return nil } + if seen == nil { + seen = make(map[types.Type]bool) + } + seen[typ] = true if tpar, ok := typ.(*typeparams.TypeParam); ok { - if seenTParams == nil { - // Lazily allocate seenTParams, since the common case will not involve - // any type parameters. - seenTParams = make(map[*typeparams.TypeParam]bool) - } - if seenTParams[tpar] { - return nil - } - seenTParams[tpar] = true terms, err := typeparams.StructuralTerms(tpar) if err != nil { return nil // invalid type } for _, term := range terms { - subpath := lockPath(tpkg, term.Type(), seenTParams) + subpath := lockPath(tpkg, term.Type(), seen) if len(subpath) > 0 { if term.Tilde() { // Prepend a tilde to our lock path entry to clarify the resulting @@ -298,7 +292,7 @@ func lockPath(tpkg *types.Package, typ types.Type, seenTParams map[*typeparams.T ttyp, ok := typ.Underlying().(*types.Tuple) if ok { for i := 0; i < ttyp.Len(); i++ { - subpath := lockPath(tpkg, ttyp.At(i).Type(), seenTParams) + subpath := lockPath(tpkg, ttyp.At(i).Type(), seen) if subpath != nil { return append(subpath, typ.String()) } @@ -332,7 +326,7 @@ func lockPath(tpkg *types.Package, typ types.Type, seenTParams map[*typeparams.T nfields := styp.NumFields() for i := 0; i < nfields; i++ { ftyp := styp.Field(i).Type() - subpath := lockPath(tpkg, ftyp, seenTParams) + subpath := lockPath(tpkg, ftyp, seen) if subpath != nil { return append(subpath, typ.String()) } diff --git a/go/analysis/passes/copylock/testdata/src/a/issue61678.go b/go/analysis/passes/copylock/testdata/src/a/issue61678.go new file mode 100644 index 00000000000..9856b5b4ba7 --- /dev/null +++ b/go/analysis/passes/copylock/testdata/src/a/issue61678.go @@ -0,0 +1,30 @@ +package a + +import "sync" + +// These examples are taken from golang/go#61678, modified so that A and B +// contain a mutex. + +type A struct { + a A + mu sync.Mutex +} + +type B struct { + a A + b B + mu sync.Mutex +} + +func okay(x A) {} +func sure() { var x A; nop(x) } + +var fine B + +func what(x B) {} // want `passes lock by value` +func bad() { var x B; nop(x) } // want `copies lock value` +func good() { nop(B{}) } +func stillgood() { nop(B{b: B{b: B{b: B{}}}}) } +func nope() { nop(B{}.b) } // want `copies lock value` + +func nop(any) {} // only used to get around unused variable errors