Skip to content

Commit

Permalink
fix: package type detection when local flag is empty. (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
ldez authored Dec 8, 2020
1 parent 599c94c commit 8f975be
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 17 deletions.
3 changes: 2 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package main
import (
"flag"
"fmt"
"github.com/daixiang0/gci/pkg/gci"
"go/scanner"
"os"

"github.com/daixiang0/gci/pkg/gci"
)

var (
Expand Down
8 changes: 6 additions & 2 deletions pkg/gci/gci.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,15 @@ func getPkgInfo(line string, comment bool) (string, string, string) {

func getPkgType(line, localFlag string) int {
pkgName := strings.Trim(line, "\"\\`")
if strings.HasPrefix(pkgName, localFlag) {

if localFlag != "" && strings.HasPrefix(pkgName, localFlag) {
return local
} else if isStandardPackage(pkgName) {
}

if isStandardPackage(pkgName) {
return standard
}

return remote
}

Expand Down
38 changes: 24 additions & 14 deletions pkg/gci/gci_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,44 @@ import (
"testing"
)

func TestGetPkgType(b *testing.T) {
func TestGetPkgType(t *testing.T) {
testCases := []struct {
Line string
LocalFlag string
ExpectedResult int
}{
{Line: `"foo/pkg/bar"`, LocalFlag: "", ExpectedResult: remote},
{Line: `"foo/pkg/bar"`, LocalFlag: "foo", ExpectedResult: local},
{Line: `"github.com/foo/bar"`, LocalFlag: "foo", ExpectedResult: remote},
{Line: `"context"`, LocalFlag: "foo", ExpectedResult: standard},
{Line: `"os/signal"`, LocalFlag: "foo", ExpectedResult: standard},
{Line: `"foo/pkg/bar"`, LocalFlag: "bar", ExpectedResult: remote},
{Line: `"github.com/foo/bar"`, LocalFlag: "bar", ExpectedResult: remote},
{Line: `"context"`, LocalFlag: "bar", ExpectedResult: standard},
{Line: `"os/signal"`, LocalFlag: "bar", ExpectedResult: standard},
{Line: `"foo/pkg/bar"`, LocalFlag: "github.com/foo/bar", ExpectedResult: remote},

{Line: `"github.com/foo/bar"`, LocalFlag: "", ExpectedResult: remote},
{Line: `"github.com/foo/bar"`, LocalFlag: "foo", ExpectedResult: remote},
{Line: `"github.com/foo/bar"`, LocalFlag: "bar", ExpectedResult: remote},
{Line: `"github.com/foo/bar"`, LocalFlag: "github.com/foo/bar", ExpectedResult: local},

{Line: `"context"`, LocalFlag: "", ExpectedResult: standard},
{Line: `"context"`, LocalFlag: "context", ExpectedResult: local},
{Line: `"context"`, LocalFlag: "foo", ExpectedResult: standard},
{Line: `"context"`, LocalFlag: "bar", ExpectedResult: standard},
{Line: `"context"`, LocalFlag: "github.com/foo/bar", ExpectedResult: standard},

{Line: `"os/signal"`, LocalFlag: "", ExpectedResult: standard},
{Line: `"os/signal"`, LocalFlag: "os/signal", ExpectedResult: local},
{Line: `"os/signal"`, LocalFlag: "foo", ExpectedResult: standard},
{Line: `"os/signal"`, LocalFlag: "bar", ExpectedResult: standard},
{Line: `"os/signal"`, LocalFlag: "github.com/foo/bar", ExpectedResult: standard},
}

for _, _tCase := range testCases {
tCase := _tCase
testFn := func(t *testing.T) {
result := getPkgType(tCase.Line, tCase.LocalFlag)
if got, want := result, tCase.ExpectedResult; got != want {
for _, tc := range testCases {
tc := tc
t.Run(fmt.Sprintf("%s:%s", tc.Line, tc.LocalFlag), func(t *testing.T) {
t.Parallel()

result := getPkgType(tc.Line, tc.LocalFlag)
if got, want := result, tc.ExpectedResult; got != want {
t.Errorf("bad result: %d, expected: %d", got, want)
}
}
b.Run(fmt.Sprintf("%s:%s", tCase.LocalFlag, tCase.Line), testFn)
})
}
}

0 comments on commit 8f975be

Please sign in to comment.