Skip to content

Commit d46fb20

Browse files
feat: add completion install api (#18)
1 parent 91966a2 commit d46fb20

15 files changed

+599
-87
lines changed

command.go

+14-4
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,16 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
296296
}
297297

298298
func (inv *Invocation) CurWords() (prev string, cur string) {
299-
if len(inv.Args) == 1 {
299+
switch len(inv.Args) {
300+
// All the shells we support will supply at least one argument (empty string),
301+
// but we don't want to panic.
302+
case 0:
303+
cur = ""
304+
prev = ""
305+
case 1:
300306
cur = inv.Args[0]
301307
prev = ""
302-
} else {
308+
default:
303309
cur = inv.Args[len(inv.Args)-1]
304310
prev = inv.Args[len(inv.Args)-2]
305311
}
@@ -645,9 +651,13 @@ func (inv *Invocation) completeFlag(word string) []string {
645651
if opt.CompletionHandler != nil {
646652
return opt.CompletionHandler(inv)
647653
}
648-
val, ok := opt.Value.(*Enum)
654+
enum, ok := opt.Value.(*Enum)
655+
if ok {
656+
return enum.Choices
657+
}
658+
enumArr, ok := opt.Value.(*EnumArray)
649659
if ok {
650-
return val.Choices
660+
return enumArr.Choices
651661
}
652662
return nil
653663
}

command_test.go

+14-8
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ func fakeIO(i *serpent.Invocation) *ioBufs {
3434
func sampleCommand(t *testing.T) *serpent.Command {
3535
t.Helper()
3636
var (
37-
verbose bool
38-
lower bool
39-
prefix string
40-
reqBool bool
41-
reqStr string
42-
reqArr []string
43-
fileArr []string
44-
enumStr string
37+
verbose bool
38+
lower bool
39+
prefix string
40+
reqBool bool
41+
reqStr string
42+
reqArr []string
43+
reqEnumArr []string
44+
fileArr []string
45+
enumStr string
4546
)
4647
enumChoices := []string{"foo", "bar", "qux"}
4748
return &serpent.Command{
@@ -94,6 +95,11 @@ func sampleCommand(t *testing.T) *serpent.Command {
9495
FlagShorthand: "a",
9596
Value: serpent.StringArrayOf(&reqArr),
9697
},
98+
serpent.Option{
99+
Name: "req-enum-array",
100+
Flag: "req-enum-array",
101+
Value: serpent.EnumArrayOf(&reqEnumArr, enumChoices...),
102+
},
97103
},
98104
HelpHandler: func(i *serpent.Invocation) error {
99105
_, _ = i.Stdout.Write([]byte("help text.png"))

completion.go

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
package serpent
22

3+
import (
4+
"github.com/spf13/pflag"
5+
)
6+
37
// CompletionModeEnv is a special environment variable that is
48
// set when the command is being run in completion mode.
59
const CompletionModeEnv = "COMPLETION_MODE"
@@ -10,15 +14,18 @@ func (inv *Invocation) IsCompletionMode() bool {
1014
return ok
1115
}
1216

13-
// DefaultCompletionHandler returns a handler that prints all
14-
// known flags and subcommands that haven't already been set to valid values.
17+
// DefaultCompletionHandler is a handler that prints all known flags and
18+
// subcommands that haven't been exhaustively set.
1519
func DefaultCompletionHandler(inv *Invocation) []string {
1620
var allResps []string
1721
for _, cmd := range inv.Command.Children {
1822
allResps = append(allResps, cmd.Name())
1923
}
2024
for _, opt := range inv.Command.Options {
21-
if opt.ValueSource == ValueSourceNone || opt.ValueSource == ValueSourceDefault || opt.Value.Type() == "string-array" {
25+
_, isSlice := opt.Value.(pflag.SliceValue)
26+
if opt.ValueSource == ValueSourceNone ||
27+
opt.ValueSource == ValueSourceDefault ||
28+
isSlice {
2229
allResps = append(allResps, "--"+opt.Flag)
2330
}
2431
}

completion/all.go

+146-42
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,198 @@
11
package completion
22

33
import (
4+
"bytes"
5+
"errors"
46
"fmt"
57
"io"
8+
"io/fs"
69
"os"
710
"os/user"
811
"path/filepath"
12+
"runtime"
913
"strings"
1014
"text/template"
1115

1216
"github.com/coder/serpent"
17+
18+
"github.com/natefinch/atomic"
1319
)
1420

1521
const (
16-
BashShell string = "bash"
17-
FishShell string = "fish"
18-
ZShell string = "zsh"
19-
Powershell string = "powershell"
22+
completionStartTemplate = `# ============ BEGIN {{.Name}} COMPLETION ============`
23+
completionEndTemplate = `# ============ END {{.Name}} COMPLETION ==============`
2024
)
2125

22-
var shellCompletionByName = map[string]func(io.Writer, string) error{
23-
BashShell: generateCompletion(bashCompletionTemplate),
24-
FishShell: generateCompletion(fishCompletionTemplate),
25-
ZShell: generateCompletion(zshCompletionTemplate),
26-
Powershell: generateCompletion(pshCompletionTemplate),
26+
type Shell interface {
27+
Name() string
28+
InstallPath() (string, error)
29+
WriteCompletion(io.Writer) error
30+
ProgramName() string
2731
}
2832

29-
func ShellOptions(choice *string) *serpent.Enum {
30-
return serpent.EnumOf(choice, BashShell, FishShell, ZShell, Powershell)
31-
}
33+
const (
34+
ShellBash string = "bash"
35+
ShellFish string = "fish"
36+
ShellZsh string = "zsh"
37+
ShellPowershell string = "powershell"
38+
)
3239

33-
func WriteCompletion(writer io.Writer, shell string, cmdName string) error {
34-
fn, ok := shellCompletionByName[shell]
35-
if !ok {
36-
return fmt.Errorf("unknown shell %q", shell)
40+
func ShellByName(shell, programName string) (Shell, error) {
41+
switch shell {
42+
case ShellBash:
43+
return Bash(runtime.GOOS, programName), nil
44+
case ShellFish:
45+
return Fish(runtime.GOOS, programName), nil
46+
case ShellZsh:
47+
return Zsh(runtime.GOOS, programName), nil
48+
case ShellPowershell:
49+
return Powershell(runtime.GOOS, programName), nil
50+
default:
51+
return nil, fmt.Errorf("unsupported shell %q", shell)
3752
}
38-
fn(writer, cmdName)
39-
return nil
4053
}
4154

42-
func DetectUserShell() (string, error) {
55+
func ShellOptions(choice *string) *serpent.Enum {
56+
return serpent.EnumOf(choice, ShellBash, ShellFish, ShellZsh, ShellPowershell)
57+
}
58+
59+
func DetectUserShell(programName string) (Shell, error) {
4360
// Attempt to get the SHELL environment variable first
4461
if shell := os.Getenv("SHELL"); shell != "" {
45-
return filepath.Base(shell), nil
62+
return ShellByName(filepath.Base(shell), "")
4663
}
4764

4865
// Fallback: Look up the current user and parse /etc/passwd
4966
currentUser, err := user.Current()
5067
if err != nil {
51-
return "", err
68+
return nil, err
5269
}
5370

5471
// Open and parse /etc/passwd
5572
passwdFile, err := os.ReadFile("/etc/passwd")
5673
if err != nil {
57-
return "", err
74+
return nil, err
5875
}
5976

6077
lines := strings.Split(string(passwdFile), "\n")
6178
for _, line := range lines {
6279
if strings.HasPrefix(line, currentUser.Username+":") {
6380
parts := strings.Split(line, ":")
6481
if len(parts) > 6 {
65-
return filepath.Base(parts[6]), nil // The shell is typically the 7th field
82+
return ShellByName(filepath.Base(parts[6]), programName) // The shell is typically the 7th field
6683
}
6784
}
6885
}
6986

70-
return "", fmt.Errorf("default shell not found")
87+
return nil, fmt.Errorf("default shell not found")
7188
}
7289

73-
func generateCompletion(
74-
scriptTemplate string,
75-
) func(io.Writer, string) error {
76-
return func(w io.Writer, rootCmdName string) error {
77-
tmpl, err := template.New("script").Parse(scriptTemplate)
78-
if err != nil {
79-
return fmt.Errorf("parse template: %w", err)
80-
}
90+
func writeConfig(
91+
w io.Writer,
92+
cfgTemplate string,
93+
programName string,
94+
) error {
95+
tmpl, err := template.New("script").Parse(cfgTemplate)
96+
if err != nil {
97+
return fmt.Errorf("parse template: %w", err)
98+
}
8199

82-
err = tmpl.Execute(
83-
w,
84-
map[string]string{
85-
"Name": rootCmdName,
86-
},
87-
)
88-
if err != nil {
89-
return fmt.Errorf("execute template: %w", err)
90-
}
100+
err = tmpl.Execute(
101+
w,
102+
map[string]string{
103+
"Name": programName,
104+
},
105+
)
106+
if err != nil {
107+
return fmt.Errorf("execute template: %w", err)
108+
}
109+
110+
return nil
111+
}
112+
113+
func InstallShellCompletion(shell Shell) error {
114+
path, err := shell.InstallPath()
115+
if err != nil {
116+
return fmt.Errorf("get install path: %w", err)
117+
}
118+
var headerBuf bytes.Buffer
119+
err = writeConfig(&headerBuf, completionStartTemplate, shell.ProgramName())
120+
if err != nil {
121+
return fmt.Errorf("generate header: %w", err)
122+
}
123+
124+
var footerBytes bytes.Buffer
125+
err = writeConfig(&footerBytes, completionEndTemplate, shell.ProgramName())
126+
if err != nil {
127+
return fmt.Errorf("generate footer: %w", err)
128+
}
91129

92-
return nil
130+
err = os.MkdirAll(filepath.Dir(path), 0o755)
131+
if err != nil {
132+
return fmt.Errorf("create directories: %w", err)
133+
}
134+
135+
f, err := os.ReadFile(path)
136+
if err != nil && !errors.Is(err, fs.ErrNotExist) {
137+
return fmt.Errorf("read ssh config failed: %w", err)
138+
}
139+
140+
before, after, err := templateConfigSplit(headerBuf.Bytes(), footerBytes.Bytes(), f)
141+
if err != nil {
142+
return err
143+
}
144+
145+
outBuf := bytes.Buffer{}
146+
_, _ = outBuf.Write(before)
147+
if len(before) > 0 {
148+
_, _ = outBuf.Write([]byte("\n"))
149+
}
150+
_, _ = outBuf.Write(headerBuf.Bytes())
151+
err = shell.WriteCompletion(&outBuf)
152+
if err != nil {
153+
return fmt.Errorf("generate completion: %w", err)
154+
}
155+
_, _ = outBuf.Write(footerBytes.Bytes())
156+
_, _ = outBuf.Write([]byte("\n"))
157+
_, _ = outBuf.Write(after)
158+
159+
err = atomic.WriteFile(path, &outBuf)
160+
if err != nil {
161+
return fmt.Errorf("write completion: %w", err)
162+
}
163+
164+
return nil
165+
}
166+
167+
func templateConfigSplit(header, footer, data []byte) (before, after []byte, err error) {
168+
startCount := bytes.Count(data, header)
169+
endCount := bytes.Count(data, footer)
170+
if startCount > 1 || endCount > 1 {
171+
return nil, nil, fmt.Errorf("Malformed config file: multiple config sections")
172+
}
173+
174+
startIndex := bytes.Index(data, header)
175+
endIndex := bytes.Index(data, footer)
176+
if startIndex == -1 && endIndex != -1 {
177+
return data, nil, fmt.Errorf("Malformed config file: missing completion header")
178+
}
179+
if startIndex != -1 && endIndex == -1 {
180+
return data, nil, fmt.Errorf("Malformed config file: missing completion footer")
181+
}
182+
if startIndex != -1 && endIndex != -1 {
183+
if startIndex > endIndex {
184+
return data, nil, fmt.Errorf("Malformed config file: completion header after footer")
185+
}
186+
// Include leading and trailing newline, if present
187+
start := startIndex
188+
if start > 0 {
189+
start--
190+
}
191+
end := endIndex + len(footer)
192+
if end < len(data) {
193+
end++
194+
}
195+
return data[:start], data[end:], nil
93196
}
197+
return data, nil, nil
94198
}

0 commit comments

Comments
 (0)