Skip to content

Commit 15ad85c

Browse files
committed
review p2
1 parent aa6a8d3 commit 15ad85c

File tree

4 files changed

+63
-38
lines changed

4 files changed

+63
-38
lines changed

command.go

+48-37
Original file line numberDiff line numberDiff line change
@@ -407,31 +407,8 @@ func (inv *Invocation) run(state *runState) error {
407407
// Outputted completions are not filtered based on the word under the cursor, as every shell we support does this already.
408408
// We only look at the current word to figure out handler to run, or what directory to inspect.
409409
if inv.IsCompletionMode() {
410-
prev, cur := inv.curWords()
411-
inv.CurWord = cur
412-
// If the current word is a flag set using `=`, use it's handler
413-
if strings.HasPrefix(cur, "--") && strings.Contains(cur, "=") {
414-
if inv.equalsFlagHandler(cur) {
415-
return nil
416-
}
417-
}
418-
// If the previous word is a flag, then we're writing it's value
419-
// and we should check it's handler
420-
if strings.HasPrefix(prev, "--") {
421-
if inv.flagHandler(prev) {
422-
return nil
423-
}
424-
}
425-
// If the current word is the command, auto-complete it so the shell moves the cursor
426-
if inv.Command.Name() == inv.CurWord {
427-
fmt.Fprintf(inv.Stdout, "%s\n", inv.Command.Name())
428-
return nil
429-
}
430-
if inv.Command.CompletionHandler == nil {
431-
inv.Command.CompletionHandler = DefaultCompletionHandler
432-
}
433-
for _, e := range inv.Command.CompletionHandler(inv) {
434-
fmt.Fprintf(inv.Stdout, "%s\n", e)
410+
for _, e := range inv.doCompletions() {
411+
fmt.Fprintln(inv.Stdout, e)
435412
}
436413
return nil
437414
}
@@ -613,11 +590,42 @@ func (inv *Invocation) with(fn func(*Invocation)) *Invocation {
613590
return &i2
614591
}
615592

616-
func (inv *Invocation) flagHandler(word string) bool {
617-
return inv.doFlagCompletion("", word)
593+
func (inv *Invocation) doCompletions() []string {
594+
prev, cur := inv.curWords()
595+
inv.CurWord = cur
596+
// If the current word is a flag set using `=`, use it's handler
597+
if strings.HasPrefix(cur, "--") && strings.Contains(cur, "=") {
598+
if out := inv.equalsFlagCompletions(cur); out != nil {
599+
return out
600+
}
601+
}
602+
// If the previous word is a flag, then we're writing it's value
603+
// and we should check it's handler
604+
if strings.HasPrefix(prev, "--") {
605+
if out := inv.flagCompletions(prev); out != nil {
606+
return out
607+
}
608+
}
609+
// If the current word is the command, auto-complete it so the shell moves the cursor
610+
if inv.Command.Name() == inv.CurWord {
611+
return []string{inv.Command.Name()}
612+
}
613+
var completions []string
614+
615+
if inv.Command.CompletionHandler != nil {
616+
completions = append(completions, inv.Command.CompletionHandler(inv)...)
617+
}
618+
619+
completions = append(completions, DefaultCompletionHandler(inv)...)
620+
621+
return completions
618622
}
619623

620-
func (inv *Invocation) equalsFlagHandler(word string) bool {
624+
func (inv *Invocation) flagCompletions(word string) []string {
625+
return inv.doFlagCompletions("", word)
626+
}
627+
628+
func (inv *Invocation) equalsFlagCompletions(word string) []string {
621629
words := strings.Split(word, "=")
622630
word = words[0]
623631
if len(words) > 1 {
@@ -626,29 +634,32 @@ func (inv *Invocation) equalsFlagHandler(word string) bool {
626634
inv.CurWord = ""
627635
}
628636
prefix := word + "="
629-
return inv.doFlagCompletion(prefix, word)
637+
return inv.doFlagCompletions(prefix, word)
630638
}
631639

632-
func (inv *Invocation) doFlagCompletion(prefix, word string) bool {
640+
func (inv *Invocation) doFlagCompletions(prefix, word string) []string {
633641
opt := inv.Command.Options.ByFlag(word[2:])
634642
if opt == nil {
635-
return false
643+
return nil
636644
}
637645
if opt.CompletionHandler != nil {
638646
completions := opt.CompletionHandler(inv)
647+
out := make([]string, 0, len(completions))
639648
for _, completion := range completions {
640-
fmt.Fprintf(inv.Stdout, "%s%s\n", prefix, completion)
649+
out = append(out, fmt.Sprintf("%s%s", prefix, completion))
641650
}
642-
return true
651+
return out
643652
}
644653
val, ok := opt.Value.(*Enum)
645654
if ok {
646-
for _, choice := range val.Choices {
647-
fmt.Fprintf(inv.Stdout, "%s%s\n", prefix, choice)
655+
completions := val.Choices
656+
out := make([]string, 0, len(completions))
657+
for _, choice := range completions {
658+
out = append(out, fmt.Sprintf("%s%s", prefix, choice))
648659
}
649-
return true
660+
return out
650661
}
651-
return false
662+
return nil
652663
}
653664

654665
// MiddlewareFunc returns the next handler in the chain,

command_test.go

+3
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ func sampleCommand(t *testing.T) *serpent.Command {
161161
CompletionHandler: completion.FileListHandler(nil),
162162
},
163163
},
164+
CompletionHandler: func(i *serpent.Invocation) []string {
165+
return []string{"doesntexist.go"}
166+
},
164167
},
165168
},
166169
}

completion/handlers.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ func FileListHandler(filter func(info os.FileInfo) bool) serpent.CompletionHandl
3636
}
3737

3838
func listFiles(word string, filter func(info os.FileInfo) bool) []string {
39-
out := make([]string, 0, 32)
39+
// Avoid reallocating for each of the first few files we see.
40+
out := make([]string, 0, 16)
4041

4142
dir, _ := filepath.Split(word)
4243
if dir == "" {

completion_test.go

+10
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ func TestCompletion(t *testing.T) {
5555
require.Equal(t, "--req-array\n--req-bool\n--req-enum\n--req-string\n", io.Stdout.String())
5656
})
5757

58+
t.Run("ListFlagsAfterArg", func(t *testing.T) {
59+
t.Parallel()
60+
i := cmd().Invoke("altfile", "")
61+
i.Environ.Set(serpent.CompletionModeEnv, "1")
62+
io := fakeIO(i)
63+
err := i.Run()
64+
require.NoError(t, err)
65+
require.Equal(t, "doesntexist.go\n--extra\n", io.Stdout.String())
66+
})
67+
5868
t.Run("FlagExhaustive", func(t *testing.T) {
5969
t.Parallel()
6070
i := cmd().Invoke("required-flag", "--req-bool", "--req-string", "foo bar", "--req-array", "asdf", "--req-array", "qwerty")

0 commit comments

Comments
 (0)