Skip to content

Commit

Permalink
ruleguard/typematch: implement struct type pattern matching (#87)
Browse files Browse the repository at this point in the history
	struct{$*_} - arbitrary struct
	struct{$_; $_} - struct of any 2 fields
	struct{$x; $x} - struct of 2 fields of type $x
	struct{$*_; $x} - struct that ends with $x-typed field
	struct{$x; $*_} - struct that starts with $x-typed field
	struct{$*_; $x; $*_} - struct that contains $x-typed field

This is not a direct solution for https://twitter.com/dgryski/status/1317245210041012224,
but it makes us get a little closer to it.

There are several interpretations of Type.Contains() and we need
to decide what should be traversed and whatnot (and how deep).

Until we decide on the exact Type.Contains() semantics, struct
type pattern matching can be a temporary solution to this problem.

I'll add analyzer tests later to see whether it really can be
used in the requested context. I would expect that we need
things like Type.Underlying(), etc. to make it work correctly.
But that's a different story.

Signed-off-by: Iskander Sharipov <quasilyte@gmail.com>
  • Loading branch information
quasilyte authored Oct 17, 2020
1 parent 56b9139 commit 043f687
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 12 deletions.
129 changes: 124 additions & 5 deletions ruleguard/typematch/typematch.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ const (
opBuiltinType patternOp = iota
opPointer
opVar
opVarSeq
opSlice
opArray
opMap
opChan
opFunc
opStructNoSeq
opStruct
opNamed
)

Expand Down Expand Up @@ -71,8 +74,14 @@ type Context struct {
Itab *ImportsTab
}

const (
varPrefix = `ᐸvarᐳ`
varSeqPrefix = `ᐸvar_seqᐳ`
)

func Parse(ctx *Context, s string) (*Pattern, error) {
noDollars := strings.ReplaceAll(s, "$", "__")
noDollars := strings.ReplaceAll(s, "$*", varSeqPrefix)
noDollars = strings.ReplaceAll(noDollars, "$", varPrefix)
n, err := parser.ParseExpr(noDollars)
if err != nil {
return nil, err
Expand Down Expand Up @@ -126,10 +135,17 @@ func parseExpr(ctx *Context, e ast.Expr) *pattern {
if ok {
return &pattern{op: opBuiltinType, value: basic}
}
if strings.HasPrefix(e.Name, "__") {
name := strings.TrimPrefix(e.Name, "__")
if strings.HasPrefix(e.Name, varPrefix) {
name := strings.TrimPrefix(e.Name, varPrefix)
return &pattern{op: opVar, value: name}
}
if strings.HasPrefix(e.Name, varSeqPrefix) {
name := strings.TrimPrefix(e.Name, varSeqPrefix)
// Only unnamed seq are supported right now.
if name == "_" {
return &pattern{op: opVarSeq, value: name}
}
}

case *ast.SelectorExpr:
pkg, ok := e.X.(*ast.Ident)
Expand Down Expand Up @@ -160,8 +176,8 @@ func parseExpr(ctx *Context, e ast.Expr) *pattern {
subs: []*pattern{elem},
}
}
if id, ok := e.Len.(*ast.Ident); ok && strings.HasPrefix(id.Name, "__") {
name := strings.TrimPrefix(id.Name, "__")
if id, ok := e.Len.(*ast.Ident); ok && strings.HasPrefix(id.Name, varPrefix) {
name := strings.TrimPrefix(id.Name, varPrefix)
return &pattern{
op: opArray,
value: name,
Expand Down Expand Up @@ -254,6 +270,31 @@ func parseExpr(ctx *Context, e ast.Expr) *pattern {
subs: append(params, results...),
}

case *ast.StructType:
hasSeq := false
members := make([]*pattern, 0, len(e.Fields.List))
for _, field := range e.Fields.List {
p := parseExpr(ctx, field.Type)
if p == nil {
return nil
}
if len(field.Names) != 0 {
return nil
}
if p.op == opVarSeq {
hasSeq = true
}
members = append(members, p)
}
op := opStructNoSeq
if hasSeq {
op = opStruct
}
return &pattern{
op: op,
subs: members,
}

case *ast.InterfaceType:
if len(e.Methods.List) == 0 {
return &pattern{op: opBuiltinType, value: efaceType}
Expand All @@ -277,6 +318,54 @@ func (p *Pattern) reset() {
}
}

func (p *Pattern) matchIdenticalFielder(subs []*pattern, f fielder) bool {
// TODO: do backtracking.

numFields := f.NumFields()
fieldsMatched := 0

if len(subs) == 0 && numFields != 0 {
return false
}

matchAny := false

i := 0
for i < len(subs) {
pat := subs[i]

if pat.op == opVarSeq {
matchAny = true
}

fieldsLeft := numFields - fieldsMatched
if matchAny {
switch {
// "Nothing left to match" stop condition.
case fieldsLeft == 0:
matchAny = false
i++
// Lookahead for non-greedy matching.
case i+1 < len(subs) && p.matchIdentical(subs[i+1], f.Field(fieldsMatched).Type()):
matchAny = false
i += 2
fieldsMatched++
default:
fieldsMatched++
}
continue
}

if fieldsLeft == 0 || !p.matchIdentical(pat, f.Field(fieldsMatched).Type()) {
return false
}
i++
fieldsMatched++
}

return numFields == fieldsMatched
}

func (p *Pattern) matchIdentical(sub *pattern, typ types.Type) bool {
switch sub.op {
case opVar:
Expand Down Expand Up @@ -394,7 +483,37 @@ func (p *Pattern) matchIdentical(sub *pattern, typ types.Type) bool {
}
return true

case opStructNoSeq:
typ, ok := typ.(*types.Struct)
if !ok {
return false
}
if typ.NumFields() != len(sub.subs) {
return false
}
for i, member := range sub.subs {
if !p.matchIdentical(member, typ.Field(i).Type()) {
return false
}
}
return true

case opStruct:
typ, ok := typ.(*types.Struct)
if !ok {
return false
}
if !p.matchIdenticalFielder(sub.subs, typ) {
return false
}
return true

default:
return false
}
}

type fielder interface {
Field(i int) *types.Var
NumFields() int
}
87 changes: 80 additions & 7 deletions ruleguard/typematch/typematch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ import (
)

var (
typeInt = types.Typ[types.Int]
typeString = types.Typ[types.String]
typeInt32 = types.Typ[types.Int32]
typeUint8 = types.Typ[types.Uint8]
typeInt = types.Typ[types.Int]
typeString = types.Typ[types.String]
typeInt32 = types.Typ[types.Int32]
typeUint8 = types.Typ[types.Uint8]
typeEstruct = types.NewStruct(nil, nil)

intVar = types.NewVar(token.NoPos, nil, "", typeInt)
stringVar = types.NewVar(token.NoPos, nil, "", typeString)
intVar = types.NewVar(token.NoPos, nil, "_", typeInt)
int32Var = types.NewVar(token.NoPos, nil, "_", typeInt32)
estructVar = types.NewVar(token.NoPos, nil, "_", typeEstruct)
stringVar = types.NewVar(token.NoPos, nil, "_", typeString)

testContext = &Context{
Itab: NewImportsTab(map[string]string{
Expand All @@ -24,12 +27,16 @@ var (
}
)

func structType(fields ...*types.Var) *types.Struct {
return types.NewStruct(fields, nil)
}

func namedType2(pkgPath, typeName string) *types.Named {
return namedType(pkgPath, path.Base(pkgPath), typeName)
}

func namedType(pkgPath, pkgName, typeName string) *types.Named {
dummy := types.NewStruct(nil, nil)
dummy := typeEstruct
pkg := types.NewPackage(pkgPath, pkgName)
typename := types.NewTypeName(0, pkg, typeName, dummy)
return types.NewNamed(typename, dummy, nil)
Expand Down Expand Up @@ -92,6 +99,37 @@ func TestIdentical(t *testing.T) {

{`func($t, $t)`, types.NewSignature(nil, types.NewTuple(stringVar, stringVar), nil, false)},
{`func($t, $t)`, types.NewSignature(nil, types.NewTuple(intVar, intVar), nil, false)},

{`struct{}`, typeEstruct},
{`struct{int}`, types.NewStruct([]*types.Var{intVar}, nil)},
{`struct{string; int}`, types.NewStruct([]*types.Var{stringVar, intVar}, nil)},
{`struct{$_; string}`, types.NewStruct([]*types.Var{stringVar, stringVar}, nil)},
{`struct{$_; $_}`, types.NewStruct([]*types.Var{stringVar, intVar}, nil)},
{`struct{$x; $x}`, types.NewStruct([]*types.Var{intVar, intVar}, nil)},

// Any struct.
{`struct{$*_}`, typeEstruct},
{`struct{$*_}`, structType(intVar, intVar)},

// Struct has suffix.
{`struct{$*_; int}`, structType(intVar)},
{`struct{$*_; int}`, structType(stringVar, stringVar, intVar)},

// Struct has prefix.
{`struct{int; $*_}`, structType(intVar)},
{`struct{int; $*_}`, structType(intVar, stringVar, stringVar)},

// Struct contains.
{`struct{$*_; int; $*_}`, structType(intVar)},
{`struct{$*_; int; $*_}`, structType(stringVar, intVar)},
{`struct{$*_; int; $*_}`, structType(intVar, stringVar)},
{`struct{$*_; int; $*_}`, structType(stringVar, intVar, stringVar)},

// Struct with dups.
{`struct{$*_; $x; $*_; $x; $*_}`, structType(intVar, intVar)},
{`struct{$*_; $x; $*_; $x; $*_}`, structType(intVar, intVar, stringVar)},
{`struct{$*_; $x; $*_; $x; $*_}`, structType(intVar, int32Var, intVar, stringVar)},
{`struct{$*_; $x; $*_; $x; $*_}`, structType(intVar, int32Var, stringVar, intVar)},
}

for _, test := range tests {
Expand Down Expand Up @@ -148,6 +186,41 @@ func TestIdenticalNegative(t *testing.T) {

{`func($t, $t)`, types.NewSignature(nil, types.NewTuple(intVar, stringVar), nil, false)},
{`func($t, $t)`, types.NewSignature(nil, types.NewTuple(stringVar, intVar), nil, false)},

{`struct{}`, typeInt},
{`struct{}`, types.NewStruct([]*types.Var{intVar}, nil)},
{`struct{int}`, typeEstruct},
{`struct{int}`, types.NewStruct([]*types.Var{stringVar}, nil)},
{`struct{string; int}`, types.NewStruct([]*types.Var{intVar, stringVar}, nil)},
{`struct{$_; string}`, types.NewStruct([]*types.Var{stringVar, stringVar, intVar}, nil)},
{`struct{$_; $_}`, types.NewStruct([]*types.Var{stringVar}, nil)},
{`struct{$x; $x}`, types.NewStruct([]*types.Var{intVar, stringVar}, nil)},

// Any struct negative.
{`struct{$*_}`, typeInt},

// Struct has suffix negative.
{`struct{$*_; int}`, typeEstruct},
{`struct{$*_; int}`, structType(stringVar)},

// Struct has prefix negative.
{`struct{int; $*_}`, typeEstruct},
{`struct{int; $*_}`, structType(stringVar)},

// Struct contains negative.
{`struct{$*_; int; $*_}`, typeEstruct},
{`struct{$*_; int; $*_}`, structType(stringVar)},
{`struct{$*_; int; $*_}`, structType(stringVar, int32Var)},

// Struct with dups negative.
{`struct{$*_; $x; $*_; $x; $*_}`, typeEstruct},
{`struct{$*_; $x; $*_; $x; $*_}`, structType(int32Var, intVar)},
{`struct{$*_; $x; $*_; $x; $*_}`, structType(intVar, int32Var, stringVar)},
{`struct{$*_; $x; $*_; $x; $*_}`, structType(intVar, int32Var, estructVar, stringVar)},

// TODO: this should fail as $* is named.
// We don't support named $* now, but they should be supported.
//{`struct{$*x; int; $*x}`, structType(stringVar, intVar, intVar)},
}

for _, test := range tests {
Expand Down

0 comments on commit 043f687

Please sign in to comment.