Skip to content

feat: add completion install api #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions completion.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package serpent

import "strings"
import (
"github.com/spf13/pflag"
)

// CompletionModeEnv is a special environment variable that is
// set when the command is being run in completion mode.
Expand All @@ -12,17 +14,18 @@ func (inv *Invocation) IsCompletionMode() bool {
return ok
}

// DefaultCompletionHandler returns a handler that prints all
// known flags and subcommands that haven't already been set to valid values.
// DefaultCompletionHandler is a handler that prints all known flags and
// subcommands that haven't been exhaustively set.
func DefaultCompletionHandler(inv *Invocation) []string {
var allResps []string
for _, cmd := range inv.Command.Children {
allResps = append(allResps, cmd.Name())
}
for _, opt := range inv.Command.Options {
_, isSlice := opt.Value.(pflag.SliceValue)
if opt.ValueSource == ValueSourceNone ||
opt.ValueSource == ValueSourceDefault ||
strings.Contains(opt.Value.Type(), "array") {
isSlice {
allResps = append(allResps, "--"+opt.Flag)
}
}
Expand Down
162 changes: 134 additions & 28 deletions completion/all.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package completion

import (
"bytes"
"errors"
"fmt"
"io"
"io/fs"
"os"
"os/user"
"path/filepath"
Expand All @@ -13,11 +16,16 @@ import (
"github.com/coder/serpent"
)

const (
completionStartTemplate = `# ============ BEGIN {{.Name}} COMPLETION ============`
completionEndTemplate = `# ============ END {{.Name}} COMPLETION ==============`
)

type Shell interface {
Name() string
InstallPath() (string, error)
UsesOwnFile() bool
WriteCompletion(io.Writer) error
ProgramName() string
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of annoying, but we need the program name during InstallCompletion to figure out the exact header we're looking for - otherwise you obviously couldn't have multiple serpent completion scripts installed.

}

const (
Expand Down Expand Up @@ -77,56 +85,154 @@ func DetectUserShell(programName string) (Shell, error) {
return nil, fmt.Errorf("default shell not found")
}

func generateCompletion(
scriptTemplate string,
) func(io.Writer, string) error {
return func(w io.Writer, programName string) error {
tmpl, err := template.New("script").Parse(scriptTemplate)
if err != nil {
return fmt.Errorf("parse template: %w", err)
}

err = tmpl.Execute(
w,
map[string]string{
"Name": programName,
},
)
if err != nil {
return fmt.Errorf("execute template: %w", err)
}
func configTemplateWriter(
w io.Writer,
cfgTemplate string,
programName string,
) error {
tmpl, err := template.New("script").Parse(cfgTemplate)
if err != nil {
return fmt.Errorf("parse template: %w", err)
}

return nil
err = tmpl.Execute(
w,
map[string]string{
"Name": programName,
},
)
if err != nil {
return fmt.Errorf("execute template: %w", err)
}

return nil
}

func InstallShellCompletion(shell Shell) error {
path, err := shell.InstallPath()
if err != nil {
return fmt.Errorf("get install path: %w", err)
}
var headerBuf bytes.Buffer
err = configTemplateWriter(&headerBuf, completionStartTemplate, shell.ProgramName())
if err != nil {
return fmt.Errorf("generate header: %w", err)
}

var footerBytes bytes.Buffer
err = configTemplateWriter(&footerBytes, completionEndTemplate, shell.ProgramName())
if err != nil {
return fmt.Errorf("generate footer: %w", err)
}

err = os.MkdirAll(filepath.Dir(path), 0o755)
if err != nil {
return fmt.Errorf("create directories: %w", err)
}

if shell.UsesOwnFile() {
err := os.WriteFile(path, nil, 0o644)
f, err := os.ReadFile(path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("read ssh config failed: %w", err)
}

before, after, err := templateConfigSplit(headerBuf.Bytes(), footerBytes.Bytes(), f)
if err != nil {
return err
}

outBuf := bytes.Buffer{}
_, _ = outBuf.Write(before)
if len(before) > 0 {
_, _ = outBuf.Write([]byte("\n"))
}
_, _ = outBuf.Write(headerBuf.Bytes())
err = shell.WriteCompletion(&outBuf)
if err != nil {
return fmt.Errorf("generate completion: %w", err)
}
_, _ = outBuf.Write(footerBytes.Bytes())
_, _ = outBuf.Write([]byte("\n"))
_, _ = outBuf.Write(after)

err = writeWithTempFileAndMove(path, &outBuf)
if err != nil {
return fmt.Errorf("write completion: %w", err)
}

return nil
}

func templateConfigSplit(header, footer, data []byte) (before, after []byte, err error) {
startCount := bytes.Count(data, header)
endCount := bytes.Count(data, footer)
if startCount > 1 || endCount > 1 {
return nil, nil, fmt.Errorf("Malformed config file: multiple config sections")
}

startIndex := bytes.Index(data, header)
endIndex := bytes.Index(data, footer)
if startIndex == -1 && endIndex != -1 {
return data, nil, fmt.Errorf("Malformed config file: missing completion header")
}
if startIndex != -1 && endIndex == -1 {
return data, nil, fmt.Errorf("Malformed config file: missing completion footer")
}
if startIndex != -1 && endIndex != -1 {
if startIndex > endIndex {
return data, nil, fmt.Errorf("Malformed config file: completion header after footer")
}
// Include leading and trailing newline, if present
start := startIndex
if start > 0 {
start--
}
end := endIndex + len(footer)
if end < len(data) {
end++
}
return data[:start], data[end:], nil
}
return data, nil, nil
}

// writeWithTempFileAndMove writes to a temporary file in the same
// directory as path and renames the temp file to the file provided in
// path. This ensure we avoid trashing the file we are writing due to
// unforeseen circumstance like filesystem full, command killed, etc.
func writeWithTempFileAndMove(path string, r io.Reader) (err error) {
dir := filepath.Dir(path)
name := filepath.Base(path)

if err = os.MkdirAll(dir, 0o700); err != nil {
return fmt.Errorf("create directory: %w", err)
}

// Create a tempfile in the same directory for ensuring write
// operation does not fail.
f, err := os.CreateTemp(dir, fmt.Sprintf(".%s.", name))
if err != nil {
return fmt.Errorf("create temp file failed: %w", err)
}
defer func() {
if err != nil {
return fmt.Errorf("create file: %w", err)
_ = os.Remove(f.Name()) // Cleanup in case a step failed.
}
}()

_, err = io.Copy(f, r)
if err != nil {
_ = f.Close()
return fmt.Errorf("write temp file failed: %w", err)
}

f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
err = f.Close()
if err != nil {
return fmt.Errorf("open file for appending: %w", err)
return fmt.Errorf("close temp file failed: %w", err)
}
defer f.Close()

err = shell.WriteCompletion(f)
err = os.Rename(f.Name(), path)
if err != nil {
return fmt.Errorf("write completion script: %w", err)
return fmt.Errorf("rename temp file failed: %w", err)
}

return nil
Expand Down
16 changes: 6 additions & 10 deletions completion/bash.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@ func (b *bash) Name() string {
return "bash"
}

// UsesOwnFile implements Shell.
func (b *bash) UsesOwnFile() bool {
return false
}

// InstallPath implements Shell.
func (b *bash) InstallPath() (string, error) {
homeDir, err := home.Dir()
Expand All @@ -42,12 +37,15 @@ func (b *bash) InstallPath() (string, error) {

// WriteCompletion implements Shell.
func (b *bash) WriteCompletion(w io.Writer) error {
return generateCompletion(bashCompletionTemplate)(w, b.programName)
return configTemplateWriter(w, bashCompletionTemplate, b.programName)
}

const bashCompletionTemplate = `
// ProgramName implements Shell.
func (b *bash) ProgramName() string {
return b.programName
}

# === BEGIN {{.Name}} COMPLETION ===
const bashCompletionTemplate = `
_generate_{{.Name}}_completions() {
# Capture the line excluding the command, and everything after the current word
local args=("${COMP_WORDS[@]:1:COMP_CWORD}")
Expand All @@ -65,6 +63,4 @@ _generate_{{.Name}}_completions() {
}
# Setup Bash to use the function for completions for '{{.Name}}'
complete -F _generate_{{.Name}}_completions {{.Name}}
# === END {{.Name}} COMPLETION ===

`
12 changes: 6 additions & 6 deletions completion/fish.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ func Fish(goos string, programName string) Shell {
return &fish{goos: goos, programName: programName}
}

// UsesOwnFile implements Shell.
func (f *fish) UsesOwnFile() bool {
return true
}

// Name implements Shell.
func (f *fish) Name() string {
return "fish"
Expand All @@ -39,7 +34,12 @@ func (f *fish) InstallPath() (string, error) {

// WriteCompletion implements Shell.
func (f *fish) WriteCompletion(w io.Writer) error {
return generateCompletion(fishCompletionTemplate)(w, f.programName)
return configTemplateWriter(w, fishCompletionTemplate, f.programName)
}

// ProgramName implements Shell.
func (f *fish) ProgramName() string {
return f.programName
}

const fishCompletionTemplate = `
Expand Down
18 changes: 7 additions & 11 deletions completion/powershell.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ type powershell struct {
programName string
}

var _ Shell = &powershell{}

// Name implements Shell.
func (p *powershell) Name() string {
return "powershell"
Expand All @@ -20,11 +22,6 @@ func Powershell(goos string, programName string) Shell {
return &powershell{goos: goos, programName: programName}
}

// UsesOwnFile implements Shell.
func (p *powershell) UsesOwnFile() bool {
return false
}

// InstallPath implements Shell.
func (p *powershell) InstallPath() (string, error) {
var (
Expand All @@ -45,14 +42,15 @@ func (p *powershell) InstallPath() (string, error) {

// WriteCompletion implements Shell.
func (p *powershell) WriteCompletion(w io.Writer) error {
return generateCompletion(pshCompletionTemplate)(w, p.programName)
return configTemplateWriter(w, pshCompletionTemplate, p.programName)
}

var _ Shell = &powershell{}
// ProgramName implements Shell.
func (p *powershell) ProgramName() string {
return p.programName
}

const pshCompletionTemplate = `

# === BEGIN {{.Name}} COMPLETION ===
# Escaping output sourced from:
# https://github.com/spf13/cobra/blob/e94f6d0dd9a5e5738dca6bce03c4b1207ffbc0ec/powershell_completions.go#L47
filter _{{.Name}}_escapeStringWithSpecialChars {
Expand Down Expand Up @@ -89,6 +87,4 @@ $_{{.Name}}_completions = {
rm env:COMPLETION_MODE
}
Register-ArgumentCompleter -CommandName {{.Name}} -ScriptBlock $_{{.Name}}_completions
# === END {{.Name}} COMPLETION ===

`
16 changes: 6 additions & 10 deletions completion/zsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@ func (z *zsh) Name() string {
return "zsh"
}

// UsesOwnFile implements Shell.
func (z *zsh) UsesOwnFile() bool {
return false
}

// InstallPath implements Shell.
func (z *zsh) InstallPath() (string, error) {
homeDir, err := home.Dir()
Expand All @@ -39,19 +34,20 @@ func (z *zsh) InstallPath() (string, error) {

// WriteCompletion implements Shell.
func (z *zsh) WriteCompletion(w io.Writer) error {
return generateCompletion(zshCompletionTemplate)(w, z.programName)
return configTemplateWriter(w, zshCompletionTemplate, z.programName)
}

const zshCompletionTemplate = `
// ProgramName implements Shell.
func (z *zsh) ProgramName() string {
return z.programName
}

# === BEGIN {{.Name}} COMPLETION ===
const zshCompletionTemplate = `
_{{.Name}}_completions() {
local -a args completions
args=("${words[@]:1:$#words}")
completions=($(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}"))
compadd -a completions
}
compdef _{{.Name}}_completions {{.Name}}
# === END {{.Name}} COMPLETION ===

`
Loading
Loading