Skip to content

Commit 59beec6

Browse files
authored
Allow a comma-separated list of local prefixes, like goimports (#33)
Signed-off-by: Luke Shumaker <[email protected]>
1 parent 9b479ee commit 59beec6

File tree

3 files changed

+36
-8
lines changed

3 files changed

+36
-8
lines changed

main.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ var (
1313
doWrite = flag.Bool("w", false, "doWrite result to (source) file instead of stdout")
1414
doDiff = flag.Bool("d", false, "display diffs instead of rewriting files")
1515

16-
localFlag string
16+
localFlag []string
1717

1818
exitCode = 0
1919
)
@@ -27,9 +27,11 @@ func report(err error) {
2727
}
2828

2929
func parseFlags() []string {
30-
flag.StringVar(&localFlag, "local", "", "put imports beginning with this string after 3rd-party packages, only support one string")
30+
var localFlagStr string
31+
flag.StringVar(&localFlagStr, "local", "", "put imports beginning with this string after 3rd-party packages; comma-separated list")
3132

3233
flag.Parse()
34+
localFlag = gci.ParseLocalFlag(localFlagStr)
3335
return flag.Args()
3436
}
3537

pkg/gci/gci.go

+15-5
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import (
3232
)
3333

3434
type FlagSet struct {
35-
LocalFlag string
35+
LocalFlag []string
3636
DoWrite, DoDiff *bool
3737
}
3838

@@ -42,7 +42,15 @@ type pkg struct {
4242
alias map[string]string
4343
}
4444

45-
func newPkg(data [][]byte, localFlag string) *pkg {
45+
// ParseLocalFlag takes a comma-separated list of
46+
// package-name-prefixes (as passed to the "-local" flag), and splits
47+
// it in to a list. This is different than strings.Split in that it
48+
// handles the empty string and empty entries in the list.
49+
func ParseLocalFlag(str string) []string {
50+
return strings.FieldsFunc(str, func(c rune) bool { return c == ',' })
51+
}
52+
53+
func newPkg(data [][]byte, localFlag []string) *pkg {
4654
listMap := make(map[int][]string)
4755
commentMap := make(map[string]string)
4856
aliasMap := make(map[string]string)
@@ -156,11 +164,13 @@ func getPkgInfo(line string, comment bool) (string, string, string) {
156164
}
157165
}
158166

159-
func getPkgType(line, localFlag string) int {
167+
func getPkgType(line string, localFlag []string) int {
160168
pkgName := strings.Trim(line, "\"\\`")
161169

162-
if localFlag != "" && strings.HasPrefix(pkgName, localFlag) {
163-
return local
170+
for _, localPkg := range localFlag {
171+
if strings.HasPrefix(pkgName, localPkg) {
172+
return local
173+
}
164174
}
165175

166176
if isStandardPackage(pkgName) {

pkg/gci/gci_test.go

+17-1
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,47 @@ func TestGetPkgType(t *testing.T) {
1515
{Line: `"foo/pkg/bar"`, LocalFlag: "foo", ExpectedResult: local},
1616
{Line: `"foo/pkg/bar"`, LocalFlag: "bar", ExpectedResult: remote},
1717
{Line: `"foo/pkg/bar"`, LocalFlag: "github.com/foo/bar", ExpectedResult: remote},
18+
{Line: `"foo/pkg/bar"`, LocalFlag: "github.com/foo", ExpectedResult: remote},
19+
{Line: `"foo/pkg/bar"`, LocalFlag: "github.com/bar", ExpectedResult: remote},
20+
{Line: `"foo/pkg/bar"`, LocalFlag: "github.com/foo,github.com/bar", ExpectedResult: remote},
21+
{Line: `"foo/pkg/bar"`, LocalFlag: "github.com/foo,,github.com/bar", ExpectedResult: remote},
1822

1923
{Line: `"github.com/foo/bar"`, LocalFlag: "", ExpectedResult: remote},
2024
{Line: `"github.com/foo/bar"`, LocalFlag: "foo", ExpectedResult: remote},
2125
{Line: `"github.com/foo/bar"`, LocalFlag: "bar", ExpectedResult: remote},
2226
{Line: `"github.com/foo/bar"`, LocalFlag: "github.com/foo/bar", ExpectedResult: local},
27+
{Line: `"github.com/foo/bar"`, LocalFlag: "github.com/foo", ExpectedResult: local},
28+
{Line: `"github.com/foo/bar"`, LocalFlag: "github.com/bar", ExpectedResult: remote},
29+
{Line: `"github.com/foo/bar"`, LocalFlag: "github.com/foo,github.com/bar", ExpectedResult: local},
30+
{Line: `"github.com/foo/bar"`, LocalFlag: "github.com/foo,,github.com/bar", ExpectedResult: local},
2331

2432
{Line: `"context"`, LocalFlag: "", ExpectedResult: standard},
2533
{Line: `"context"`, LocalFlag: "context", ExpectedResult: local},
2634
{Line: `"context"`, LocalFlag: "foo", ExpectedResult: standard},
2735
{Line: `"context"`, LocalFlag: "bar", ExpectedResult: standard},
2836
{Line: `"context"`, LocalFlag: "github.com/foo/bar", ExpectedResult: standard},
37+
{Line: `"context"`, LocalFlag: "github.com/foo", ExpectedResult: standard},
38+
{Line: `"context"`, LocalFlag: "github.com/bar", ExpectedResult: standard},
39+
{Line: `"context"`, LocalFlag: "github.com/foo,github.com/bar", ExpectedResult: standard},
40+
{Line: `"context"`, LocalFlag: "github.com/foo,,github.com/bar", ExpectedResult: standard},
2941

3042
{Line: `"os/signal"`, LocalFlag: "", ExpectedResult: standard},
3143
{Line: `"os/signal"`, LocalFlag: "os/signal", ExpectedResult: local},
3244
{Line: `"os/signal"`, LocalFlag: "foo", ExpectedResult: standard},
3345
{Line: `"os/signal"`, LocalFlag: "bar", ExpectedResult: standard},
3446
{Line: `"os/signal"`, LocalFlag: "github.com/foo/bar", ExpectedResult: standard},
47+
{Line: `"os/signal"`, LocalFlag: "github.com/foo", ExpectedResult: standard},
48+
{Line: `"os/signal"`, LocalFlag: "github.com/bar", ExpectedResult: standard},
49+
{Line: `"os/signal"`, LocalFlag: "github.com/foo,github.com/bar", ExpectedResult: standard},
50+
{Line: `"os/signal"`, LocalFlag: "github.com/foo,,github.com/bar", ExpectedResult: standard},
3551
}
3652

3753
for _, tc := range testCases {
3854
tc := tc
3955
t.Run(fmt.Sprintf("%s:%s", tc.Line, tc.LocalFlag), func(t *testing.T) {
4056
t.Parallel()
4157

42-
result := getPkgType(tc.Line, tc.LocalFlag)
58+
result := getPkgType(tc.Line, ParseLocalFlag(tc.LocalFlag))
4359
if got, want := result, tc.ExpectedResult; got != want {
4460
t.Errorf("bad result: %d, expected: %d", got, want)
4561
}

0 commit comments

Comments
 (0)