Skip to content

Commit e0e26a3

Browse files
committed
Refactor and add testcase
1 parent beed208 commit e0e26a3

File tree

4 files changed

+308
-217
lines changed

4 files changed

+308
-217
lines changed

main.go

Lines changed: 1 addition & 217 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,11 @@ import (
88
"io"
99
"io/ioutil"
1010
"os"
11-
"os/exec"
1211
"path/filepath"
13-
"sort"
1412
"strings"
1513
)
1614

17-
const (
18-
// pkg type: standard, remote, local
19-
standard int = iota
20-
// 3rd-party packages
21-
remote
22-
local
2315

24-
dot = "."
25-
blank = " "
26-
indent = "\t"
27-
linebreak = "\n"
28-
commentFlag = "//"
29-
)
3016

3117
var (
3218
write = flag.Bool("w", false, "write result to (source) file instead of stdout")
@@ -35,13 +21,6 @@ var (
3521
localFlag string
3622

3723
exitCode = 0
38-
39-
importStartFlag = []byte(`
40-
import (
41-
`)
42-
importEndFlag = []byte(`
43-
)
44-
`)
4524
)
4625

4726
func report(err error) {
@@ -68,186 +47,6 @@ func usage() {
6847
os.Exit(2)
6948
}
7049

71-
type pkg struct {
72-
list map[int][]string
73-
comment map[string]string
74-
alias map[string]string
75-
}
76-
77-
func newPkg(data [][]byte) *pkg {
78-
listMap := make(map[int][]string)
79-
commentMap := make(map[string]string)
80-
aliasMap := make(map[string]string)
81-
p := &pkg{
82-
list: listMap,
83-
comment: commentMap,
84-
alias: aliasMap,
85-
}
86-
87-
formatData := make([]string, 0)
88-
// remove all empty lines
89-
for _, v := range data {
90-
if len(v) > 0 {
91-
formatData = append(formatData, strings.TrimSpace(string(v)))
92-
}
93-
}
94-
95-
for i := len(formatData) - 1; i >= 0; i-- {
96-
line := formatData[i]
97-
98-
// check commentFlag:
99-
// 1. one line commentFlag
100-
// 2. commentFlag after import path
101-
commentIndex := strings.Index(line, commentFlag)
102-
if commentIndex == 0 {
103-
pkg, _, _ := getPkgInfo(formatData[i+1])
104-
p.comment[pkg] = line
105-
continue
106-
} else if commentIndex > 0 {
107-
pkg, alias, comment := getPkgInfo(line)
108-
if alias != "" {
109-
p.alias[pkg] = alias
110-
}
111-
112-
p.comment[pkg] = comment
113-
pkgType := getPkgType(pkg)
114-
p.list[pkgType] = append(p.list[pkgType], pkg)
115-
continue
116-
}
117-
118-
pkg, alias, _ := getPkgInfo(line)
119-
120-
if alias != "" {
121-
p.alias[pkg] = alias
122-
}
123-
124-
pkgType := getPkgType(pkg)
125-
p.list[pkgType] = append(p.list[pkgType], pkg)
126-
}
127-
128-
return p
129-
}
130-
131-
// getPkgInfo assume line is a import path, and return (path, alias)
132-
func getPkgInfo(line string) (string, string, string) {
133-
pkgArray := strings.Split(line, blank)
134-
if len(pkgArray) > 1 {
135-
return pkgArray[1], pkgArray[0], strings.Join(pkgArray[2:], "")
136-
} else {
137-
return line, "", ""
138-
}
139-
}
140-
141-
// fmt format import pkgs as expected
142-
func (p *pkg) fmt() []byte {
143-
ret := make([]string, 0, 100)
144-
145-
for pkgType := range []int{standard, remote, local} {
146-
sort.Strings(p.list[pkgType])
147-
for _, s := range p.list[pkgType] {
148-
if p.comment[s] != "" {
149-
l := fmt.Sprintf("%s%s%s", indent, p.comment[s], linebreak)
150-
ret = append(ret, l)
151-
}
152-
153-
if p.alias[s] != "" {
154-
s = fmt.Sprintf("%s%s%s%s%s", indent, p.alias[s], blank, s, linebreak)
155-
} else {
156-
s = fmt.Sprintf("%s%s%s", indent, s, linebreak)
157-
}
158-
159-
ret = append(ret, s)
160-
}
161-
162-
if len(p.list[pkgType]) > 0 {
163-
ret = append(ret, linebreak)
164-
}
165-
}
166-
if ret[len(ret)-1] == linebreak {
167-
ret = ret[:len(ret)-1]
168-
}
169-
return []byte(strings.Join(ret, ""))
170-
}
171-
172-
func diff(b1, b2 []byte, filename string) (data []byte, err error) {
173-
f1, err := writeTempFile("", "gci", b1)
174-
if err != nil {
175-
return
176-
}
177-
defer os.Remove(f1)
178-
179-
f2, err := writeTempFile("", "gci", b2)
180-
if err != nil {
181-
return
182-
}
183-
defer os.Remove(f2)
184-
185-
cmd := "diff"
186-
187-
data, err = exec.Command(cmd, "-u", f1, f2).CombinedOutput()
188-
if len(data) > 0 {
189-
// diff exits with a non-zero status when the files don't match.
190-
// Ignore that failure as long as we get output.
191-
return replaceTempFilename(data, filename)
192-
}
193-
return
194-
}
195-
196-
func writeTempFile(dir, prefix string, data []byte) (string, error) {
197-
file, err := ioutil.TempFile(dir, prefix)
198-
if err != nil {
199-
return "", err
200-
}
201-
_, err = file.Write(data)
202-
if err1 := file.Close(); err == nil {
203-
err = err1
204-
}
205-
if err != nil {
206-
os.Remove(file.Name())
207-
return "", err
208-
}
209-
return file.Name(), nil
210-
}
211-
212-
// replaceTempFilename replaces temporary filenames in diff with actual one.
213-
//
214-
// --- /tmp/gofmt316145376 2017-02-03 19:13:00.280468375 -0500
215-
// +++ /tmp/gofmt617882815 2017-02-03 19:13:00.280468375 -0500
216-
// ...
217-
// ->
218-
// --- path/to/file.go.orig 2017-02-03 19:13:00.280468375 -0500
219-
// +++ path/to/file.go 2017-02-03 19:13:00.280468375 -0500
220-
// ...
221-
func replaceTempFilename(diff []byte, filename string) ([]byte, error) {
222-
bs := bytes.SplitN(diff, []byte{'\n'}, 3)
223-
if len(bs) < 3 {
224-
return nil, fmt.Errorf("got unexpected diff for %s", filename)
225-
}
226-
// Preserve timestamps.
227-
var t0, t1 []byte
228-
if i := bytes.LastIndexByte(bs[0], '\t'); i != -1 {
229-
t0 = bs[0][i:]
230-
}
231-
if i := bytes.LastIndexByte(bs[1], '\t'); i != -1 {
232-
t1 = bs[1][i:]
233-
}
234-
// Always print filepath with slash separator.
235-
f := filepath.ToSlash(filename)
236-
bs[0] = []byte(fmt.Sprintf("--- %s%s", f+".orig", t0))
237-
bs[1] = []byte(fmt.Sprintf("+++ %s%s", f, t1))
238-
return bytes.Join(bs, []byte{'\n'}), nil
239-
}
240-
241-
func getPkgType(pkg string) int {
242-
if !strings.Contains(pkg, dot) {
243-
return standard
244-
} else if strings.Contains(pkg, localFlag) {
245-
return local
246-
} else {
247-
return remote
248-
}
249-
}
250-
25150
func processFile(filename string, out io.Writer) error {
25251
var err error
25352

@@ -274,7 +73,7 @@ func processFile(filename string, out io.Writer) error {
27473

27574
ret := bytes.Split(src[start+len(importStartFlag):end], []byte(linebreak))
27675

277-
p := newPkg(ret)
76+
p := newPkg(ret, localFlag)
27877

27978
res := append(src[:start+len(importStartFlag)], append(p.fmt(), src[end+1:]...)...)
28079

@@ -310,21 +109,6 @@ func processFile(filename string, out io.Writer) error {
310109
}
311110

312111
return err
313-
314-
}
315-
316-
func visitFile(path string, f os.FileInfo, err error) error {
317-
if err == nil && isGoFile(f) {
318-
err = processFile(path, os.Stdout)
319-
}
320-
if err != nil {
321-
report(err)
322-
}
323-
return nil
324-
}
325-
326-
func walkDir(path string) {
327-
filepath.Walk(path, visitFile)
328112
}
329113

330114
func main() {

0 commit comments

Comments
 (0)