Skip to content

feat: add completion install api #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,16 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
}

func (inv *Invocation) CurWords() (prev string, cur string) {
if len(inv.Args) == 1 {
switch len(inv.Args) {
// All the shells we support will supply at least one argument (empty string),
// but we don't want to panic.
case 0:
cur = ""
prev = ""
case 1:
cur = inv.Args[0]
prev = ""
} else {
default:
cur = inv.Args[len(inv.Args)-1]
prev = inv.Args[len(inv.Args)-2]
}
Expand Down Expand Up @@ -645,9 +651,13 @@ func (inv *Invocation) completeFlag(word string) []string {
if opt.CompletionHandler != nil {
return opt.CompletionHandler(inv)
}
val, ok := opt.Value.(*Enum)
enum, ok := opt.Value.(*Enum)
if ok {
return enum.Choices
}
enumArr, ok := opt.Value.(*EnumArray)
if ok {
return val.Choices
return enumArr.Choices
}
return nil
}
Expand Down
22 changes: 14 additions & 8 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ func fakeIO(i *serpent.Invocation) *ioBufs {
func sampleCommand(t *testing.T) *serpent.Command {
t.Helper()
var (
verbose bool
lower bool
prefix string
reqBool bool
reqStr string
reqArr []string
fileArr []string
enumStr string
verbose bool
lower bool
prefix string
reqBool bool
reqStr string
reqArr []string
reqEnumArr []string
fileArr []string
enumStr string
)
enumChoices := []string{"foo", "bar", "qux"}
return &serpent.Command{
Expand Down Expand Up @@ -94,6 +95,11 @@ func sampleCommand(t *testing.T) *serpent.Command {
FlagShorthand: "a",
Value: serpent.StringArrayOf(&reqArr),
},
serpent.Option{
Name: "req-enum-array",
Flag: "req-enum-array",
Value: serpent.EnumArrayOf(&reqEnumArr, enumChoices...),
},
},
HelpHandler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte("help text.png"))
Expand Down
13 changes: 10 additions & 3 deletions completion.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package serpent

import (
"github.com/spf13/pflag"
)

// CompletionModeEnv is a special environment variable that is
// set when the command is being run in completion mode.
const CompletionModeEnv = "COMPLETION_MODE"
Expand All @@ -10,15 +14,18 @@ func (inv *Invocation) IsCompletionMode() bool {
return ok
}

// DefaultCompletionHandler returns a handler that prints all
// known flags and subcommands that haven't already been set to valid values.
// DefaultCompletionHandler is a handler that prints all known flags and
// subcommands that haven't been exhaustively set.
func DefaultCompletionHandler(inv *Invocation) []string {
var allResps []string
for _, cmd := range inv.Command.Children {
allResps = append(allResps, cmd.Name())
}
for _, opt := range inv.Command.Options {
if opt.ValueSource == ValueSourceNone || opt.ValueSource == ValueSourceDefault || opt.Value.Type() == "string-array" {
_, isSlice := opt.Value.(pflag.SliceValue)
if opt.ValueSource == ValueSourceNone ||
opt.ValueSource == ValueSourceDefault ||
isSlice {
allResps = append(allResps, "--"+opt.Flag)
}
}
Expand Down
188 changes: 146 additions & 42 deletions completion/all.go
Original file line number Diff line number Diff line change
@@ -1,94 +1,198 @@
package completion

import (
"bytes"
"errors"
"fmt"
"io"
"io/fs"
"os"
"os/user"
"path/filepath"
"runtime"
"strings"
"text/template"

"github.com/coder/serpent"

"github.com/natefinch/atomic"
)

const (
BashShell string = "bash"
FishShell string = "fish"
ZShell string = "zsh"
Powershell string = "powershell"
completionStartTemplate = `# ============ BEGIN {{.Name}} COMPLETION ============`
completionEndTemplate = `# ============ END {{.Name}} COMPLETION ==============`
)

var shellCompletionByName = map[string]func(io.Writer, string) error{
BashShell: generateCompletion(bashCompletionTemplate),
FishShell: generateCompletion(fishCompletionTemplate),
ZShell: generateCompletion(zshCompletionTemplate),
Powershell: generateCompletion(pshCompletionTemplate),
type Shell interface {
Name() string
InstallPath() (string, error)
WriteCompletion(io.Writer) error
ProgramName() string
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of annoying, but we need the program name during InstallCompletion to figure out the exact header we're looking for - otherwise you obviously couldn't have multiple serpent completion scripts installed.

}

func ShellOptions(choice *string) *serpent.Enum {
return serpent.EnumOf(choice, BashShell, FishShell, ZShell, Powershell)
}
const (
ShellBash string = "bash"
ShellFish string = "fish"
ShellZsh string = "zsh"
ShellPowershell string = "powershell"
)

func WriteCompletion(writer io.Writer, shell string, cmdName string) error {
fn, ok := shellCompletionByName[shell]
if !ok {
return fmt.Errorf("unknown shell %q", shell)
func ShellByName(shell, programName string) (Shell, error) {
switch shell {
case ShellBash:
return Bash(runtime.GOOS, programName), nil
case ShellFish:
return Fish(runtime.GOOS, programName), nil
case ShellZsh:
return Zsh(runtime.GOOS, programName), nil
case ShellPowershell:
return Powershell(runtime.GOOS, programName), nil
default:
return nil, fmt.Errorf("unsupported shell %q", shell)
}
fn(writer, cmdName)
return nil
}

func DetectUserShell() (string, error) {
func ShellOptions(choice *string) *serpent.Enum {
return serpent.EnumOf(choice, ShellBash, ShellFish, ShellZsh, ShellPowershell)
}

func DetectUserShell(programName string) (Shell, error) {
Copy link
Member Author

@ethanndickson ethanndickson Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a bit awkward, but we need to provide it to use in the template itself, e.g. function _name, and potentially in the install path, e.g. name.fish, and it's much nicer for this function to return a Shell, instead of a string that you just have to pass back to ShellByName.

Alternatively, InstallPath and WriteCompletion could take it as an argument, but if you call one you're probably going to call the other, so I thought it'd be better to just store it.

// Attempt to get the SHELL environment variable first
if shell := os.Getenv("SHELL"); shell != "" {
return filepath.Base(shell), nil
return ShellByName(filepath.Base(shell), "")
}

// Fallback: Look up the current user and parse /etc/passwd
currentUser, err := user.Current()
if err != nil {
return "", err
return nil, err
}

// Open and parse /etc/passwd
passwdFile, err := os.ReadFile("/etc/passwd")
if err != nil {
return "", err
return nil, err
}

lines := strings.Split(string(passwdFile), "\n")
for _, line := range lines {
if strings.HasPrefix(line, currentUser.Username+":") {
parts := strings.Split(line, ":")
if len(parts) > 6 {
return filepath.Base(parts[6]), nil // The shell is typically the 7th field
return ShellByName(filepath.Base(parts[6]), programName) // The shell is typically the 7th field
}
}
}

return "", fmt.Errorf("default shell not found")
return nil, fmt.Errorf("default shell not found")
}

func generateCompletion(
scriptTemplate string,
) func(io.Writer, string) error {
return func(w io.Writer, rootCmdName string) error {
tmpl, err := template.New("script").Parse(scriptTemplate)
if err != nil {
return fmt.Errorf("parse template: %w", err)
}
func writeConfig(
w io.Writer,
cfgTemplate string,
programName string,
) error {
tmpl, err := template.New("script").Parse(cfgTemplate)
if err != nil {
return fmt.Errorf("parse template: %w", err)
}

err = tmpl.Execute(
w,
map[string]string{
"Name": rootCmdName,
},
)
if err != nil {
return fmt.Errorf("execute template: %w", err)
}
err = tmpl.Execute(
w,
map[string]string{
"Name": programName,
},
)
if err != nil {
return fmt.Errorf("execute template: %w", err)
}

return nil
}

func InstallShellCompletion(shell Shell) error {
path, err := shell.InstallPath()
if err != nil {
return fmt.Errorf("get install path: %w", err)
}
var headerBuf bytes.Buffer
err = writeConfig(&headerBuf, completionStartTemplate, shell.ProgramName())
if err != nil {
return fmt.Errorf("generate header: %w", err)
}

var footerBytes bytes.Buffer
err = writeConfig(&footerBytes, completionEndTemplate, shell.ProgramName())
if err != nil {
return fmt.Errorf("generate footer: %w", err)
}

return nil
err = os.MkdirAll(filepath.Dir(path), 0o755)
if err != nil {
return fmt.Errorf("create directories: %w", err)
}

f, err := os.ReadFile(path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("read ssh config failed: %w", err)
}

before, after, err := templateConfigSplit(headerBuf.Bytes(), footerBytes.Bytes(), f)
if err != nil {
return err
}

outBuf := bytes.Buffer{}
_, _ = outBuf.Write(before)
if len(before) > 0 {
_, _ = outBuf.Write([]byte("\n"))
}
_, _ = outBuf.Write(headerBuf.Bytes())
err = shell.WriteCompletion(&outBuf)
if err != nil {
return fmt.Errorf("generate completion: %w", err)
}
_, _ = outBuf.Write(footerBytes.Bytes())
_, _ = outBuf.Write([]byte("\n"))
_, _ = outBuf.Write(after)

err = atomic.WriteFile(path, &outBuf)
if err != nil {
return fmt.Errorf("write completion: %w", err)
}

return nil
}

func templateConfigSplit(header, footer, data []byte) (before, after []byte, err error) {
startCount := bytes.Count(data, header)
endCount := bytes.Count(data, footer)
if startCount > 1 || endCount > 1 {
return nil, nil, fmt.Errorf("Malformed config file: multiple config sections")
}

startIndex := bytes.Index(data, header)
endIndex := bytes.Index(data, footer)
if startIndex == -1 && endIndex != -1 {
return data, nil, fmt.Errorf("Malformed config file: missing completion header")
}
if startIndex != -1 && endIndex == -1 {
return data, nil, fmt.Errorf("Malformed config file: missing completion footer")
}
if startIndex != -1 && endIndex != -1 {
if startIndex > endIndex {
return data, nil, fmt.Errorf("Malformed config file: completion header after footer")
}
// Include leading and trailing newline, if present
start := startIndex
if start > 0 {
start--
}
end := endIndex + len(footer)
if end < len(data) {
end++
}
return data[:start], data[end:], nil
}
return data, nil, nil
}
42 changes: 41 additions & 1 deletion completion/bash.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,46 @@
package completion

import (
"io"
"path/filepath"

home "github.com/mitchellh/go-homedir"
)

type bash struct {
goos string
programName string
}

var _ Shell = &bash{}

func Bash(goos string, programName string) Shell {
return &bash{goos: goos, programName: programName}
}

func (b *bash) Name() string {
return "bash"
}

func (b *bash) InstallPath() (string, error) {
homeDir, err := home.Dir()
if err != nil {
return "", err
}
if b.goos == "darwin" {
return filepath.Join(homeDir, ".bash_profile"), nil
}
return filepath.Join(homeDir, ".bashrc"), nil
}

func (b *bash) WriteCompletion(w io.Writer) error {
return writeConfig(w, bashCompletionTemplate, b.programName)
}

func (b *bash) ProgramName() string {
return b.programName
}

const bashCompletionTemplate = `
_generate_{{.Name}}_completions() {
# Capture the line excluding the command, and everything after the current word
Expand All @@ -16,7 +57,6 @@ _generate_{{.Name}}_completions() {
COMPREPLY=()
fi
}

# Setup Bash to use the function for completions for '{{.Name}}'
complete -F _generate_{{.Name}}_completions {{.Name}}
`
Loading
Loading