Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ruleguard/typematch: implement struct type pattern matching #87

Merged
merged 1 commit into from
Oct 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 124 additions & 5 deletions ruleguard/typematch/typematch.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ const (
opBuiltinType patternOp = iota
opPointer
opVar
opVarSeq
opSlice
opArray
opMap
opChan
opFunc
opStructNoSeq
opStruct
opNamed
)

Expand Down Expand Up @@ -71,8 +74,14 @@ type Context struct {
Itab *ImportsTab
}

const (
varPrefix = `ᐸvarᐳ`
varSeqPrefix = `ᐸvar_seqᐳ`
)

func Parse(ctx *Context, s string) (*Pattern, error) {
noDollars := strings.ReplaceAll(s, "$", "__")
noDollars := strings.ReplaceAll(s, "$*", varSeqPrefix)
noDollars = strings.ReplaceAll(noDollars, "$", varPrefix)
n, err := parser.ParseExpr(noDollars)
if err != nil {
return nil, err
Expand Down Expand Up @@ -126,10 +135,17 @@ func parseExpr(ctx *Context, e ast.Expr) *pattern {
if ok {
return &pattern{op: opBuiltinType, value: basic}
}
if strings.HasPrefix(e.Name, "__") {
name := strings.TrimPrefix(e.Name, "__")
if strings.HasPrefix(e.Name, varPrefix) {
name := strings.TrimPrefix(e.Name, varPrefix)
return &pattern{op: opVar, value: name}
}
if strings.HasPrefix(e.Name, varSeqPrefix) {
name := strings.TrimPrefix(e.Name, varSeqPrefix)
// Only unnamed seq are supported right now.
if name == "_" {
return &pattern{op: opVarSeq, value: name}
}
}

case *ast.SelectorExpr:
pkg, ok := e.X.(*ast.Ident)
Expand Down Expand Up @@ -160,8 +176,8 @@ func parseExpr(ctx *Context, e ast.Expr) *pattern {
subs: []*pattern{elem},
}
}
if id, ok := e.Len.(*ast.Ident); ok && strings.HasPrefix(id.Name, "__") {
name := strings.TrimPrefix(id.Name, "__")
if id, ok := e.Len.(*ast.Ident); ok && strings.HasPrefix(id.Name, varPrefix) {
name := strings.TrimPrefix(id.Name, varPrefix)
return &pattern{
op: opArray,
value: name,
Expand Down Expand Up @@ -254,6 +270,31 @@ func parseExpr(ctx *Context, e ast.Expr) *pattern {
subs: append(params, results...),
}

case *ast.StructType:
hasSeq := false
members := make([]*pattern, 0, len(e.Fields.List))
for _, field := range e.Fields.List {
p := parseExpr(ctx, field.Type)
if p == nil {
return nil
}
if len(field.Names) != 0 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably can be moved above, to prevent parsing when we don't have field.Names

return nil
}
if p.op == opVarSeq {
hasSeq = true
}
members = append(members, p)
}
op := opStructNoSeq
if hasSeq {
op = opStruct
}
return &pattern{
op: op,
subs: members,
}

case *ast.InterfaceType:
if len(e.Methods.List) == 0 {
return &pattern{op: opBuiltinType, value: efaceType}
Expand All @@ -277,6 +318,54 @@ func (p *Pattern) reset() {
}
}

func (p *Pattern) matchIdenticalFielder(subs []*pattern, f fielder) bool {
// TODO: do backtracking.

numFields := f.NumFields()
fieldsMatched := 0

if len(subs) == 0 && numFields != 0 {
return false
}

matchAny := false

i := 0
for i < len(subs) {
pat := subs[i]

if pat.op == opVarSeq {
matchAny = true
}

fieldsLeft := numFields - fieldsMatched
if matchAny {
switch {
// "Nothing left to match" stop condition.
case fieldsLeft == 0:
matchAny = false
i++
// Lookahead for non-greedy matching.
case i+1 < len(subs) && p.matchIdentical(subs[i+1], f.Field(fieldsMatched).Type()):
matchAny = false
i += 2
fieldsMatched++
default:
fieldsMatched++
}
continue
}

if fieldsLeft == 0 || !p.matchIdentical(pat, f.Field(fieldsMatched).Type()) {
return false
}
i++
fieldsMatched++
}

return numFields == fieldsMatched
}

func (p *Pattern) matchIdentical(sub *pattern, typ types.Type) bool {
switch sub.op {
case opVar:
Expand Down Expand Up @@ -394,7 +483,37 @@ func (p *Pattern) matchIdentical(sub *pattern, typ types.Type) bool {
}
return true

case opStructNoSeq:
typ, ok := typ.(*types.Struct)
if !ok {
return false
}
if typ.NumFields() != len(sub.subs) {
return false
}
for i, member := range sub.subs {
if !p.matchIdentical(member, typ.Field(i).Type()) {
return false
}
}
return true

case opStruct:
typ, ok := typ.(*types.Struct)
if !ok {
return false
}
if !p.matchIdenticalFielder(sub.subs, typ) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just return thing func call? :)

return false
}
return true

default:
return false
}
}

type fielder interface {
Field(i int) *types.Var
NumFields() int
}
87 changes: 80 additions & 7 deletions ruleguard/typematch/typematch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ import (
)

var (
typeInt = types.Typ[types.Int]
typeString = types.Typ[types.String]
typeInt32 = types.Typ[types.Int32]
typeUint8 = types.Typ[types.Uint8]
typeInt = types.Typ[types.Int]
typeString = types.Typ[types.String]
typeInt32 = types.Typ[types.Int32]
typeUint8 = types.Typ[types.Uint8]
typeEstruct = types.NewStruct(nil, nil)

intVar = types.NewVar(token.NoPos, nil, "", typeInt)
stringVar = types.NewVar(token.NoPos, nil, "", typeString)
intVar = types.NewVar(token.NoPos, nil, "_", typeInt)
int32Var = types.NewVar(token.NoPos, nil, "_", typeInt32)
estructVar = types.NewVar(token.NoPos, nil, "_", typeEstruct)
stringVar = types.NewVar(token.NoPos, nil, "_", typeString)

testContext = &Context{
Itab: NewImportsTab(map[string]string{
Expand All @@ -24,12 +27,16 @@ var (
}
)

func structType(fields ...*types.Var) *types.Struct {
return types.NewStruct(fields, nil)
}

func namedType2(pkgPath, typeName string) *types.Named {
return namedType(pkgPath, path.Base(pkgPath), typeName)
}

func namedType(pkgPath, pkgName, typeName string) *types.Named {
dummy := types.NewStruct(nil, nil)
dummy := typeEstruct
pkg := types.NewPackage(pkgPath, pkgName)
typename := types.NewTypeName(0, pkg, typeName, dummy)
return types.NewNamed(typename, dummy, nil)
Expand Down Expand Up @@ -92,6 +99,37 @@ func TestIdentical(t *testing.T) {

{`func($t, $t)`, types.NewSignature(nil, types.NewTuple(stringVar, stringVar), nil, false)},
{`func($t, $t)`, types.NewSignature(nil, types.NewTuple(intVar, intVar), nil, false)},

{`struct{}`, typeEstruct},
{`struct{int}`, types.NewStruct([]*types.Var{intVar}, nil)},
{`struct{string; int}`, types.NewStruct([]*types.Var{stringVar, intVar}, nil)},
{`struct{$_; string}`, types.NewStruct([]*types.Var{stringVar, stringVar}, nil)},
{`struct{$_; $_}`, types.NewStruct([]*types.Var{stringVar, intVar}, nil)},
{`struct{$x; $x}`, types.NewStruct([]*types.Var{intVar, intVar}, nil)},

// Any struct.
{`struct{$*_}`, typeEstruct},
{`struct{$*_}`, structType(intVar, intVar)},

// Struct has suffix.
{`struct{$*_; int}`, structType(intVar)},
{`struct{$*_; int}`, structType(stringVar, stringVar, intVar)},

// Struct has prefix.
{`struct{int; $*_}`, structType(intVar)},
{`struct{int; $*_}`, structType(intVar, stringVar, stringVar)},

// Struct contains.
{`struct{$*_; int; $*_}`, structType(intVar)},
{`struct{$*_; int; $*_}`, structType(stringVar, intVar)},
{`struct{$*_; int; $*_}`, structType(intVar, stringVar)},
{`struct{$*_; int; $*_}`, structType(stringVar, intVar, stringVar)},

// Struct with dups.
{`struct{$*_; $x; $*_; $x; $*_}`, structType(intVar, intVar)},
{`struct{$*_; $x; $*_; $x; $*_}`, structType(intVar, intVar, stringVar)},
{`struct{$*_; $x; $*_; $x; $*_}`, structType(intVar, int32Var, intVar, stringVar)},
{`struct{$*_; $x; $*_; $x; $*_}`, structType(intVar, int32Var, stringVar, intVar)},
}

for _, test := range tests {
Expand Down Expand Up @@ -148,6 +186,41 @@ func TestIdenticalNegative(t *testing.T) {

{`func($t, $t)`, types.NewSignature(nil, types.NewTuple(intVar, stringVar), nil, false)},
{`func($t, $t)`, types.NewSignature(nil, types.NewTuple(stringVar, intVar), nil, false)},

{`struct{}`, typeInt},
{`struct{}`, types.NewStruct([]*types.Var{intVar}, nil)},
{`struct{int}`, typeEstruct},
{`struct{int}`, types.NewStruct([]*types.Var{stringVar}, nil)},
{`struct{string; int}`, types.NewStruct([]*types.Var{intVar, stringVar}, nil)},
{`struct{$_; string}`, types.NewStruct([]*types.Var{stringVar, stringVar, intVar}, nil)},
{`struct{$_; $_}`, types.NewStruct([]*types.Var{stringVar}, nil)},
{`struct{$x; $x}`, types.NewStruct([]*types.Var{intVar, stringVar}, nil)},

// Any struct negative.
{`struct{$*_}`, typeInt},

// Struct has suffix negative.
{`struct{$*_; int}`, typeEstruct},
{`struct{$*_; int}`, structType(stringVar)},

// Struct has prefix negative.
{`struct{int; $*_}`, typeEstruct},
{`struct{int; $*_}`, structType(stringVar)},

// Struct contains negative.
{`struct{$*_; int; $*_}`, typeEstruct},
{`struct{$*_; int; $*_}`, structType(stringVar)},
{`struct{$*_; int; $*_}`, structType(stringVar, int32Var)},

// Struct with dups negative.
{`struct{$*_; $x; $*_; $x; $*_}`, typeEstruct},
{`struct{$*_; $x; $*_; $x; $*_}`, structType(int32Var, intVar)},
{`struct{$*_; $x; $*_; $x; $*_}`, structType(intVar, int32Var, stringVar)},
{`struct{$*_; $x; $*_; $x; $*_}`, structType(intVar, int32Var, estructVar, stringVar)},

// TODO: this should fail as $* is named.
// We don't support named $* now, but they should be supported.
//{`struct{$*x; int; $*x}`, structType(stringVar, intVar, intVar)},
}

for _, test := range tests {
Expand Down