From e8433ee6f34a318de5e2216d4403d1056f30a3f7 Mon Sep 17 00:00:00 2001 From: Cody Oss Date: Thu, 25 Feb 2021 16:03:12 -0700 Subject: [PATCH] refactor mockgen and cleanup --- gomock/call.go | 38 +++++----- gomock/callset_test.go | 2 +- gomock/matchers.go | 1 - mockgen/mockgen.go | 50 ++++++++++++-- mockgen/mockgen_test.go | 85 +++++++++++++++++++++++ mockgen/model/model.go | 5 +- mockgen/parse.go | 111 +++++++++-------------------- mockgen/parse_test.go | 150 ---------------------------------------- mockgen/reflect.go | 76 ++++++++++---------- 9 files changed, 220 insertions(+), 298 deletions(-) diff --git a/gomock/call.go b/gomock/call.go index b18cc2d6..bd994266 100644 --- a/gomock/call.go +++ b/gomock/call.go @@ -50,16 +50,16 @@ func newCall(t TestHelper, receiver interface{}, method string, methodType refle t.Helper() // TODO: check arity, types. - margs := make([]Matcher, len(args)) + mArgs := make([]Matcher, len(args)) for i, arg := range args { if m, ok := arg.(Matcher); ok { - margs[i] = m + mArgs[i] = m } else if arg == nil { // Handle nil specially so that passing a nil interface value // will match the typed nils of concrete args. - margs[i] = Nil() + mArgs[i] = Nil() } else { - margs[i] = Eq(arg) + mArgs[i] = Eq(arg) } } @@ -76,7 +76,7 @@ func newCall(t TestHelper, receiver interface{}, method string, methodType refle return rets }} return &Call{t: t, receiver: receiver, method: method, methodType: methodType, - args: margs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions} + args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions} } // AnyTimes allows the expectation to be called 0 or more times @@ -113,19 +113,19 @@ func (c *Call) DoAndReturn(f interface{}) *Call { v := reflect.ValueOf(f) c.addAction(func(args []interface{}) []interface{} { - vargs := make([]reflect.Value, len(args)) + vArgs := make([]reflect.Value, len(args)) ft := v.Type() for i := 0; i < len(args); i++ { if args[i] != nil { - vargs[i] = reflect.ValueOf(args[i]) + vArgs[i] = reflect.ValueOf(args[i]) } else { // Use the zero value for the arg. - vargs[i] = reflect.Zero(ft.In(i)) + vArgs[i] = reflect.Zero(ft.In(i)) } } - vrets := v.Call(vargs) - rets := make([]interface{}, len(vrets)) - for i, ret := range vrets { + vRets := v.Call(vArgs) + rets := make([]interface{}, len(vRets)) + for i, ret := range vRets { rets[i] = ret.Interface() } return rets @@ -142,17 +142,17 @@ func (c *Call) Do(f interface{}) *Call { v := reflect.ValueOf(f) c.addAction(func(args []interface{}) []interface{} { - vargs := make([]reflect.Value, len(args)) + vArgs := make([]reflect.Value, len(args)) ft := v.Type() for i := 0; i < len(args); i++ { if args[i] != nil { - vargs[i] = reflect.ValueOf(args[i]) + vArgs[i] = reflect.ValueOf(args[i]) } else { // Use the zero value for the arg. - vargs[i] = reflect.Zero(ft.In(i)) + vArgs[i] = reflect.Zero(ft.In(i)) } } - v.Call(vargs) + v.Call(vArgs) return nil }) return c @@ -353,12 +353,12 @@ func (c *Call) matches(args []interface{}) error { // matches all the remaining arguments or the lack of any. // Convert the remaining arguments, if any, into a slice of the // expected type. - vargsType := c.methodType.In(c.methodType.NumIn() - 1) - vargs := reflect.MakeSlice(vargsType, 0, len(args)-i) + vArgsType := c.methodType.In(c.methodType.NumIn() - 1) + vArgs := reflect.MakeSlice(vArgsType, 0, len(args)-i) for _, arg := range args[i:] { - vargs = reflect.Append(vargs, reflect.ValueOf(arg)) + vArgs = reflect.Append(vArgs, reflect.ValueOf(arg)) } - if m.Matches(vargs.Interface()) { + if m.Matches(vArgs.Interface()) { // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, gomock.Any()) // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, someSliceMatcher) // Got Foo(a, b) want Foo(matcherA, matcherB, gomock.Any()) diff --git a/gomock/callset_test.go b/gomock/callset_test.go index a2835f34..fe053af7 100644 --- a/gomock/callset_test.go +++ b/gomock/callset_test.go @@ -84,7 +84,7 @@ func TestCallSetFindMatch(t *testing.T) { c1 := newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func)) cs.exhausted = map[callSetKey][]*Call{ - callSetKey{receiver: receiver, fname: method}: []*Call{c1}, + {receiver: receiver, fname: method}: {c1}, } _, err := cs.FindMatch(receiver, method, args) diff --git a/gomock/matchers.go b/gomock/matchers.go index 770aba5a..5638efe5 100644 --- a/gomock/matchers.go +++ b/gomock/matchers.go @@ -153,7 +153,6 @@ func (n notMatcher) Matches(x interface{}) bool { } func (n notMatcher) String() string { - // TODO: Improve this if we add a NotString method to the Matcher interface. return "not(" + n.m.String() + ")" } diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index f816b093..f7c82c5d 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -38,6 +38,7 @@ import ( "github.com/golang/mock/mockgen/model" + "golang.org/x/mod/modfile" toolsimports "golang.org/x/tools/imports" ) @@ -84,6 +85,7 @@ func main() { log.Fatal("Expected exactly two arguments") } packageName = flag.Arg(0) + interfaces := strings.Split(flag.Arg(1), ",") if packageName == "." { dir, err := os.Getwd() if err != nil { @@ -94,7 +96,7 @@ func main() { log.Fatalf("Parse package name failed: %v", err) } } - pkg, err = reflectMode(packageName, strings.Split(flag.Arg(1), ",")) + pkg, err = reflectMode(packageName, interfaces) } if err != nil { log.Fatalf("Loading input failed: %v", err) @@ -394,11 +396,6 @@ func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePa g.p("}") g.p("") - // TODO: Re-enable this if we can import the interface reliably. - // g.p("// Verify that the mock satisfies the interface at compile time.") - // g.p("var _ %v = (*%v)(nil)", typeName, mockType) - // g.p("") - g.p("// New%v creates a new mock instance.", mockType) g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType) g.in() @@ -665,3 +662,44 @@ func printVersion() { printModuleVersion() } } + +// parseImportPackage get package import path via source file +// an alternative implementation is to use: +// cfg := &packages.Config{Mode: packages.NeedName, Tests: true, Dir: srcDir} +// pkgs, err := packages.Load(cfg, "file="+source) +// However, it will call "go list" and slow down the performance +func parsePackageImport(srcDir string) (string, error) { + moduleMode := os.Getenv("GO111MODULE") + // trying to find the module + if moduleMode != "off" { + currentDir := srcDir + for { + dat, err := ioutil.ReadFile(filepath.Join(currentDir, "go.mod")) + if os.IsNotExist(err) { + if currentDir == filepath.Dir(currentDir) { + // at the root + break + } + currentDir = filepath.Dir(currentDir) + continue + } else if err != nil { + return "", err + } + modulePath := modfile.ModulePath(dat) + return filepath.ToSlash(filepath.Join(modulePath, strings.TrimPrefix(srcDir, currentDir))), nil + } + } + // fall back to GOPATH mode + goPaths := os.Getenv("GOPATH") + if goPaths == "" { + return "", fmt.Errorf("GOPATH is not set") + } + goPathList := strings.Split(goPaths, string(os.PathListSeparator)) + for _, goPath := range goPathList { + sourceRoot := filepath.Join(goPath, "src") + string(os.PathSeparator) + if strings.HasPrefix(srcDir, sourceRoot) { + return filepath.ToSlash(strings.TrimPrefix(srcDir, sourceRoot)), nil + } + } + return "", errOutsideGoPath +} diff --git a/mockgen/mockgen_test.go b/mockgen/mockgen_test.go index 3c3eaae2..55566001 100644 --- a/mockgen/mockgen_test.go +++ b/mockgen/mockgen_test.go @@ -2,6 +2,9 @@ package main import ( "fmt" + "io/ioutil" + "os" + "path/filepath" "reflect" "regexp" "strings" @@ -364,3 +367,85 @@ func Test_createPackageMap(t *testing.T) { }) } } + +func TestParsePackageImport_FallbackGoPath(t *testing.T) { + goPath, err := ioutil.TempDir("", "gopath") + if err != nil { + t.Error(err) + } + defer func() { + if err = os.RemoveAll(goPath); err != nil { + t.Error(err) + } + }() + srcDir := filepath.Join(goPath, "src/example.com/foo") + err = os.MkdirAll(srcDir, 0755) + if err != nil { + t.Error(err) + } + key := "GOPATH" + value := goPath + if err := os.Setenv(key, value); err != nil { + t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err) + } + key = "GO111MODULE" + value = "on" + if err := os.Setenv(key, value); err != nil { + t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err) + } + pkgPath, err := parsePackageImport(srcDir) + expected := "example.com/foo" + if pkgPath != expected { + t.Errorf("expect %s, got %s", expected, pkgPath) + } +} + +func TestParsePackageImport_FallbackMultiGoPath(t *testing.T) { + var goPathList []string + + // first gopath + goPath, err := ioutil.TempDir("", "gopath1") + if err != nil { + t.Error(err) + } + goPathList = append(goPathList, goPath) + defer func() { + if err = os.RemoveAll(goPath); err != nil { + t.Error(err) + } + }() + srcDir := filepath.Join(goPath, "src/example.com/foo") + err = os.MkdirAll(srcDir, 0755) + if err != nil { + t.Error(err) + } + + // second gopath + goPath, err = ioutil.TempDir("", "gopath2") + if err != nil { + t.Error(err) + } + goPathList = append(goPathList, goPath) + defer func() { + if err = os.RemoveAll(goPath); err != nil { + t.Error(err) + } + }() + + goPaths := strings.Join(goPathList, string(os.PathListSeparator)) + key := "GOPATH" + value := goPaths + if err := os.Setenv(key, value); err != nil { + t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err) + } + key = "GO111MODULE" + value = "on" + if err := os.Setenv(key, value); err != nil { + t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err) + } + pkgPath, err := parsePackageImport(srcDir) + expected := "example.com/foo" + if pkgPath != expected { + t.Errorf("expect %s, got %s", expected, pkgPath) + } +} diff --git a/mockgen/model/model.go b/mockgen/model/model.go index d06d5162..2c6a62ce 100644 --- a/mockgen/model/model.go +++ b/mockgen/model/model.go @@ -71,7 +71,7 @@ func (intf *Interface) addImports(im map[string]bool) { } } -// AddMethod adds a new method, deduplicating by method name. +// AddMethod adds a new method, de-duplicating by method name. func (intf *Interface) AddMethod(m *Method) { for _, me := range intf.Methods { if me.Name == m.Name { @@ -260,11 +260,10 @@ func (mt *MapType) addImports(im map[string]bool) { // NamedType is an exported type in a package. type NamedType struct { Package string // may be empty - Type string // TODO: should this be typed Type? + Type string } func (nt *NamedType) String(pm map[string]string, pkgOverride string) string { - // TODO: is this right? if pkgOverride == nt.Package { return nt.Type } diff --git a/mockgen/parse.go b/mockgen/parse.go index cdb82a34..c46221aa 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -26,14 +26,12 @@ import ( "go/token" "io/ioutil" "log" - "os" "path" "path/filepath" "strconv" "strings" "github.com/golang/mock/mockgen/model" - "golang.org/x/mod/modfile" ) var ( @@ -41,8 +39,6 @@ var ( auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.") ) -// TODO: simplify error reporting - // sourceMode generates mocks via source file. func sourceMode(source string) (*model.Package, error) { srcDir, err := filepath.Abs(filepath.Dir(source)) @@ -76,10 +72,8 @@ func sourceMode(source string) (*model.Package, error) { eq := strings.Index(kv, "=") k, v := kv[:eq], kv[eq+1:] if k == "." { - // TODO: Catch dupes? dotImports[v] = true } else { - // TODO: Catch dupes? p.imports[k] = importedPkg{path: v} } } @@ -125,7 +119,7 @@ type duplicateImport struct { } func (d duplicateImport) Error() string { - return fmt.Sprintf("%q is ambigous because of duplicate imports: %v", d.name, d.duplicates) + return fmt.Sprintf("%q is ambiguous because of duplicate imports: %v", d.name, d.duplicates) } func (d duplicateImport) Path() string { log.Fatal(d.Error()); return "" } @@ -252,12 +246,12 @@ func (p *fileParser) parsePackage(path string) (*fileParser, error) { } func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) { - intf := &model.Interface{Name: name} + iface := &model.Interface{Name: name} for _, field := range it.Methods.List { switch v := field.Type.(type) { case *ast.FuncType: if nn := len(field.Names); nn != 1 { - return nil, fmt.Errorf("expected one name for interface %v, got %d", intf.Name, nn) + return nil, fmt.Errorf("expected one name for interface %v, got %d", iface.Name, nn) } m := &model.Method{ Name: field.Names[0].String(), @@ -267,84 +261,84 @@ func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*m if err != nil { return nil, err } - intf.AddMethod(m) + iface.AddMethod(m) case *ast.Ident: // Embedded interface in this package. - ei := p.auxInterfaces[pkg][v.String()] - if ei == nil { - ei = p.importedInterfaces[pkg][v.String()] + embeddedIfaceType := p.auxInterfaces[pkg][v.String()] + if embeddedIfaceType == nil { + embeddedIfaceType = p.importedInterfaces[pkg][v.String()] } - var eintf *model.Interface - if ei != nil { + var embeddedIface *model.Interface + if embeddedIfaceType != nil { var err error - eintf, err = p.parseInterface(v.String(), pkg, ei) + embeddedIface, err = p.parseInterface(v.String(), pkg, embeddedIfaceType) if err != nil { return nil, err } } else { // This is built-in error interface. if v.String() == model.ErrorInterface.Name { - eintf = &model.ErrorInterface + embeddedIface = &model.ErrorInterface } else { return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String()) } } // Copy the methods. - for _, m := range eintf.Methods { - intf.AddMethod(m) + for _, m := range embeddedIface.Methods { + iface.AddMethod(m) } case *ast.SelectorExpr: // Embedded interface in another package. - fpkg, sel := v.X.(*ast.Ident).String(), v.Sel.String() - epkg, ok := p.imports[fpkg] + filePkg, sel := v.X.(*ast.Ident).String(), v.Sel.String() + embeddedPkg, ok := p.imports[filePkg] if !ok { - return nil, p.errorf(v.X.Pos(), "unknown package %s", fpkg) + return nil, p.errorf(v.X.Pos(), "unknown package %s", filePkg) } - var eintf *model.Interface + var embeddedIface *model.Interface var err error - ei := p.auxInterfaces[fpkg][sel] - if ei != nil { - eintf, err = p.parseInterface(sel, fpkg, ei) + embeddedIfaceType := p.auxInterfaces[filePkg][sel] + if embeddedIfaceType != nil { + embeddedIface, err = p.parseInterface(sel, filePkg, embeddedIfaceType) if err != nil { return nil, err } } else { - path := epkg.Path() - parser := epkg.Parser() + path := embeddedPkg.Path() + parser := embeddedPkg.Parser() if parser == nil { ip, err := p.parsePackage(path) if err != nil { return nil, p.errorf(v.Pos(), "could not parse package %s: %v", path, err) } parser = ip - p.imports[fpkg] = importedPkg{ - path: epkg.Path(), + p.imports[filePkg] = importedPkg{ + path: embeddedPkg.Path(), parser: parser, } } - if ei = parser.importedInterfaces[path][sel]; ei == nil { + if embeddedIfaceType = parser.importedInterfaces[path][sel]; embeddedIfaceType == nil { return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", path, sel) } - eintf, err = parser.parseInterface(sel, path, ei) + embeddedIface, err = parser.parseInterface(sel, path, embeddedIfaceType) if err != nil { return nil, err } } // Copy the methods. // TODO: apply shadowing rules. - for _, m := range eintf.Methods { - intf.AddMethod(m) + for _, m := range embeddedIface.Methods { + iface.AddMethod(m) } default: return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) } } - return intf, nil + return iface, nil } -func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (in []*model.Parameter, variadic *model.Parameter, out []*model.Parameter, err error) { +func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (inParam []*model.Parameter, variadic *model.Parameter, outParam []*model.Parameter, err error) { if f.Params != nil { regParams := f.Params.List if isVariadic(f) { @@ -357,13 +351,13 @@ func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (in []*model.Paramet } variadic = vp[0] } - in, err = p.parseFieldList(pkg, regParams) + inParam, err = p.parseFieldList(pkg, regParams) if err != nil { return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err) } } if f.Results != nil { - out, err = p.parseFieldList(pkg, f.Results.List) + outParam, err = p.parseFieldList(pkg, f.Results.List) if err != nil { return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err) } @@ -635,44 +629,3 @@ func packageNameOfDir(srcDir string) (string, error) { } var errOutsideGoPath = errors.New("Source directory is outside GOPATH") - -// parseImportPackage get package import path via source file -// an alternative implementation is to use: -// cfg := &packages.Config{Mode: packages.NeedName, Tests: true, Dir: srcDir} -// pkgs, err := packages.Load(cfg, "file="+source) -// However, it will call "go list" and slow down the performance -func parsePackageImport(srcDir string) (string, error) { - moduleMode := os.Getenv("GO111MODULE") - // trying to find the module - if moduleMode != "off" { - currentDir := srcDir - for { - dat, err := ioutil.ReadFile(filepath.Join(currentDir, "go.mod")) - if os.IsNotExist(err) { - if currentDir == filepath.Dir(currentDir) { - // at the root - break - } - currentDir = filepath.Dir(currentDir) - continue - } else if err != nil { - return "", err - } - modulePath := modfile.ModulePath(dat) - return filepath.ToSlash(filepath.Join(modulePath, strings.TrimPrefix(srcDir, currentDir))), nil - } - } - // fall back to GOPATH mode - goPaths := os.Getenv("GOPATH") - if goPaths == "" { - return "", fmt.Errorf("GOPATH is not set") - } - goPathList := strings.Split(goPaths, string(os.PathListSeparator)) - for _, goPath := range goPathList { - sourceRoot := filepath.Join(goPath, "src") + string(os.PathSeparator) - if strings.HasPrefix(srcDir, sourceRoot) { - return filepath.ToSlash(strings.TrimPrefix(srcDir, sourceRoot)), nil - } - } - return "", errOutsideGoPath -} diff --git a/mockgen/parse_test.go b/mockgen/parse_test.go index 6230df58..a7ea9f82 100644 --- a/mockgen/parse_test.go +++ b/mockgen/parse_test.go @@ -4,10 +4,6 @@ import ( "go/ast" "go/parser" "go/token" - "io/ioutil" - "os" - "path/filepath" - "strings" "testing" ) @@ -118,152 +114,6 @@ func Benchmark_parseFile(b *testing.B) { } } -func TestParsePackageImport(t *testing.T) { - testRoot, err := ioutil.TempDir("", "test_root") - if err != nil { - t.Fatal("cannot create tempdir") - } - defer func() { - if err = os.RemoveAll(testRoot); err != nil { - t.Errorf("cannot clean up tempdir at %s: %v", testRoot, err) - } - }() - barDir := filepath.Join(testRoot, "gomod/bar") - if err = os.MkdirAll(barDir, 0755); err != nil { - t.Fatalf("error creating %s: %v", barDir, err) - } - if err = ioutil.WriteFile(filepath.Join(barDir, "bar.go"), []byte("package bar"), 0644); err != nil { - t.Fatalf("error creating gomod/bar/bar.go: %v", err) - } - if err = ioutil.WriteFile(filepath.Join(testRoot, "gomod/go.mod"), []byte("module github.com/golang/foo"), 0644); err != nil { - t.Fatalf("error creating gomod/go.mod: %v", err) - } - goPath := filepath.Join(testRoot, "gopath") - for _, testCase := range []struct { - name string - envs map[string]string - dir string - pkgPath string - err error - }{ - { - name: "go mod default", - envs: map[string]string{"GO111MODULE": ""}, - dir: barDir, - pkgPath: "github.com/golang/foo/bar", - }, - { - name: "go mod off", - envs: map[string]string{"GO111MODULE": "off", "GOPATH": goPath}, - dir: filepath.Join(testRoot, "gopath/src/example.com/foo"), - pkgPath: "example.com/foo", - }, - { - name: "outside GOPATH", - envs: map[string]string{"GO111MODULE": "off", "GOPATH": goPath}, - dir: "testdata", - err: errOutsideGoPath, - }, - } { - t.Run(testCase.name, func(t *testing.T) { - for key, value := range testCase.envs { - if err := os.Setenv(key, value); err != nil { - t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err) - } - } - pkgPath, err := parsePackageImport(filepath.Clean(testCase.dir)) - if err != testCase.err { - t.Errorf("expect %v, got %v", testCase.err, err) - } - if pkgPath != testCase.pkgPath { - t.Errorf("expect %s, got %s", testCase.pkgPath, pkgPath) - } - }) - } -} - -func TestParsePackageImport_FallbackGoPath(t *testing.T) { - goPath, err := ioutil.TempDir("", "gopath") - if err != nil { - t.Error(err) - } - defer func() { - if err = os.RemoveAll(goPath); err != nil { - t.Error(err) - } - }() - srcDir := filepath.Join(goPath, "src/example.com/foo") - err = os.MkdirAll(srcDir, 0755) - if err != nil { - t.Error(err) - } - key := "GOPATH" - value := goPath - if err := os.Setenv(key, value); err != nil { - t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err) - } - key = "GO111MODULE" - value = "on" - if err := os.Setenv(key, value); err != nil { - t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err) - } - pkgPath, err := parsePackageImport(srcDir) - expected := "example.com/foo" - if pkgPath != expected { - t.Errorf("expect %s, got %s", expected, pkgPath) - } -} - -func TestParsePackageImport_FallbackMultiGoPath(t *testing.T) { - var goPathList []string - - // first gopath - goPath, err := ioutil.TempDir("", "gopath1") - if err != nil { - t.Error(err) - } - goPathList = append(goPathList, goPath) - defer func() { - if err = os.RemoveAll(goPath); err != nil { - t.Error(err) - } - }() - srcDir := filepath.Join(goPath, "src/example.com/foo") - err = os.MkdirAll(srcDir, 0755) - if err != nil { - t.Error(err) - } - - // second gopath - goPath, err = ioutil.TempDir("", "gopath2") - if err != nil { - t.Error(err) - } - goPathList = append(goPathList, goPath) - defer func() { - if err = os.RemoveAll(goPath); err != nil { - t.Error(err) - } - }() - - goPaths := strings.Join(goPathList, string(os.PathListSeparator)) - key := "GOPATH" - value := goPaths - if err := os.Setenv(key, value); err != nil { - t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err) - } - key = "GO111MODULE" - value = "on" - if err := os.Setenv(key, value); err != nil { - t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err) - } - pkgPath, err := parsePackageImport(srcDir) - expected := "example.com/foo" - if pkgPath != expected { - t.Errorf("expect %s, got %s", expected, pkgPath) - } -} - func TestParseArrayWithConstLength(t *testing.T) { fs := token.NewFileSet() diff --git a/mockgen/reflect.go b/mockgen/reflect.go index 55c24505..620b5f15 100644 --- a/mockgen/reflect.go +++ b/mockgen/reflect.go @@ -39,6 +39,43 @@ var ( buildFlags = flag.String("build_flags", "", "(reflect mode) Additional flags for go build.") ) +// reflectMode generates mocks via reflection on an interface. +func reflectMode(importPath string, symbols []string) (*model.Package, error) { + if *execOnly != "" { + return run(*execOnly) + } + + program, err := writeProgram(importPath, symbols) + if err != nil { + return nil, err + } + + if *progOnly { + if _, err := os.Stdout.Write(program); err != nil { + return nil, err + } + os.Exit(0) + } + + wd, _ := os.Getwd() + + // Try to run the reflection program in the current working directory. + if p, err := runInDir(program, wd); err == nil { + return p, nil + } + + // Try to run the program in the same directory as the input package. + if p, err := build.Import(importPath, wd, build.FindOnly); err == nil { + dir := p.Dir + if p, err := runInDir(program, dir); err == nil { + return p, nil + } + } + + // Try to run it in a standard temp directory. + return runInDir(program, "") +} + func writeProgram(importPath string, symbols []string) ([]byte, error) { var program bytes.Buffer data := reflectData{ @@ -133,45 +170,6 @@ func runInDir(program []byte, dir string) (*model.Package, error) { return run(filepath.Join(tmpDir, progBinary)) } -// reflectMode generates mocks via reflection on an interface. -func reflectMode(importPath string, symbols []string) (*model.Package, error) { - // TODO: sanity check arguments - - if *execOnly != "" { - return run(*execOnly) - } - - program, err := writeProgram(importPath, symbols) - if err != nil { - return nil, err - } - - if *progOnly { - if _, err := os.Stdout.Write(program); err != nil { - return nil, err - } - os.Exit(0) - } - - wd, _ := os.Getwd() - - // Try to run the reflection program in the current working directory. - if p, err := runInDir(program, wd); err == nil { - return p, nil - } - - // Try to run the program in the same directory as the input package. - if p, err := build.Import(importPath, wd, build.FindOnly); err == nil { - dir := p.Dir - if p, err := runInDir(program, dir); err == nil { - return p, nil - } - } - - // Try to run it in a standard temp directory. - return runInDir(program, "") -} - type reflectData struct { ImportPath string Symbols []string