From b787354dc9044f4ba6585f068e57decd58dc43a5 Mon Sep 17 00:00:00 2001 From: Iskander Sharipov Date: Sun, 20 Dec 2020 16:02:49 +0300 Subject: [PATCH] ruleguard: filter parsing improvements Refactored Where() filter parsing code. Implemented || operator while at it. Fixes #115 !(A && B) conditions now work properly Fixes #26 A || B conditions are now implemented --- analyzer/testdata/src/regression/issue115.go | 13 + analyzer/testdata/src/regression/rules.go | 6 + ruleguard/debug_test.go | 53 +++- ruleguard/filters.go | 246 +++++++++++++++ ruleguard/gorule.go | 41 +-- ruleguard/parser.go | 313 +++++++------------ ruleguard/runner.go | 73 +++-- 7 files changed, 494 insertions(+), 251 deletions(-) create mode 100644 analyzer/testdata/src/regression/issue115.go create mode 100644 ruleguard/filters.go diff --git a/analyzer/testdata/src/regression/issue115.go b/analyzer/testdata/src/regression/issue115.go new file mode 100644 index 00000000..ba8d82ee --- /dev/null +++ b/analyzer/testdata/src/regression/issue115.go @@ -0,0 +1,13 @@ +package regression + +func testIssue115() { + intFunc := func() int { return 19 } + stringFunc := func() string { return "19" } + + println(13) + println(43 + 5) + + println("foo") // want `\Q"foo" is not a constexpr int` + println(intFunc()) // want `\QintFunc() is not a constexpr int` + println(stringFunc()) // want `\QstringFunc() is not a constexpr int` +} diff --git a/analyzer/testdata/src/regression/rules.go b/analyzer/testdata/src/regression/rules.go index d866124f..273693a0 100644 --- a/analyzer/testdata/src/regression/rules.go +++ b/analyzer/testdata/src/regression/rules.go @@ -19,3 +19,9 @@ func issue72(m fluent.Matcher) { `fmt.Sprintf("%s<%s>", $name, $email)`). Report("use net/mail Address.String() instead of fmt.Sprintf()") } + +func issue115(m fluent.Matcher) { + m.Match(`println($x)`). + Where(!(m["x"].Const && m["x"].Type.Is("int"))). + Report("$x is not a constexpr int") +} \ No newline at end of file diff --git a/ruleguard/debug_test.go b/ruleguard/debug_test.go index 2e84b651..90650c8f 100644 --- a/ruleguard/debug_test.go +++ b/ruleguard/debug_test.go @@ -40,12 +40,11 @@ func TestDebug(t *testing.T) { }, }, - // TODO(quasilyte): don't lose "!" in the debug output. `m.Match("$x + $_").Where(!m["x"].Type.Is("int"))`: { `sink = "a" + "b"`: nil, `sink = int(10) + 20`: { - `input.go:4: [rules.go:5] rejected by m["x"].Type.Is("int")`, + `input.go:4: [rules.go:5] rejected by !m["x"].Type.Is("int")`, ` $x int: int(10)`, }, }, @@ -91,6 +90,56 @@ func TestDebug(t *testing.T) { ` $x interface{}: f((10))`, }, }, + + // When debugging OR, the last alternative will be reported as the failure reason, + // although it should be obvious that all operands are falsy. + // We don't return the entire OR expression as a reason to avoid the output cluttering. + `m.Match("_ = $x").Where(m["x"].Type.Is("int") || m["x"].Type.Is("string"))`: { + `_ = ""`: nil, + `_ = 10`: nil, + + `_ = []int{}`: { + `input.go:4: [rules.go:5] rejected by m["x"].Type.Is("string")`, + ` $x []int: []int{}`, + }, + + `_ = int32(0)`: { + `input.go:4: [rules.go:5] rejected by m["x"].Type.Is("string")`, + ` $x int32: int32(0)`, + }, + }, + + // Using 3 operands for || and different ()-groupings. + `m.Match("_ = $x").Where(m["x"].Type.Is("int") || m["x"].Type.Is("string") || m["x"].Text == "f()")`: { + `_ = ""`: nil, + `_ = 10`: nil, + `_ = f()`: nil, + + `_ = []string{"x"}`: { + `input.go:4: [rules.go:5] rejected by m["x"].Text == "f()"`, + ` $x []string: []string{"x"}`, + }, + }, + `m.Match("_ = $x").Where(m["x"].Type.Is("int") || (m["x"].Type.Is("string") || m["x"].Text == "f()"))`: { + `_ = ""`: nil, + `_ = 10`: nil, + `_ = f()`: nil, + + `_ = []string{"x"}`: { + `input.go:4: [rules.go:5] rejected by m["x"].Text == "f()"`, + ` $x []string: []string{"x"}`, + }, + }, + `m.Match("_ = $x").Where((m["x"].Type.Is("int") || m["x"].Type.Is("string")) || m["x"].Text == "f()")`: { + `_ = ""`: nil, + `_ = 10`: nil, + `_ = f()`: nil, + + `_ = []string{"x"}`: { + `input.go:4: [rules.go:5] rejected by m["x"].Text == "f()"`, + ` $x []string: []string{"x"}`, + }, + }, } exprToRules := func(s string) *GoRuleSet { diff --git a/ruleguard/filters.go b/ruleguard/filters.go new file mode 100644 index 00000000..711d71dc --- /dev/null +++ b/ruleguard/filters.go @@ -0,0 +1,246 @@ +package ruleguard + +import ( + "go/ast" + "go/constant" + "go/token" + "go/types" + "path/filepath" + "regexp" + + "github.com/quasilyte/go-ruleguard/ruleguard/typematch" +) + +const filterSuccess = matchFilterResult("") + +func filterFailure(reason string) matchFilterResult { + return matchFilterResult(reason) +} + +func makeNotFilter(src string, x matchFilter) filterFunc { + return func(params *filterParams) matchFilterResult { + if x.fn(params).Matched() { + return matchFilterResult(src) + } + return "" + } +} + +func makeAndFilter(lhs, rhs matchFilter) filterFunc { + return func(params *filterParams) matchFilterResult { + if lhsResult := lhs.fn(params); !lhsResult.Matched() { + return lhsResult + } + return rhs.fn(params) + } +} + +func makeOrFilter(lhs, rhs matchFilter) filterFunc { + return func(params *filterParams) matchFilterResult { + if lhsResult := lhs.fn(params); lhsResult.Matched() { + return filterSuccess + } + return rhs.fn(params) + } +} + +func makeFileImportsFilter(src, pkgPath string) filterFunc { + return func(params *filterParams) matchFilterResult { + _, imported := params.imports[pkgPath] + if imported { + return filterSuccess + } + return filterFailure(src) + } +} + +func makeFilePkgPathMatchesFilter(src string, re *regexp.Regexp) filterFunc { + return func(params *filterParams) matchFilterResult { + pkgPath := params.ctx.Pkg.Path() + if re.MatchString(pkgPath) { + return filterSuccess + } + return filterFailure(src) + } +} + +func makeFileNameMatchesFilter(src string, re *regexp.Regexp) filterFunc { + return func(params *filterParams) matchFilterResult { + if re.MatchString(filepath.Base(params.filename)) { + return filterSuccess + } + return filterFailure(src) + } +} + +func makePureFilter(src, varname string) filterFunc { + return func(params *filterParams) matchFilterResult { + n := params.subExpr(varname) + if isPure(params.ctx.Types, n) { + return filterSuccess + } + return filterFailure(src) + } +} + +func makeConstFilter(src, varname string) filterFunc { + return func(params *filterParams) matchFilterResult { + n := params.subExpr(varname) + if isConstant(params.ctx.Types, n) { + return filterSuccess + } + return filterFailure(src) + } +} + +func makeAddressableFilter(src, varname string) filterFunc { + return func(params *filterParams) matchFilterResult { + n := params.subExpr(varname) + if isAddressable(params.ctx.Types, n) { + return filterSuccess + } + return filterFailure(src) + } +} + +func makeTypeImplementsFilter(src, varname string, iface *types.Interface) filterFunc { + return func(params *filterParams) matchFilterResult { + typ := params.typeofNode(params.subExpr(varname)) + if types.Implements(typ, iface) { + return filterSuccess + } + return filterFailure(src) + } +} + +func makeTypeIsFilter(src, varname string, underlying bool, pat *typematch.Pattern) filterFunc { + if underlying { + return func(params *filterParams) matchFilterResult { + typ := params.typeofNode(params.subExpr(varname)).Underlying() + if pat.MatchIdentical(typ) { + return filterSuccess + } + return filterFailure(src) + } + } + return func(params *filterParams) matchFilterResult { + typ := params.typeofNode(params.subExpr(varname)) + if pat.MatchIdentical(typ) { + return filterSuccess + } + return filterFailure(src) + } +} + +func makeTypeConvertibleToFilter(src, varname string, dstType types.Type) filterFunc { + return func(params *filterParams) matchFilterResult { + typ := params.typeofNode(params.subExpr(varname)) + if types.ConvertibleTo(typ, dstType) { + return filterSuccess + } + return filterFailure(src) + } +} + +func makeTypeAssignableToFilter(src, varname string, dstType types.Type) filterFunc { + return func(params *filterParams) matchFilterResult { + typ := params.typeofNode(params.subExpr(varname)) + if types.AssignableTo(typ, dstType) { + return filterSuccess + } + return filterFailure(src) + } +} + +func makeTypeSizeConstFilter(src, varname string, op token.Token, rhsValue constant.Value) filterFunc { + return func(params *filterParams) matchFilterResult { + typ := params.typeofNode(params.subExpr(varname)) + lhsValue := constant.MakeInt64(params.ctx.Sizes.Sizeof(typ)) + if constant.Compare(lhsValue, op, rhsValue) { + return filterSuccess + } + return filterFailure(src) + } +} + +func makeValueIntConstFilter(src, varname string, op token.Token, rhsValue constant.Value) filterFunc { + return func(params *filterParams) matchFilterResult { + lhsValue := intValueOf(params.ctx.Types, params.subExpr(varname)) + if lhsValue == nil { + return filterFailure(src) // The value is unknown + } + if constant.Compare(lhsValue, op, rhsValue) { + return filterSuccess + } + return filterFailure(src) + } +} + +func makeValueIntFilter(src, varname string, op token.Token, rhsVarname string) filterFunc { + return func(params *filterParams) matchFilterResult { + lhsValue := intValueOf(params.ctx.Types, params.subExpr(varname)) + if lhsValue == nil { + return filterFailure(src) + } + rhsValue := intValueOf(params.ctx.Types, params.subExpr(rhsVarname)) + if rhsValue == nil { + return filterFailure(src) + } + if constant.Compare(lhsValue, op, rhsValue) { + return filterSuccess + } + return filterFailure(src) + } +} + +func makeTextConstFilter(src, varname string, op token.Token, rhsValue constant.Value) filterFunc { + return func(params *filterParams) matchFilterResult { + s := params.nodeText(params.subExpr(varname)) + lhsValue := constant.MakeString(string(s)) + if constant.Compare(lhsValue, op, rhsValue) { + return filterSuccess + } + return filterFailure(src) + } +} + +func makeTextFilter(src, varname string, op token.Token, rhsVarname string) filterFunc { + return func(params *filterParams) matchFilterResult { + s1 := params.nodeText(params.subExpr(varname)) + lhsValue := constant.MakeString(string(s1)) + s2 := params.nodeText(params.values[rhsVarname]) + rhsValue := constant.MakeString(string(s2)) + if constant.Compare(lhsValue, op, rhsValue) { + return filterSuccess + } + return filterFailure(src) + } +} + +func makeTextMatchesFilter(src, varname string, re *regexp.Regexp) filterFunc { + return func(params *filterParams) matchFilterResult { + if re.Match(params.nodeText(params.subExpr(varname))) { + return filterSuccess + } + return filterFailure(src) + } +} + +func makeNodeIsFilter(src, varname string, cat nodeCategory) filterFunc { + return func(params *filterParams) matchFilterResult { + n := params.subExpr(varname) + var matched bool + switch cat { + case nodeExpr: + _, matched = n.(ast.Expr) + case nodeStmt: + _, matched = n.(ast.Stmt) + default: + matched = (cat == categorizeNode(n)) + } + if matched { + return filterSuccess + } + return filterFailure(src) + } +} diff --git a/ruleguard/gorule.go b/ruleguard/gorule.go index 03161333..4b4ed069 100644 --- a/ruleguard/gorule.go +++ b/ruleguard/gorule.go @@ -25,40 +25,41 @@ type goRule struct { filter matchFilter } -type matchFilter struct { - fileFilters []fileFilter - subFilters map[string][]nodeFilter -} +type matchFilterResult string + +func (s matchFilterResult) Matched() bool { return s == "" } + +func (s matchFilterResult) RejectReason() string { return string(s) } + +type filterFunc func(*filterParams) matchFilterResult -type fileFilter struct { - src string - pred func(*fileFilterParams) bool +type matchFilter struct { + src string + fn func(*filterParams) matchFilterResult } -type fileFilterParams struct { +type filterParams struct { ctx *Context filename string imports map[string]struct{} -} -type nodeFilter struct { - src string - pred func(*nodeFilterParams) bool -} - -type nodeFilterParams struct { - ctx *Context - n ast.Expr values map[string]ast.Node nodeText func(n ast.Node) []byte } -func (params *nodeFilterParams) nodeType() types.Type { - return params.typeofNode(params.n) +func (params *filterParams) subExpr(name string) ast.Expr { + switch n := params.values[name].(type) { + case ast.Expr: + return n + case *ast.ExprStmt: + return n.X + default: + return nil + } } -func (params *nodeFilterParams) typeofNode(n ast.Node) types.Type { +func (params *filterParams) typeofNode(n ast.Node) types.Type { if e, ok := n.(ast.Expr); ok { if typ := params.ctx.Types.TypeOf(e); typ != nil { return typ diff --git a/ruleguard/parser.go b/ruleguard/parser.go index 843c02f8..ee9b01dc 100644 --- a/ruleguard/parser.go +++ b/ruleguard/parser.go @@ -1,16 +1,15 @@ package ruleguard import ( + "errors" "fmt" "go/ast" - "go/constant" "go/importer" "go/parser" "go/token" "go/types" "io" "path" - "path/filepath" "regexp" "strconv" @@ -342,10 +341,6 @@ func (p *rulesParser) parseRule(matcher string, call *ast.CallExpr) error { filename: p.filename, line: p.fset.Position(origCall.Pos()).Line, group: p.group, - filter: matchFilter{ - fileFilters: []fileFilter{}, - subFilters: map[string][]nodeFilter{}, - }, } var alternatives []string @@ -361,9 +356,11 @@ func (p *rulesParser) parseRule(matcher string, call *ast.CallExpr) error { } if whereArgs != nil { - if err := p.walkFilter(&proto.filter, (*whereArgs)[0], false); err != nil { + filter, err := p.parseFilter((*whereArgs)[0]) + if err != nil { return err } + proto.filter = filter } if suggestArgs != nil { @@ -418,307 +415,239 @@ func (p *rulesParser) parseRule(matcher string, call *ast.CallExpr) error { return nil } -func (p *rulesParser) appendFileFilter(dst *matchFilter, e ast.Expr, pred func(*fileFilterParams) bool) { - dst.fileFilters = append(dst.fileFilters, fileFilter{ - src: sprintNode(p.fset, e), - pred: pred, - }) +type filterParseError string + +func (p *rulesParser) parseFilter(root ast.Expr) (result matchFilter, err error) { + defer func() { + rv := recover() + if rv == nil { + return + } + if parseErr, ok := rv.(filterParseError); ok { + err = errors.New(string(parseErr)) + return + } + panic(rv) // not our panic + }() + + f := p.parseFilterExpr(root) + return f, nil // error is set via defer } -func (p *rulesParser) appendSubFilter(dst *matchFilter, e ast.Expr, key string, pred func(*nodeFilterParams) bool) { - dst.subFilters[key] = append(dst.subFilters[key], nodeFilter{ - src: sprintNode(p.fset, e), - pred: pred, - }) +func (p *rulesParser) filterError(n ast.Node, format string, args ...interface{}) filterParseError { + loc := p.fset.Position(n.Pos()) + message := fmt.Sprintf("%s:%d: %s: %s", loc.Filename, loc.Line, sprintNode(p.fset, n), fmt.Sprintf(format, args...)) + return filterParseError(message) } -func (p *rulesParser) walkFilter(dst *matchFilter, e ast.Expr, negate bool) error { +func (p *rulesParser) parseFilterExpr(e ast.Expr) matchFilter { + result := matchFilter{src: sprintNode(p.fset, e)} + switch e := e.(type) { + case *ast.ParenExpr: + return p.parseFilterExpr(e.X) + case *ast.UnaryExpr: + x := p.parseFilterExpr(e.X) if e.Op == token.NOT { - return p.walkFilter(dst, e.X, !negate) + result.fn = makeNotFilter(result.src, x) + return result } + panic(p.filterError(e, "unsupported unary op")) + case *ast.BinaryExpr: switch e.Op { case token.LAND: - err := p.walkFilter(dst, e.X, negate) - if err != nil { - return err - } - return p.walkFilter(dst, e.Y, negate) + result.fn = makeAndFilter(p.parseFilterExpr(e.X), p.parseFilterExpr(e.Y)) + return result + case token.LOR: + result.fn = makeOrFilter(p.parseFilterExpr(e.X), p.parseFilterExpr(e.Y)) + return result case token.GEQ, token.LEQ, token.LSS, token.GTR, token.EQL, token.NEQ: operand := p.toFilterOperand(e.X) rhs := p.toFilterOperand(e.Y) rhsValue := p.types.Types[e.Y].Value - expectedResult := !negate if operand.path == "Type.Size" && rhsValue != nil { - p.appendSubFilter(dst, e, operand.varName, func(params *nodeFilterParams) bool { - x := constant.MakeInt64(params.ctx.Sizes.Sizeof(params.nodeType())) - return expectedResult == constant.Compare(x, e.Op, rhsValue) - }) - return nil + result.fn = makeTypeSizeConstFilter(result.src, operand.varName, e.Op, rhsValue) + return result } if operand.path == "Value.Int" && rhsValue != nil { - p.appendSubFilter(dst, e, operand.varName, func(params *nodeFilterParams) bool { - x := intValueOf(params.ctx.Types, params.n) - if x == nil { - return false // The value is unknown - } - return expectedResult == constant.Compare(x, e.Op, rhsValue) - }) - return nil + result.fn = makeValueIntConstFilter(result.src, operand.varName, e.Op, rhsValue) + return result } if operand.path == "Value.Int" && rhs.path == "Value.Int" && rhs.varName != "" { - p.appendSubFilter(dst, e, operand.varName, func(params *nodeFilterParams) bool { - x := intValueOf(params.ctx.Types, params.n) - if x == nil { - return false // The value is unknown - } - y := intValueOf(params.ctx.Types, params.values[rhs.varName].(ast.Expr)) - if y == nil { - return false - } - return expectedResult == constant.Compare(x, e.Op, y) - }) - return nil + result.fn = makeValueIntFilter(result.src, operand.varName, e.Op, rhs.varName) + return result } if operand.path == "Text" && rhsValue != nil { - p.appendSubFilter(dst, e, operand.varName, func(params *nodeFilterParams) bool { - s := params.nodeText(params.n) - x := constant.MakeString(string(s)) - return expectedResult == constant.Compare(x, e.Op, rhsValue) - }) - return nil + result.fn = makeTextConstFilter(result.src, operand.varName, e.Op, rhsValue) + return result } if operand.path == "Text" && rhs.path == "Text" && rhs.varName != "" { - p.appendSubFilter(dst, e, operand.varName, func(params *nodeFilterParams) bool { - s := params.nodeText(params.n) - x := constant.MakeString(string(s)) - s2 := params.nodeText(params.values[rhs.varName]) - y := constant.MakeString(string(s2)) - return expectedResult == constant.Compare(x, e.Op, y) - }) - return nil + result.fn = makeTextFilter(result.src, operand.varName, e.Op, rhs.varName) + return result } } - case *ast.ParenExpr: - return p.walkFilter(dst, e.X, negate) + panic(p.filterError(e, "unsupported binary op")) } - origExpr := e - appendFileFilter := func(pred func(*fileFilterParams) bool) { - p.appendFileFilter(dst, origExpr, pred) - } - appendSubFilter := func(sub string, pred func(*nodeFilterParams) bool) { - p.appendSubFilter(dst, origExpr, sub, pred) - } - - // TODO(quasilyte): refactor and extend. operand := p.toFilterOperand(e) args := operand.args switch operand.path { default: - return p.errorf(e, "%s is not a valid filter expression", sprintNode(p.fset, e)) + panic(p.filterError(e, "unsupported expr")) + case "File.Imports": pkgPath, ok := p.toStringValue(args[0]) if !ok { - return p.errorf(args[0], "expected a string literal argument") + panic(p.filterError(args[0], "expected a string literal argument")) } - wantImported := !negate - appendFileFilter(func(params *fileFilterParams) bool { - _, imported := params.imports[pkgPath] - return wantImported == imported - }) - return nil + result.fn = makeFileImportsFilter(result.src, pkgPath) + case "File.PkgPath.Matches": patternString, ok := p.toStringValue(args[0]) if !ok { - return p.errorf(args[0], "expected a string literal argument") + panic(p.filterError(args[0], "expected a string literal argument")) } re, err := regexp.Compile(patternString) if err != nil { - return p.errorf(args[0], "parse regexp: %v", err) + panic(p.filterError(args[0], "parse regexp: %v", err)) } - wantMatched := !negate - appendFileFilter(func(params *fileFilterParams) bool { - pkgPath := params.ctx.Pkg.Path() - return wantMatched == re.MatchString(pkgPath) - }) - return nil + result.fn = makeFilePkgPathMatchesFilter(result.src, re) + case "File.Name.Matches": patternString, ok := p.toStringValue(args[0]) if !ok { - return p.errorf(args[0], "expected a string literal argument") + panic(p.filterError(args[0], "expected a string literal argument")) } re, err := regexp.Compile(patternString) if err != nil { - return p.errorf(args[0], "parse regexp: %v", err) + panic(p.filterError(args[0], "parse regexp: %v", err)) } - wantMatched := !negate - appendFileFilter(func(params *fileFilterParams) bool { - return wantMatched == re.MatchString(filepath.Base(params.filename)) - }) - return nil + result.fn = makeFileNameMatchesFilter(result.src, re) case "Pure": - wantPure := !negate - appendSubFilter(operand.varName, func(params *nodeFilterParams) bool { - return wantPure == isPure(params.ctx.Types, params.n) - }) + result.fn = makePureFilter(result.src, operand.varName) case "Const": - wantConst := !negate - appendSubFilter(operand.varName, func(params *nodeFilterParams) bool { - return wantConst == isConstant(params.ctx.Types, params.n) - }) + result.fn = makeConstFilter(result.src, operand.varName) case "Addressable": - wantAddressable := !negate - appendSubFilter(operand.varName, func(params *nodeFilterParams) bool { - return wantAddressable == isAddressable(params.ctx.Types, params.n) - }) - - case "Text.Matches": - patternString, ok := p.toStringValue(args[0]) - if !ok { - return p.errorf(args[0], "expected a string literal argument") - } - re, err := regexp.Compile(patternString) - if err != nil { - return p.errorf(args[0], "parse regexp: %v", err) - } - wantMatched := !negate - appendSubFilter(operand.varName, func(params *nodeFilterParams) bool { - return wantMatched == re.Match(params.nodeText(params.n)) - }) - - case "Node.Is": - typeString, ok := p.toStringValue(args[0]) - if !ok { - return p.errorf(args[0], "expected a string literal argument") - } - cat := categorizeNodeString(typeString) - if cat == nodeUnknown { - return p.errorf(args[0], "%s is not a valid go/ast type name", typeString) - } - wantMatched := !negate - appendSubFilter(operand.varName, func(params *nodeFilterParams) bool { - switch cat { - case nodeExpr: - _, ok := params.n.(ast.Expr) - return wantMatched == ok - case nodeStmt: - _, ok := params.n.(ast.Stmt) - return wantMatched == ok - default: - ok := cat == categorizeNode(params.n) - return wantMatched == ok - } - }) + result.fn = makeAddressableFilter(result.src, operand.varName) case "Type.Is", "Type.Underlying.Is": typeString, ok := p.toStringValue(args[0]) if !ok { - return p.errorf(args[0], "expected a string literal argument") + panic(p.filterError(args[0], "expected a string literal argument")) } ctx := typematch.Context{Itab: p.itab} pat, err := typematch.Parse(&ctx, typeString) if err != nil { - return p.errorf(args[0], "parse type expr: %v", err) - } - wantIdentical := !negate - if operand.path == "Type.Underlying.Is" { - appendSubFilter(operand.varName, func(params *nodeFilterParams) bool { - return wantIdentical == pat.MatchIdentical(params.nodeType().Underlying()) - }) - } else { - appendSubFilter(operand.varName, func(params *nodeFilterParams) bool { - return wantIdentical == pat.MatchIdentical(params.nodeType()) - }) + panic(p.filterError(args[0], "parse type expr: %v", err)) } + underlying := operand.path == "Type.Underlying.Is" + result.fn = makeTypeIsFilter(result.src, operand.varName, underlying, pat) + case "Type.ConvertibleTo": typeString, ok := p.toStringValue(args[0]) if !ok { - return p.errorf(args[0], "expected a string literal argument") + panic(p.filterError(args[0], "expected a string literal argument")) } - y, err := typeFromString(typeString) + dstType, err := typeFromString(typeString) if err != nil { - return p.errorf(args[0], "parse type expr: %v", err) + panic(p.filterError(args[0], "parse type expr: %v", err)) } - if y == nil { - return p.errorf(args[0], "can't convert %s into a type constraint yet", typeString) + if dstType == nil { + panic(p.filterError(args[0], "can't convert %s into a type constraint yet", typeString)) } - wantConvertible := !negate - appendSubFilter(operand.varName, func(params *nodeFilterParams) bool { - return wantConvertible == types.ConvertibleTo(params.nodeType(), y) - }) + result.fn = makeTypeConvertibleToFilter(result.src, operand.varName, dstType) + case "Type.AssignableTo": typeString, ok := p.toStringValue(args[0]) if !ok { - return p.errorf(args[0], "expected a string literal argument") + panic(p.filterError(args[0], "expected a string literal argument")) } - y, err := typeFromString(typeString) + dstType, err := typeFromString(typeString) if err != nil { - return p.errorf(args[0], "parse type expr: %v", err) + panic(p.filterError(args[0], "parse type expr: %v", err)) } - if y == nil { - return p.errorf(args[0], "can't convert %s into a type constraint yet", typeString) + if dstType == nil { + panic(p.filterError(args[0], "can't convert %s into a type constraint yet", typeString)) } - wantAssignable := !negate - appendSubFilter(operand.varName, func(params *nodeFilterParams) bool { - return wantAssignable == types.AssignableTo(params.nodeType(), y) - }) + result.fn = makeTypeAssignableToFilter(result.src, operand.varName, dstType) + case "Type.Implements": typeString, ok := p.toStringValue(args[0]) if !ok { - return p.errorf(args[0], "expected a string literal argument") + panic(p.filterError(args[0], "expected a string literal argument")) } n, err := parser.ParseExpr(typeString) if err != nil { - return p.errorf(args[0], "parse type expr: %v", err) + panic(p.filterError(args[0], "parse type expr: %v", err)) } var iface *types.Interface switch n := n.(type) { case *ast.Ident: if n.Name != `error` { - return p.errorf(n, "only `error` unqualified type is recognized") + panic(p.filterError(n, "only `error` unqualified type is recognized")) } iface = types.Universe.Lookup("error").Type().Underlying().(*types.Interface) case *ast.SelectorExpr: pkgName, ok := n.X.(*ast.Ident) if !ok { - return p.errorf(n.X, "invalid package name") + panic(p.filterError(n.X, "invalid package name")) } pkgPath, ok := p.itab.Lookup(pkgName.Name) if !ok { - return p.errorf(n.X, "package %s is not imported", pkgName.Name) + panic(p.filterError(n.X, "package %s is not imported", pkgName.Name)) } pkg, err := p.stdImporter.Import(pkgPath) if err != nil { pkg, err = p.srcImporter.Import(pkgPath) if err != nil { - return p.errorf(n, "can't load %s: %v", pkgPath, err) + panic(p.filterError(n, "can't load %s: %v", pkgPath, err)) } } obj := pkg.Scope().Lookup(n.Sel.Name) if obj == nil { - return p.errorf(n, "%s is not found in %s", n.Sel.Name, pkgPath) + panic(p.filterError(n, "%s is not found in %s", n.Sel.Name, pkgPath)) } iface, ok = obj.Type().Underlying().(*types.Interface) if !ok { - return p.errorf(n, "%s is not an interface type", n.Sel.Name) + panic(p.filterError(n, "%s is not an interface type", n.Sel.Name)) } default: - return p.errorf(args[0], "only qualified names (and `error`) are supported") + panic(p.filterError(args[0], "only qualified names (and `error`) are supported")) } + result.fn = makeTypeImplementsFilter(result.src, operand.varName, iface) - wantImplemented := !negate - appendSubFilter(operand.varName, func(params *nodeFilterParams) bool { - return wantImplemented == types.Implements(params.nodeType(), iface) - }) + case "Text.Matches": + patternString, ok := p.toStringValue(args[0]) + if !ok { + panic(p.filterError(args[0], "expected a string literal argument")) + } + re, err := regexp.Compile(patternString) + if err != nil { + panic(p.filterError(args[0], "parse regexp: %v", err)) + } + result.fn = makeTextMatchesFilter(result.src, operand.varName, re) + + case "Node.Is": + typeString, ok := p.toStringValue(args[0]) + if !ok { + panic(p.filterError(args[0], "expected a string literal argument")) + } + cat := categorizeNodeString(typeString) + if cat == nodeUnknown { + panic(p.filterError(args[0], "%s is not a valid go/ast type name", typeString)) + } + result.fn = makeNodeIsFilter(result.src, operand.varName, cat) } - return nil + if result.fn == nil { + panic("bug: nil func for the filter") // Should never happen + } + return result } func (p *rulesParser) toStringValue(x ast.Node) (string, bool) { diff --git a/ruleguard/runner.go b/ruleguard/runner.go index b8496b09..b3b90fc8 100644 --- a/ruleguard/runner.go +++ b/ruleguard/runner.go @@ -24,23 +24,19 @@ type rulesRunner struct { // A slice that is used to do a nodes keys sorting in renderMessage(). sortScratch []string - fileFilterParams fileFilterParams - nodeFilterParams nodeFilterParams + filterParams filterParams } func newRulesRunner(ctx *Context, rules *GoRuleSet) *rulesRunner { rr := &rulesRunner{ ctx: ctx, rules: rules, - fileFilterParams: fileFilterParams{ - ctx: ctx, - }, - nodeFilterParams: nodeFilterParams{ + filterParams: filterParams{ ctx: ctx, }, sortScratch: make([]string, 0, 8), } - rr.nodeFilterParams.nodeText = rr.nodeText + rr.filterParams.nodeText = rr.nodeText return rr } @@ -80,7 +76,7 @@ func (rr *rulesRunner) run(f *ast.File) error { // TODO(quasilyte): run local rules as well. rr.filename = rr.ctx.Fset.Position(f.Pos()).Filename - rr.fileFilterParams.filename = rr.filename + rr.filterParams.filename = rr.filename rr.collectImports(f) for _, rule := range rr.rules.universal.uncategorized { @@ -109,10 +105,6 @@ func (rr *rulesRunner) run(f *ast.File) error { } func (rr *rulesRunner) reject(rule goRule, reason string, m gogrep.MatchData) { - // Note: we accept reason and sub args instead of formatted or - // concatenated string so it's cheaper for us to call this - // function is debugging is not enabled. - if rule.group != rr.ctx.Debug { return // This rule is not being debugged } @@ -157,33 +149,40 @@ func (rr *rulesRunner) reject(rule goRule, reason string, m gogrep.MatchData) { } func (rr *rulesRunner) handleMatch(rule goRule, m gogrep.MatchData) bool { - for _, f := range rule.filter.fileFilters { - if !f.pred(&rr.fileFilterParams) { - rr.reject(rule, f.src, m) + if rule.filter.fn != nil { + rr.filterParams.values = m.Values + filterResult := rule.filter.fn(&rr.filterParams) + if !filterResult.Matched() { + rr.reject(rule, filterResult.RejectReason(), m) return false } } - rr.nodeFilterParams.values = m.Values - for name, node := range m.Values { - var expr ast.Expr - switch node := node.(type) { - case ast.Expr: - expr = node - case *ast.ExprStmt: - expr = node.X - default: - continue - } - - rr.nodeFilterParams.n = expr - for _, f := range rule.filter.subFilters[name] { - if !f.pred(&rr.nodeFilterParams) { - rr.reject(rule, f.src, m) - return false - } - } - } + // for _, f := range rule.filter.fileFilters { + // if !f.pred(&rr.fileFilterParams) { + // rr.reject(rule, f.src, m) + // return false + // } + // } + // rr.nodeFilterParams.values = m.Values + // for name, node := range m.Values { + // var expr ast.Expr + // switch node := node.(type) { + // case ast.Expr: + // expr = node + // case *ast.ExprStmt: + // expr = node.X + // default: + // continue + // } + // rr.nodeFilterParams.n = expr + // for _, f := range rule.filter.subFilters[name] { + // if !f.pred(&rr.nodeFilterParams) { + // rr.reject(rule, f.src, m) + // return false + // } + // } + // } prefix := "" if rule.severity != "" { @@ -210,13 +209,13 @@ func (rr *rulesRunner) handleMatch(rule goRule, m gogrep.MatchData) bool { } func (rr *rulesRunner) collectImports(f *ast.File) { - rr.fileFilterParams.imports = make(map[string]struct{}, len(f.Imports)) + rr.filterParams.imports = make(map[string]struct{}, len(f.Imports)) for _, spec := range f.Imports { s, err := strconv.Unquote(spec.Path.Value) if err != nil { continue } - rr.fileFilterParams.imports[s] = struct{}{} + rr.filterParams.imports[s] = struct{}{} } }