Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
nishanths committed Nov 14, 2021
1 parent 877044d commit 7596586
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 37 deletions.
4 changes: 2 additions & 2 deletions comment.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ func isGeneratedFile(file *ast.File) bool {
return false
}

var generatedCodeRx = regexp.MustCompile(`^// Code generated .* DO NOT EDIT\.$`)
var generatedCodeRe = regexp.MustCompile(`^// Code generated .* DO NOT EDIT\.$`)

func isGeneratedFileComment(s string) bool {
return generatedCodeRx.MatchString(s)
return generatedCodeRe.MatchString(s)
}

// ignoreDirective is used to exclude checking of specific switch statements.
Expand Down
19 changes: 9 additions & 10 deletions enum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,6 @@ func TestEnumMembers_add(t *testing.T) {
// TODO(testing): add tests for iota, repeated values, ...
}

var testdataEnumPkg = func() *packages.Package {
cfg := &packages.Config{Mode: packages.NeedTypesInfo | packages.NeedTypes | packages.NeedSyntax}
pkgs, err := packages.Load(cfg, "./testdata/src/enum")
if err != nil {
panic(err)
}
return pkgs[0]
}()

func TestFindEnums(t *testing.T) {
transform := func(in map[enumType]enumMembers) []checkEnum {
var out []checkEnum
Expand All @@ -79,6 +70,14 @@ func TestFindEnums(t *testing.T) {
return out
}

testdataEnumPkg := func() *packages.Package {
cfg := &packages.Config{Mode: packages.NeedTypesInfo | packages.NeedTypes | packages.NeedSyntax}
pkgs, err := packages.Load(cfg, "./testdata/src/enum")
if err != nil {
panic(err)
}
return pkgs[0]
}()
inspect := inspector.New(testdataEnumPkg.Syntax)

for _, pkgOnly := range [...]bool{false, true} {
Expand All @@ -89,7 +88,7 @@ func TestFindEnums(t *testing.T) {
}
}

// See checkEnums.
// See func checkEnums.
type checkEnum struct {
typeName string
members enumMembers
Expand Down
8 changes: 4 additions & 4 deletions exhaustive.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (v *regexpFlag) Set(expr string) error {
func (v *regexpFlag) value() *regexp.Regexp { return v.r }

func init() {
Analyzer.Flags.BoolVar(&fCheckGeneratedFiles, CheckGeneratedFlag, false, "check switch statements in generated files")
Analyzer.Flags.BoolVar(&fCheckGenerated, CheckGeneratedFlag, false, "check switch statements in generated files")
Analyzer.Flags.BoolVar(&fDefaultSignifiesExhaustive, DefaultSignifiesExhaustiveFlag, false, "presence of \"default\" case in switch statements satisfies exhaustiveness, even if all enum members are not listed")
Analyzer.Flags.Var(&fIgnoreEnumMembers, IgnoreEnumMembersFlag, "enum members matching `regex` do not have to be listed in switch statements to satisfy exhaustiveness")
Analyzer.Flags.BoolVar(&fPackageScopeOnly, PackageScopeOnlyFlag, false, "consider enums only in package scopes, not in inner scopes")
Expand All @@ -218,7 +218,7 @@ const (
)

var (
fCheckGeneratedFiles bool
fCheckGenerated bool
fDefaultSignifiesExhaustive bool
fIgnoreEnumMembers regexpFlag
fPackageScopeOnly bool
Expand All @@ -227,7 +227,7 @@ var (
// resetFlags resets the flag variables to their default values.
// Useful in tests.
func resetFlags() {
fCheckGeneratedFiles = false
fCheckGenerated = false
fDefaultSignifiesExhaustive = false
fIgnoreEnumMembers = regexpFlag{}
fPackageScopeOnly = false
Expand All @@ -250,7 +250,7 @@ func run(pass *analysis.Pass) (interface{}, error) {

cfg := config{
defaultSignifiesExhaustive: fDefaultSignifiesExhaustive,
checkGeneratedFiles: fCheckGeneratedFiles,
checkGeneratedFiles: fCheckGenerated,
ignoreEnumMembers: fIgnoreEnumMembers.value(),
}
checkSwitchStatements(pass, inspect, cfg)
Expand Down
2 changes: 1 addition & 1 deletion exhaustive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func TestExhaustive(t *testing.T) {

// Tests for the -check-generated flag.
run(t, "generated-file/check-generated-off/...")
run(t, "generated-file/check-generated-on/...", func() { fCheckGeneratedFiles = true })
run(t, "generated-file/check-generated-on/...", func() { fCheckGenerated = true })

// Tests for the -default-signifies-exhaustive flag.
// (For tests with this flag off, see other testdata packages
Expand Down
3 changes: 1 addition & 2 deletions fact.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ func exportFact(pass *analysis.Pass, enumTyp enumType, members enumMembers) {
// An (_, false) return indicates that the enum type is not a known one.
func importFact(pass *analysis.Pass, possibleEnumType enumType) (enumMembers, bool) {
var f enumMembersFact
ok := pass.ImportObjectFact(possibleEnumType.factObject(), &f)
if !ok {
if !pass.ImportObjectFact(possibleEnumType.factObject(), &f) {
return enumMembers{}, false
}
return f.Members, true
Expand Down
1 change: 0 additions & 1 deletion fact_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ func checkOneFactType(t *testing.T, fact analysis.Fact) {
checkTypeEnumMembersFact(t, reflect.TypeOf(v).Elem())
default:
t.Errorf("unhandled type %T", v)
return
}
})
}
Expand Down
31 changes: 16 additions & 15 deletions switch.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ const (
// switchStmtChecker returns a node visitor that checks exhaustiveness
// of enum switch statements for the supplied pass, and reports diagnostics for
// switch statements that are non-exhaustive.
// It expects to only see *ast.SwitchStmt nodes.
func switchStmtChecker(pass *analysis.Pass, cfg config) nodeVisitor {
generated := make(map[*ast.File]bool) // cached results
comments := make(map[*ast.File]ast.CommentMap) // cached results
Expand Down Expand Up @@ -108,7 +109,7 @@ func switchStmtChecker(pass *analysis.Pass, cfg config) nodeVisitor {
checkUnexported := samePkg // we want to include unexported members in the exhaustiveness check only if we're in the same package
checklist := makeChecklist(members, tagPkg, checkUnexported, cfg.ignoreEnumMembers)

hasDefaultCase := analyzeSwitchClauses(sw, tagPkg, members.NameToValue, pass.TypesInfo, func(val constantValue) {
hasDefaultCase := analyzeSwitchClauses(sw, pass.TypesInfo, func(val constantValue) {
checklist.found(val)
})

Expand Down Expand Up @@ -163,28 +164,25 @@ func denotesPackage(ident *ast.Ident, info *types.Info) (*types.Package, bool) {
}

// analyzeSwitchClauses analyzes the clauses in the supplied switch statement.
//
// tagPkg is the package of the switch statement's tag value's type.
// The info param should typically be pass.TypesInfo. The found function is
// called for each enum member name found in the switch statement.
//
// The hasDefaultCase return value indicates whether the switch statement has a
// default clause.
func analyzeSwitchClauses(sw *ast.SwitchStmt, tagPkg *types.Package, members map[string]constantValue, info *types.Info, found func(val constantValue)) (hasDefaultCase bool) {
func analyzeSwitchClauses(sw *ast.SwitchStmt, info *types.Info, found func(val constantValue)) (hasDefaultCase bool) {
for _, stmt := range sw.Body.List {
caseCl := stmt.(*ast.CaseClause)
if isDefaultCase(caseCl) {
hasDefaultCase = true
continue // nothing more to do if it's the default case
}
for _, expr := range caseCl.List {
analyzeCaseClauseExpr(expr, tagPkg, members, info, found)
analyzeCaseClauseExpr(expr, info, found)
}
}
return hasDefaultCase
}

func analyzeCaseClauseExpr(e ast.Expr, tagPkg *types.Package, members map[string]constantValue, info *types.Info, found func(val constantValue)) {
func analyzeCaseClauseExpr(e ast.Expr, info *types.Info, found func(val constantValue)) {
handleIdent := func(ident *ast.Ident) {
obj := info.Uses[ident]
if obj == nil {
Expand Down Expand Up @@ -282,6 +280,9 @@ func diagnosticEnumTypeName(enumType *types.TypeName, samePkg bool) string {
return enumType.Pkg().Name() + "." + enumType.Name()
}

// Makes a "missing cases in switch" diagnostic.
// samePkg should be true if the enum type and the switch statement are defined
// in the same package.
func makeDiagnostic(sw *ast.SwitchStmt, samePkg bool, enumTyp enumType, allMembers enumMembers, missingMembers map[string]struct{}) analysis.Diagnostic {
message := fmt.Sprintf("missing cases in switch of type %s: %s",
diagnosticEnumTypeName(enumTyp.TypeName, samePkg),
Expand All @@ -304,12 +305,12 @@ func makeDiagnostic(sw *ast.SwitchStmt, samePkg bool, enumTyp enumType, allMembe
// The remaining method returns the member names not accounted for.
//
type checklist struct {
em enumMembers
names map[string]struct{}
em enumMembers
checkl map[string]struct{}
}

func makeChecklist(em enumMembers, enumPkg *types.Package, includeUnexported bool, ignore *regexp.Regexp) *checklist {
names := make(map[string]struct{})
checkl := make(map[string]struct{})

add := func(memberName string) {
if memberName == "_" {
Expand All @@ -325,26 +326,26 @@ func makeChecklist(em enumMembers, enumPkg *types.Package, includeUnexported boo
if ignore != nil && ignore.MatchString(enumPkg.Path()+"."+memberName) {
return
}
names[memberName] = struct{}{}
checkl[memberName] = struct{}{}
}

for _, name := range em.Names {
add(name)
}

return &checklist{
em: em,
names: names,
em: em,
checkl: checkl,
}
}

func (c *checklist) found(val constantValue) {
// Delete all of the same-valued names.
for _, name := range c.em.ValueToNames[val] {
delete(c.names, name)
delete(c.checkl, name)
}
}

func (c *checklist) remaining() map[string]struct{} {
return c.names
return c.checkl
}
3 changes: 1 addition & 2 deletions switch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,9 @@ func TestAnalyzeSwitchClauses(t *testing.T) {

assertFoundNames := func(t *testing.T, sw *ast.SwitchStmt, info *types.Info, want []constantValue, wantDefaultExists bool) {
t.Helper()
tagType := info.Types[sw.Tag].Type.(*types.Named)

var got []constantValue
gotDefaultExists := analyzeSwitchClauses(sw, tagType.Obj().Pkg(), m, info, func(val constantValue) {
gotDefaultExists := analyzeSwitchClauses(sw, info, func(val constantValue) {
got = append(got, val)
})

Expand Down

0 comments on commit 7596586

Please sign in to comment.