From 802129c6853c098d2e633dbe20aa97e4bc4bb9c7 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Fri, 9 Aug 2024 07:00:48 +0000 Subject: [PATCH 1/4] feat: add completion install api --- command.go | 18 ++++-- command_test.go | 22 +++++--- completion.go | 6 +- completion/all.go | 93 ++++++++++++++++++++++--------- completion/bash.go | 50 ++++++++++++++++- completion/fish.go | 42 ++++++++++++++ completion/powershell.go | 54 +++++++++++++++++- completion/zsh.go | 47 +++++++++++++++- completion_test.go | 103 ++++++++++++++++++++++++++++++++++- example/completetest/main.go | 22 +++++--- go.mod | 1 + go.sum | 2 + help.go | 2 + values.go | 48 ++++++++++++++++ 14 files changed, 455 insertions(+), 55 deletions(-) diff --git a/command.go b/command.go index 542dfff..9d10dfe 100644 --- a/command.go +++ b/command.go @@ -296,10 +296,16 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet { } func (inv *Invocation) CurWords() (prev string, cur string) { - if len(inv.Args) == 1 { + switch len(inv.Args) { + // All the shells we support will supply at least one argument (empty string), + // but we don't want to panic. + case 0: + cur = "" + prev = "" + case 1: cur = inv.Args[0] prev = "" - } else { + default: cur = inv.Args[len(inv.Args)-1] prev = inv.Args[len(inv.Args)-2] } @@ -645,9 +651,13 @@ func (inv *Invocation) completeFlag(word string) []string { if opt.CompletionHandler != nil { return opt.CompletionHandler(inv) } - val, ok := opt.Value.(*Enum) + enum, ok := opt.Value.(*Enum) + if ok { + return enum.Choices + } + enumArr, ok := opt.Value.(*EnumArray) if ok { - return val.Choices + return enumArr.Choices } return nil } diff --git a/command_test.go b/command_test.go index de6c12d..41541f4 100644 --- a/command_test.go +++ b/command_test.go @@ -34,14 +34,15 @@ func fakeIO(i *serpent.Invocation) *ioBufs { func sampleCommand(t *testing.T) *serpent.Command { t.Helper() var ( - verbose bool - lower bool - prefix string - reqBool bool - reqStr string - reqArr []string - fileArr []string - enumStr string + verbose bool + lower bool + prefix string + reqBool bool + reqStr string + reqArr []string + reqEnumArr []string + fileArr []string + enumStr string ) enumChoices := []string{"foo", "bar", "qux"} return &serpent.Command{ @@ -94,6 +95,11 @@ func sampleCommand(t *testing.T) *serpent.Command { FlagShorthand: "a", Value: serpent.StringArrayOf(&reqArr), }, + serpent.Option{ + Name: "req-enum-array", + Flag: "req-enum-array", + Value: serpent.EnumArrayOf(&reqEnumArr, enumChoices...), + }, }, HelpHandler: func(i *serpent.Invocation) error { _, _ = i.Stdout.Write([]byte("help text.png")) diff --git a/completion.go b/completion.go index a0fb779..616775b 100644 --- a/completion.go +++ b/completion.go @@ -1,5 +1,7 @@ package serpent +import "strings" + // CompletionModeEnv is a special environment variable that is // set when the command is being run in completion mode. const CompletionModeEnv = "COMPLETION_MODE" @@ -18,7 +20,9 @@ func DefaultCompletionHandler(inv *Invocation) []string { allResps = append(allResps, cmd.Name()) } for _, opt := range inv.Command.Options { - if opt.ValueSource == ValueSourceNone || opt.ValueSource == ValueSourceDefault || opt.Value.Type() == "string-array" { + if opt.ValueSource == ValueSourceNone || + opt.ValueSource == ValueSourceDefault || + strings.Contains(opt.Value.Type(), "array") { allResps = append(allResps, "--"+opt.Flag) } } diff --git a/completion/all.go b/completion/all.go index b20c254..4a48aa8 100644 --- a/completion/all.go +++ b/completion/all.go @@ -6,55 +6,62 @@ import ( "os" "os/user" "path/filepath" + "runtime" "strings" "text/template" "github.com/coder/serpent" ) +type Shell interface { + Name() string + InstallPath() (string, error) + UsesOwnFile() bool + WriteCompletion(io.Writer) error +} + const ( - BashShell string = "bash" - FishShell string = "fish" - ZShell string = "zsh" - Powershell string = "powershell" + ShellBash string = "bash" + ShellFish string = "fish" + ShellZsh string = "zsh" + ShellPowershell string = "powershell" ) -var shellCompletionByName = map[string]func(io.Writer, string) error{ - BashShell: generateCompletion(bashCompletionTemplate), - FishShell: generateCompletion(fishCompletionTemplate), - ZShell: generateCompletion(zshCompletionTemplate), - Powershell: generateCompletion(pshCompletionTemplate), +func ShellByName(shell, programName string) (Shell, error) { + switch shell { + case ShellBash: + return Bash(runtime.GOOS, programName), nil + case ShellFish: + return Fish(runtime.GOOS, programName), nil + case ShellZsh: + return Zsh(runtime.GOOS, programName), nil + case ShellPowershell: + return Powershell(runtime.GOOS, programName), nil + default: + return nil, fmt.Errorf("unsupported shell %q", shell) + } } func ShellOptions(choice *string) *serpent.Enum { - return serpent.EnumOf(choice, BashShell, FishShell, ZShell, Powershell) -} - -func WriteCompletion(writer io.Writer, shell string, cmdName string) error { - fn, ok := shellCompletionByName[shell] - if !ok { - return fmt.Errorf("unknown shell %q", shell) - } - fn(writer, cmdName) - return nil + return serpent.EnumOf(choice, ShellBash, ShellFish, ShellZsh, ShellPowershell) } -func DetectUserShell() (string, error) { +func DetectUserShell(programName string) (Shell, error) { // Attempt to get the SHELL environment variable first if shell := os.Getenv("SHELL"); shell != "" { - return filepath.Base(shell), nil + return ShellByName(filepath.Base(shell), "") } // Fallback: Look up the current user and parse /etc/passwd currentUser, err := user.Current() if err != nil { - return "", err + return nil, err } // Open and parse /etc/passwd passwdFile, err := os.ReadFile("/etc/passwd") if err != nil { - return "", err + return nil, err } lines := strings.Split(string(passwdFile), "\n") @@ -62,18 +69,18 @@ func DetectUserShell() (string, error) { if strings.HasPrefix(line, currentUser.Username+":") { parts := strings.Split(line, ":") if len(parts) > 6 { - return filepath.Base(parts[6]), nil // The shell is typically the 7th field + return ShellByName(filepath.Base(parts[6]), programName) // The shell is typically the 7th field } } } - return "", fmt.Errorf("default shell not found") + return nil, fmt.Errorf("default shell not found") } func generateCompletion( scriptTemplate string, ) func(io.Writer, string) error { - return func(w io.Writer, rootCmdName string) error { + return func(w io.Writer, programName string) error { tmpl, err := template.New("script").Parse(scriptTemplate) if err != nil { return fmt.Errorf("parse template: %w", err) @@ -82,7 +89,7 @@ func generateCompletion( err = tmpl.Execute( w, map[string]string{ - "Name": rootCmdName, + "Name": programName, }, ) if err != nil { @@ -92,3 +99,35 @@ func generateCompletion( return nil } } + +func InstallShellCompletion(shell Shell) error { + path, err := shell.InstallPath() + if err != nil { + return fmt.Errorf("get install path: %w", err) + } + + err = os.MkdirAll(filepath.Dir(path), 0o755) + if err != nil { + return fmt.Errorf("create directories: %w", err) + } + + if shell.UsesOwnFile() { + err := os.WriteFile(path, nil, 0o644) + if err != nil { + return fmt.Errorf("create file: %w", err) + } + } + + f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + if err != nil { + return fmt.Errorf("open file for appending: %w", err) + } + defer f.Close() + + err = shell.WriteCompletion(f) + if err != nil { + return fmt.Errorf("write completion script: %w", err) + } + + return nil +} diff --git a/completion/bash.go b/completion/bash.go index fad4069..8b0d219 100644 --- a/completion/bash.go +++ b/completion/bash.go @@ -1,6 +1,53 @@ package completion +import ( + "io" + "path/filepath" + + home "github.com/mitchellh/go-homedir" +) + +type bash struct { + goos string + programName string +} + +var _ Shell = &bash{} + +func Bash(goos string, programName string) Shell { + return &bash{goos: goos, programName: programName} +} + +// Name implements Shell. +func (b *bash) Name() string { + return "bash" +} + +// UsesOwnFile implements Shell. +func (b *bash) UsesOwnFile() bool { + return false +} + +// InstallPath implements Shell. +func (b *bash) InstallPath() (string, error) { + homeDir, err := home.Dir() + if err != nil { + return "", err + } + if b.goos == "darwin" { + return filepath.Join(homeDir, ".bash_profile"), nil + } + return filepath.Join(homeDir, ".bashrc"), nil +} + +// WriteCompletion implements Shell. +func (b *bash) WriteCompletion(w io.Writer) error { + return generateCompletion(bashCompletionTemplate)(w, b.programName) +} + const bashCompletionTemplate = ` + +# === BEGIN {{.Name}} COMPLETION === _generate_{{.Name}}_completions() { # Capture the line excluding the command, and everything after the current word local args=("${COMP_WORDS[@]:1:COMP_CWORD}") @@ -16,7 +63,8 @@ _generate_{{.Name}}_completions() { COMPREPLY=() fi } - # Setup Bash to use the function for completions for '{{.Name}}' complete -F _generate_{{.Name}}_completions {{.Name}} +# === END {{.Name}} COMPLETION === + ` diff --git a/completion/fish.go b/completion/fish.go index f9b2793..037b264 100644 --- a/completion/fish.go +++ b/completion/fish.go @@ -1,5 +1,47 @@ package completion +import ( + "io" + "path/filepath" + + home "github.com/mitchellh/go-homedir" +) + +type fish struct { + goos string + programName string +} + +var _ Shell = &fish{} + +func Fish(goos string, programName string) Shell { + return &fish{goos: goos, programName: programName} +} + +// UsesOwnFile implements Shell. +func (f *fish) UsesOwnFile() bool { + return true +} + +// Name implements Shell. +func (f *fish) Name() string { + return "fish" +} + +// InstallPath implements Shell. +func (f *fish) InstallPath() (string, error) { + homeDir, err := home.Dir() + if err != nil { + return "", err + } + return filepath.Join(homeDir, ".config/fish/completions/", f.programName+".fish"), nil +} + +// WriteCompletion implements Shell. +func (f *fish) WriteCompletion(w io.Writer) error { + return generateCompletion(fishCompletionTemplate)(w, f.programName) +} + const fishCompletionTemplate = ` function _{{.Name}}_completions # Capture the full command line as an array diff --git a/completion/powershell.go b/completion/powershell.go index e30c61b..1114fbb 100644 --- a/completion/powershell.go +++ b/completion/powershell.go @@ -1,7 +1,58 @@ package completion +import ( + "io" + "os/exec" + "strings" +) + +type powershell struct { + goos string + programName string +} + +// Name implements Shell. +func (p *powershell) Name() string { + return "powershell" +} + +func Powershell(goos string, programName string) Shell { + return &powershell{goos: goos, programName: programName} +} + +// UsesOwnFile implements Shell. +func (p *powershell) UsesOwnFile() bool { + return false +} + +// InstallPath implements Shell. +func (p *powershell) InstallPath() (string, error) { + var ( + path []byte + err error + ) + cmd := "$PROFILE.CurrentUserAllHosts" + if p.goos == "windows" { + path, err = exec.Command("powershell", cmd).CombinedOutput() + } else { + path, err = exec.Command("pwsh", "-Command", cmd).CombinedOutput() + } + if err != nil { + return "", err + } + return strings.TrimSpace(string(path)), nil +} + +// WriteCompletion implements Shell. +func (p *powershell) WriteCompletion(w io.Writer) error { + return generateCompletion(pshCompletionTemplate)(w, p.programName) +} + +var _ Shell = &powershell{} + const pshCompletionTemplate = ` +# === BEGIN {{.Name}} COMPLETION === # Escaping output sourced from: # https://github.com/spf13/cobra/blob/e94f6d0dd9a5e5738dca6bce03c4b1207ffbc0ec/powershell_completions.go#L47 filter _{{.Name}}_escapeStringWithSpecialChars { @@ -37,6 +88,7 @@ $_{{.Name}}_completions = { } rm env:COMPLETION_MODE } - Register-ArgumentCompleter -CommandName {{.Name}} -ScriptBlock $_{{.Name}}_completions +# === END {{.Name}} COMPLETION === + ` diff --git a/completion/zsh.go b/completion/zsh.go index a8ee4a8..af6b553 100644 --- a/completion/zsh.go +++ b/completion/zsh.go @@ -1,12 +1,57 @@ package completion +import ( + "io" + "path/filepath" + + home "github.com/mitchellh/go-homedir" +) + +type zsh struct { + goos string + programName string +} + +var _ Shell = &zsh{} + +func Zsh(goos string, programName string) Shell { + return &zsh{goos: goos, programName: programName} +} + +// Name implements Shell. +func (z *zsh) Name() string { + return "zsh" +} + +// UsesOwnFile implements Shell. +func (z *zsh) UsesOwnFile() bool { + return false +} + +// InstallPath implements Shell. +func (z *zsh) InstallPath() (string, error) { + homeDir, err := home.Dir() + if err != nil { + return "", err + } + return filepath.Join(homeDir, ".zshrc"), nil +} + +// WriteCompletion implements Shell. +func (z *zsh) WriteCompletion(w io.Writer) error { + return generateCompletion(zshCompletionTemplate)(w, z.programName) +} + const zshCompletionTemplate = ` + +# === BEGIN {{.Name}} COMPLETION === _{{.Name}}_completions() { local -a args completions args=("${words[@]:1:$#words}") completions=($(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}")) compadd -a completions } - compdef _{{.Name}}_completions {{.Name}} +# === END {{.Name}} COMPLETION === + ` diff --git a/completion_test.go b/completion_test.go index 5ca160a..1ac925c 100644 --- a/completion_test.go +++ b/completion_test.go @@ -2,11 +2,14 @@ package serpent_test import ( "fmt" + "io" "os" + "path/filepath" "strings" "testing" serpent "github.com/coder/serpent" + "github.com/coder/serpent/completion" "github.com/stretchr/testify/require" ) @@ -52,7 +55,7 @@ func TestCompletion(t *testing.T) { io := fakeIO(i) err := i.Run() require.NoError(t, err) - require.Equal(t, "--req-array\n--req-bool\n--req-enum\n--req-string\n", io.Stdout.String()) + require.Equal(t, "--req-array\n--req-bool\n--req-enum\n--req-enum-array\n--req-string\n", io.Stdout.String()) }) t.Run("ListFlagsAfterArg", func(t *testing.T) { @@ -72,7 +75,7 @@ func TestCompletion(t *testing.T) { io := fakeIO(i) err := i.Run() require.NoError(t, err) - require.Equal(t, "--req-array\n--req-enum\n", io.Stdout.String()) + require.Equal(t, "--req-array\n--req-enum\n--req-enum-array\n", io.Stdout.String()) }) t.Run("FlagShorthand", func(t *testing.T) { @@ -82,7 +85,7 @@ func TestCompletion(t *testing.T) { io := fakeIO(i) err := i.Run() require.NoError(t, err) - require.Equal(t, "--req-array\n--req-enum\n", io.Stdout.String()) + require.Equal(t, "--req-array\n--req-enum\n--req-enum-array\n", io.Stdout.String()) }) t.Run("NoOptDefValueFlag", func(t *testing.T) { @@ -125,6 +128,36 @@ func TestCompletion(t *testing.T) { require.Equal(t, "--req-enum=foo\n--req-enum=bar\n--req-enum=qux\n", io.Stdout.String()) }) + t.Run("EnumArrayOK", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-enum-array", "") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "foo\nbar\nqux\n", io.Stdout.String()) + }) + + t.Run("EnumArrayEqualsOK", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-enum-array=") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-enum-array=foo\n--req-enum-array=bar\n--req-enum-array=qux\n", io.Stdout.String()) + }) + + t.Run("EnumArrayEqualsBeginQuotesOK", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-enum-array=\"") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-enum-array=foo\n--req-enum-array=bar\n--req-enum-array=qux\n", io.Stdout.String()) + }) + } func TestFileCompletion(t *testing.T) { @@ -201,3 +234,67 @@ func TestFileCompletion(t *testing.T) { }) } } + +func TestCompletionInstall(t *testing.T) { + t.Parallel() + + t.Run("InstallingAppend", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "fake.sh") + f, err := os.Create(path) + require.NoError(t, err) + f.Write([]byte("FAKE_SCRIPT")) + f.Close() + + shell := &fakeShell{baseInstallDir: dir, useOwn: false} + err = completion.InstallShellCompletion(shell) + require.NoError(t, err) + contents, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, "FAKE_SCRIPTFAKE_COMPLETION", string(contents)) + }) + + t.Run("InstallReplace", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "fake.sh") + f, err := os.Create(path) + require.NoError(t, err) + f.Write([]byte("FAKE_SCRIPT")) + f.Close() + + shell := &fakeShell{baseInstallDir: dir, useOwn: true} + err = completion.InstallShellCompletion(shell) + require.NoError(t, err) + contents, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, "FAKE_COMPLETION", string(contents)) + }) +} + +type fakeShell struct { + baseInstallDir string + useOwn bool +} + +var _ completion.Shell = &fakeShell{} + +// InstallPath implements completion.Shell. +func (f *fakeShell) InstallPath() (string, error) { + return filepath.Join(f.baseInstallDir, "fake.sh"), nil +} + +// Name implements completion.Shell. +func (f *fakeShell) Name() string { + return "fake" +} + +// UsesOwnFile implements completion.Shell. +func (f *fakeShell) UsesOwnFile() bool { + return f.useOwn +} + +// WriteCompletion implements completion.Shell. +func (f *fakeShell) WriteCompletion(w io.Writer) error { + _, err := w.Write([]byte("FAKE_COMPLETION")) + return err +} diff --git a/example/completetest/main.go b/example/completetest/main.go index 920e705..add6d5c 100644 --- a/example/completetest/main.go +++ b/example/completetest/main.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "os" "strings" @@ -11,24 +12,21 @@ import ( // installCommand returns a serpent command that helps // a user configure their shell to use serpent's completion. func installCommand() *serpent.Command { - defaultShell, err := completion.DetectUserShell() - if err != nil { - defaultShell = "bash" - } - var shell string return &serpent.Command{ - Use: "completion", + Use: "completion [--shell ]", Short: "Generate completion scripts for the given shell.", Handler: func(inv *serpent.Invocation) error { - completion.WriteCompletion(inv.Stdout, shell, inv.Command.Parent.Name()) - return nil + defaultShell, err := completion.DetectUserShell(inv.Command.Parent.Name()) + if err != nil { + return fmt.Errorf("Could not detect user shell, please specify a shell using `--shell`") + } + return defaultShell.WriteCompletion(inv.Stdout) }, Options: serpent.OptionSet{ { Flag: "shell", FlagShorthand: "s", - Default: defaultShell, Description: "The shell to generate a completion script for.", Value: completion.ShellOptions(&shell), }, @@ -42,6 +40,7 @@ func main() { upper bool fileType string fileArr []string + types []string ) cmd := serpent.Command{ Use: "completetest ", @@ -109,6 +108,11 @@ func main() { Description: "Extra files.", Value: serpent.StringArrayOf(&fileArr), }, + { + Name: "types", + Flag: "types", + Value: serpent.EnumArrayOf(&types, "binary", "text"), + }, }, CompletionHandler: completion.FileHandler(nil), Middleware: serpent.RequireNArgs(1), diff --git a/go.mod b/go.mod index 8bd432e..e70ab99 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( cdr.dev/slog v1.6.2-0.20240126064726-20367d4aede6 github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0 github.com/hashicorp/go-multierror v1.1.1 + github.com/mitchellh/go-homedir v1.1.0 github.com/mitchellh/go-wordwrap v1.0.1 github.com/muesli/termenv v0.15.2 github.com/pion/udp v0.1.4 diff --git a/go.sum b/go.sum index 63d175d..011ecb6 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= diff --git a/help.go b/help.go index 3dc49a4..533be11 100644 --- a/help.go +++ b/help.go @@ -95,6 +95,8 @@ var defaultHelpTemplate = func() *template.Template { switch v := opt.Value.(type) { case *Enum: return strings.Join(v.Choices, "|") + case *EnumArray: + return fmt.Sprintf("[%s]", strings.Join(v.Choices, "|")) default: return v.Type() } diff --git a/values.go b/values.go index 25fc20f..9c0ed83 100644 --- a/values.go +++ b/values.go @@ -628,3 +628,51 @@ func (p *YAMLConfigPath) String() string { func (*YAMLConfigPath) Type() string { return "yaml-config-path" } + +var _ pflag.Value = (*EnumArray)(nil) + +type EnumArray struct { + Choices []string + Value *[]string +} + +func (e *EnumArray) Set(v string) error { + if v == "" { + *e.Value = nil + return nil + } + ss, err := readAsCSV(v) + if err != nil { + return err + } + for _, s := range ss { + found := false + for _, c := range e.Choices { + if s == c { + found = true + break + } + } + if !found { + return xerrors.Errorf("invalid choice: %s, should be one of %v", s, e.Choices) + } + } + *e.Value = append(*e.Value, ss...) + return nil +} + +func (e *EnumArray) String() string { + return writeAsCSV(*e.Value) +} + +func (e *EnumArray) Type() string { + return fmt.Sprintf("enum-array[%v]", strings.Join(e.Choices, "\\|")) +} + +func EnumArrayOf(v *[]string, choices ...string) *EnumArray { + choices = append([]string{}, choices...) + return &EnumArray{ + Choices: choices, + Value: v, + } +} From daea69684dec6d20ee129d5a1b9cedce38d885ae Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Tue, 13 Aug 2024 10:31:52 +0000 Subject: [PATCH 2/4] idempotent install --- completion.go | 11 ++- completion/all.go | 162 ++++++++++++++++++++++++++++++++------- completion/bash.go | 16 ++-- completion/fish.go | 12 +-- completion/powershell.go | 18 ++--- completion/zsh.go | 16 ++-- completion_test.go | 105 +++++++++++++++++-------- values.go | 46 ++++++++--- 8 files changed, 278 insertions(+), 108 deletions(-) diff --git a/completion.go b/completion.go index 616775b..1da93a0 100644 --- a/completion.go +++ b/completion.go @@ -1,6 +1,8 @@ package serpent -import "strings" +import ( + "github.com/spf13/pflag" +) // CompletionModeEnv is a special environment variable that is // set when the command is being run in completion mode. @@ -12,17 +14,18 @@ func (inv *Invocation) IsCompletionMode() bool { return ok } -// DefaultCompletionHandler returns a handler that prints all -// known flags and subcommands that haven't already been set to valid values. +// DefaultCompletionHandler is a handler that prints all known flags and +// subcommands that haven't been exhaustively set. func DefaultCompletionHandler(inv *Invocation) []string { var allResps []string for _, cmd := range inv.Command.Children { allResps = append(allResps, cmd.Name()) } for _, opt := range inv.Command.Options { + _, isSlice := opt.Value.(pflag.SliceValue) if opt.ValueSource == ValueSourceNone || opt.ValueSource == ValueSourceDefault || - strings.Contains(opt.Value.Type(), "array") { + isSlice { allResps = append(allResps, "--"+opt.Flag) } } diff --git a/completion/all.go b/completion/all.go index 4a48aa8..ae07732 100644 --- a/completion/all.go +++ b/completion/all.go @@ -1,8 +1,11 @@ package completion import ( + "bytes" + "errors" "fmt" "io" + "io/fs" "os" "os/user" "path/filepath" @@ -13,11 +16,16 @@ import ( "github.com/coder/serpent" ) +const ( + completionStartTemplate = `# ============ BEGIN {{.Name}} COMPLETION ============` + completionEndTemplate = `# ============ END {{.Name}} COMPLETION ==============` +) + type Shell interface { Name() string InstallPath() (string, error) - UsesOwnFile() bool WriteCompletion(io.Writer) error + ProgramName() string } const ( @@ -77,27 +85,27 @@ func DetectUserShell(programName string) (Shell, error) { return nil, fmt.Errorf("default shell not found") } -func generateCompletion( - scriptTemplate string, -) func(io.Writer, string) error { - return func(w io.Writer, programName string) error { - tmpl, err := template.New("script").Parse(scriptTemplate) - if err != nil { - return fmt.Errorf("parse template: %w", err) - } - - err = tmpl.Execute( - w, - map[string]string{ - "Name": programName, - }, - ) - if err != nil { - return fmt.Errorf("execute template: %w", err) - } +func configTemplateWriter( + w io.Writer, + cfgTemplate string, + programName string, +) error { + tmpl, err := template.New("script").Parse(cfgTemplate) + if err != nil { + return fmt.Errorf("parse template: %w", err) + } - return nil + err = tmpl.Execute( + w, + map[string]string{ + "Name": programName, + }, + ) + if err != nil { + return fmt.Errorf("execute template: %w", err) } + + return nil } func InstallShellCompletion(shell Shell) error { @@ -105,28 +113,126 @@ func InstallShellCompletion(shell Shell) error { if err != nil { return fmt.Errorf("get install path: %w", err) } + var headerBuf bytes.Buffer + err = configTemplateWriter(&headerBuf, completionStartTemplate, shell.ProgramName()) + if err != nil { + return fmt.Errorf("generate header: %w", err) + } + + var footerBytes bytes.Buffer + err = configTemplateWriter(&footerBytes, completionEndTemplate, shell.ProgramName()) + if err != nil { + return fmt.Errorf("generate footer: %w", err) + } err = os.MkdirAll(filepath.Dir(path), 0o755) if err != nil { return fmt.Errorf("create directories: %w", err) } - if shell.UsesOwnFile() { - err := os.WriteFile(path, nil, 0o644) + f, err := os.ReadFile(path) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("read ssh config failed: %w", err) + } + + before, after, err := templateConfigSplit(headerBuf.Bytes(), footerBytes.Bytes(), f) + if err != nil { + return err + } + + outBuf := bytes.Buffer{} + _, _ = outBuf.Write(before) + if len(before) > 0 { + _, _ = outBuf.Write([]byte("\n")) + } + _, _ = outBuf.Write(headerBuf.Bytes()) + err = shell.WriteCompletion(&outBuf) + if err != nil { + return fmt.Errorf("generate completion: %w", err) + } + _, _ = outBuf.Write(footerBytes.Bytes()) + _, _ = outBuf.Write([]byte("\n")) + _, _ = outBuf.Write(after) + + err = writeWithTempFileAndMove(path, &outBuf) + if err != nil { + return fmt.Errorf("write completion: %w", err) + } + + return nil +} + +func templateConfigSplit(header, footer, data []byte) (before, after []byte, err error) { + startCount := bytes.Count(data, header) + endCount := bytes.Count(data, footer) + if startCount > 1 || endCount > 1 { + return nil, nil, fmt.Errorf("Malformed config file: multiple config sections") + } + + startIndex := bytes.Index(data, header) + endIndex := bytes.Index(data, footer) + if startIndex == -1 && endIndex != -1 { + return data, nil, fmt.Errorf("Malformed config file: missing completion header") + } + if startIndex != -1 && endIndex == -1 { + return data, nil, fmt.Errorf("Malformed config file: missing completion footer") + } + if startIndex != -1 && endIndex != -1 { + if startIndex > endIndex { + return data, nil, fmt.Errorf("Malformed config file: completion header after footer") + } + // Include leading and trailing newline, if present + start := startIndex + if start > 0 { + start-- + } + end := endIndex + len(footer) + if end < len(data) { + end++ + } + return data[:start], data[end:], nil + } + return data, nil, nil +} + +// writeWithTempFileAndMove writes to a temporary file in the same +// directory as path and renames the temp file to the file provided in +// path. This ensure we avoid trashing the file we are writing due to +// unforeseen circumstance like filesystem full, command killed, etc. +func writeWithTempFileAndMove(path string, r io.Reader) (err error) { + dir := filepath.Dir(path) + name := filepath.Base(path) + + if err = os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("create directory: %w", err) + } + + // Create a tempfile in the same directory for ensuring write + // operation does not fail. + f, err := os.CreateTemp(dir, fmt.Sprintf(".%s.", name)) + if err != nil { + return fmt.Errorf("create temp file failed: %w", err) + } + defer func() { if err != nil { - return fmt.Errorf("create file: %w", err) + _ = os.Remove(f.Name()) // Cleanup in case a step failed. } + }() + + _, err = io.Copy(f, r) + if err != nil { + _ = f.Close() + return fmt.Errorf("write temp file failed: %w", err) } - f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + err = f.Close() if err != nil { - return fmt.Errorf("open file for appending: %w", err) + return fmt.Errorf("close temp file failed: %w", err) } - defer f.Close() - err = shell.WriteCompletion(f) + err = os.Rename(f.Name(), path) if err != nil { - return fmt.Errorf("write completion script: %w", err) + return fmt.Errorf("rename temp file failed: %w", err) } return nil diff --git a/completion/bash.go b/completion/bash.go index 8b0d219..47cf158 100644 --- a/completion/bash.go +++ b/completion/bash.go @@ -23,11 +23,6 @@ func (b *bash) Name() string { return "bash" } -// UsesOwnFile implements Shell. -func (b *bash) UsesOwnFile() bool { - return false -} - // InstallPath implements Shell. func (b *bash) InstallPath() (string, error) { homeDir, err := home.Dir() @@ -42,12 +37,15 @@ func (b *bash) InstallPath() (string, error) { // WriteCompletion implements Shell. func (b *bash) WriteCompletion(w io.Writer) error { - return generateCompletion(bashCompletionTemplate)(w, b.programName) + return configTemplateWriter(w, bashCompletionTemplate, b.programName) } -const bashCompletionTemplate = ` +// ProgramName implements Shell. +func (b *bash) ProgramName() string { + return b.programName +} -# === BEGIN {{.Name}} COMPLETION === +const bashCompletionTemplate = ` _generate_{{.Name}}_completions() { # Capture the line excluding the command, and everything after the current word local args=("${COMP_WORDS[@]:1:COMP_CWORD}") @@ -65,6 +63,4 @@ _generate_{{.Name}}_completions() { } # Setup Bash to use the function for completions for '{{.Name}}' complete -F _generate_{{.Name}}_completions {{.Name}} -# === END {{.Name}} COMPLETION === - ` diff --git a/completion/fish.go b/completion/fish.go index 037b264..dbff929 100644 --- a/completion/fish.go +++ b/completion/fish.go @@ -18,11 +18,6 @@ func Fish(goos string, programName string) Shell { return &fish{goos: goos, programName: programName} } -// UsesOwnFile implements Shell. -func (f *fish) UsesOwnFile() bool { - return true -} - // Name implements Shell. func (f *fish) Name() string { return "fish" @@ -39,7 +34,12 @@ func (f *fish) InstallPath() (string, error) { // WriteCompletion implements Shell. func (f *fish) WriteCompletion(w io.Writer) error { - return generateCompletion(fishCompletionTemplate)(w, f.programName) + return configTemplateWriter(w, fishCompletionTemplate, f.programName) +} + +// ProgramName implements Shell. +func (f *fish) ProgramName() string { + return f.programName } const fishCompletionTemplate = ` diff --git a/completion/powershell.go b/completion/powershell.go index 1114fbb..f9e7133 100644 --- a/completion/powershell.go +++ b/completion/powershell.go @@ -11,6 +11,8 @@ type powershell struct { programName string } +var _ Shell = &powershell{} + // Name implements Shell. func (p *powershell) Name() string { return "powershell" @@ -20,11 +22,6 @@ func Powershell(goos string, programName string) Shell { return &powershell{goos: goos, programName: programName} } -// UsesOwnFile implements Shell. -func (p *powershell) UsesOwnFile() bool { - return false -} - // InstallPath implements Shell. func (p *powershell) InstallPath() (string, error) { var ( @@ -45,14 +42,15 @@ func (p *powershell) InstallPath() (string, error) { // WriteCompletion implements Shell. func (p *powershell) WriteCompletion(w io.Writer) error { - return generateCompletion(pshCompletionTemplate)(w, p.programName) + return configTemplateWriter(w, pshCompletionTemplate, p.programName) } -var _ Shell = &powershell{} +// ProgramName implements Shell. +func (p *powershell) ProgramName() string { + return p.programName +} const pshCompletionTemplate = ` - -# === BEGIN {{.Name}} COMPLETION === # Escaping output sourced from: # https://github.com/spf13/cobra/blob/e94f6d0dd9a5e5738dca6bce03c4b1207ffbc0ec/powershell_completions.go#L47 filter _{{.Name}}_escapeStringWithSpecialChars { @@ -89,6 +87,4 @@ $_{{.Name}}_completions = { rm env:COMPLETION_MODE } Register-ArgumentCompleter -CommandName {{.Name}} -ScriptBlock $_{{.Name}}_completions -# === END {{.Name}} COMPLETION === - ` diff --git a/completion/zsh.go b/completion/zsh.go index af6b553..5e0ddd7 100644 --- a/completion/zsh.go +++ b/completion/zsh.go @@ -23,11 +23,6 @@ func (z *zsh) Name() string { return "zsh" } -// UsesOwnFile implements Shell. -func (z *zsh) UsesOwnFile() bool { - return false -} - // InstallPath implements Shell. func (z *zsh) InstallPath() (string, error) { homeDir, err := home.Dir() @@ -39,12 +34,15 @@ func (z *zsh) InstallPath() (string, error) { // WriteCompletion implements Shell. func (z *zsh) WriteCompletion(w io.Writer) error { - return generateCompletion(zshCompletionTemplate)(w, z.programName) + return configTemplateWriter(w, zshCompletionTemplate, z.programName) } -const zshCompletionTemplate = ` +// ProgramName implements Shell. +func (z *zsh) ProgramName() string { + return z.programName +} -# === BEGIN {{.Name}} COMPLETION === +const zshCompletionTemplate = ` _{{.Name}}_completions() { local -a args completions args=("${words[@]:1:$#words}") @@ -52,6 +50,4 @@ _{{.Name}}_completions() { compadd -a completions } compdef _{{.Name}}_completions {{.Name}} -# === END {{.Name}} COMPLETION === - ` diff --git a/completion_test.go b/completion_test.go index 1ac925c..d90a42a 100644 --- a/completion_test.go +++ b/completion_test.go @@ -238,42 +238,92 @@ func TestFileCompletion(t *testing.T) { func TestCompletionInstall(t *testing.T) { t.Parallel() - t.Run("InstallingAppend", func(t *testing.T) { + t.Run("InstallingNew", func(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "fake.sh") - f, err := os.Create(path) - require.NoError(t, err) - f.Write([]byte("FAKE_SCRIPT")) - f.Close() + shell := &fakeShell{baseInstallDir: dir, programName: "fake"} - shell := &fakeShell{baseInstallDir: dir, useOwn: false} - err = completion.InstallShellCompletion(shell) + err := completion.InstallShellCompletion(shell) require.NoError(t, err) contents, err := os.ReadFile(path) require.NoError(t, err) - require.Equal(t, "FAKE_SCRIPTFAKE_COMPLETION", string(contents)) + require.Equal(t, "# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\n", string(contents)) }) - t.Run("InstallReplace", func(t *testing.T) { - dir := t.TempDir() - path := filepath.Join(dir, "fake.sh") - f, err := os.Create(path) - require.NoError(t, err) - f.Write([]byte("FAKE_SCRIPT")) - f.Close() + cases := []struct { + name string + input []byte + expected []byte + errMsg string + }{ + { + name: "InstallingAppend", + input: []byte("FAKE_SCRIPT"), + expected: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\n"), + }, + { + name: "InstallReplaceBeginning", + input: []byte("# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), + expected: []byte("# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), + }, + { + name: "InstallReplaceMiddle", + input: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), + expected: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), + }, + { + name: "InstallReplaceEnd", + input: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n# ============ END fake COMPLETION ==============\n"), + expected: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\n"), + }, + { + name: "InstallNoFooter", + input: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n"), + errMsg: "missing completion footer", + }, + { + name: "InstallNoHeader", + input: []byte("OLD_COMPLETION\n# ============ END fake COMPLETION ==============\n"), + errMsg: "missing completion header", + }, + { + name: "InstallBadOrder", + input: []byte("# ============ END fake COMPLETION ==============\nFAKE_COMPLETION\n# ============ BEGIN fake COMPLETION =============="), + errMsg: "header after footer", + }, + } - shell := &fakeShell{baseInstallDir: dir, useOwn: true} - err = completion.InstallShellCompletion(shell) - require.NoError(t, err) - contents, err := os.ReadFile(path) - require.NoError(t, err) - require.Equal(t, "FAKE_COMPLETION", string(contents)) - }) + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "fake.sh") + err := os.WriteFile(path, tc.input, 0o644) + require.NoError(t, err) + + shell := &fakeShell{baseInstallDir: dir, programName: "fake"} + err = completion.InstallShellCompletion(shell) + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + return + } else { + require.NoError(t, err) + contents, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, tc.expected, contents) + } + }) + } } type fakeShell struct { baseInstallDir string - useOwn bool + programName string +} + +// ProgramName implements completion.Shell. +func (f *fakeShell) ProgramName() string { + return f.programName } var _ completion.Shell = &fakeShell{} @@ -285,16 +335,11 @@ func (f *fakeShell) InstallPath() (string, error) { // Name implements completion.Shell. func (f *fakeShell) Name() string { - return "fake" -} - -// UsesOwnFile implements completion.Shell. -func (f *fakeShell) UsesOwnFile() bool { - return f.useOwn + return "Fake" } // WriteCompletion implements completion.Shell. func (f *fakeShell) WriteCompletion(w io.Writer) error { - _, err := w.Write([]byte("FAKE_COMPLETION")) + _, err := w.Write([]byte("\nFAKE_COMPLETION\n")) return err } diff --git a/values.go b/values.go index 9c0ed83..7aaaeec 100644 --- a/values.go +++ b/values.go @@ -191,7 +191,10 @@ func (String) Type() string { return "string" } -var _ pflag.SliceValue = &StringArray{} +var ( + _ pflag.SliceValue = &StringArray{} + _ pflag.Value = &StringArray{} +) // StringArray is a slice of strings that implements pflag.Value and pflag.SliceValue. type StringArray []string @@ -629,6 +632,7 @@ func (*YAMLConfigPath) Type() string { return "yaml-config-path" } +var _ pflag.SliceValue = (*EnumArray)(nil) var _ pflag.Value = (*EnumArray)(nil) type EnumArray struct { @@ -636,15 +640,22 @@ type EnumArray struct { Value *[]string } -func (e *EnumArray) Set(v string) error { - if v == "" { - *e.Value = nil - return nil - } - ss, err := readAsCSV(v) - if err != nil { - return err +// Append implements pflag.SliceValue. +func (e *EnumArray) Append(s string) error { + for _, c := range e.Choices { + if s == c { + *e.Value = append(*e.Value, s) + return nil + } } + return xerrors.Errorf("invalid choice: %s, should be one of %v", s, e.Choices) +} + +func (e *EnumArray) GetSlice() []string { + return *e.Value +} + +func (e *EnumArray) Replace(ss []string) error { for _, s := range ss { found := false for _, c := range e.Choices { @@ -657,6 +668,23 @@ func (e *EnumArray) Set(v string) error { return xerrors.Errorf("invalid choice: %s, should be one of %v", s, e.Choices) } } + *e.Value = ss + return nil +} + +func (e *EnumArray) Set(v string) error { + if v == "" { + *e.Value = nil + return nil + } + ss, err := readAsCSV(v) + if err != nil { + return err + } + err = e.Replace(ss) + if err != nil { + return err + } *e.Value = append(*e.Value, ss...) return nil } From 64b9c1ee84a8474261c235d204cdd7a95f53f7a5 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Wed, 14 Aug 2024 04:35:49 +0000 Subject: [PATCH 3/4] add natefinch/atomic --- completion/all.go | 53 +++++----------------------------------- completion/bash.go | 6 +---- completion/fish.go | 6 +---- completion/powershell.go | 6 +---- completion/zsh.go | 6 +---- completion_test.go | 4 --- go.mod | 1 + go.sum | 2 ++ values.go | 1 - 9 files changed, 13 insertions(+), 72 deletions(-) diff --git a/completion/all.go b/completion/all.go index ae07732..ca6e2cf 100644 --- a/completion/all.go +++ b/completion/all.go @@ -14,6 +14,8 @@ import ( "text/template" "github.com/coder/serpent" + + "github.com/natefinch/atomic" ) const ( @@ -85,7 +87,7 @@ func DetectUserShell(programName string) (Shell, error) { return nil, fmt.Errorf("default shell not found") } -func configTemplateWriter( +func writeConfig( w io.Writer, cfgTemplate string, programName string, @@ -114,13 +116,13 @@ func InstallShellCompletion(shell Shell) error { return fmt.Errorf("get install path: %w", err) } var headerBuf bytes.Buffer - err = configTemplateWriter(&headerBuf, completionStartTemplate, shell.ProgramName()) + err = writeConfig(&headerBuf, completionStartTemplate, shell.ProgramName()) if err != nil { return fmt.Errorf("generate header: %w", err) } var footerBytes bytes.Buffer - err = configTemplateWriter(&footerBytes, completionEndTemplate, shell.ProgramName()) + err = writeConfig(&footerBytes, completionEndTemplate, shell.ProgramName()) if err != nil { return fmt.Errorf("generate footer: %w", err) } @@ -154,7 +156,7 @@ func InstallShellCompletion(shell Shell) error { _, _ = outBuf.Write([]byte("\n")) _, _ = outBuf.Write(after) - err = writeWithTempFileAndMove(path, &outBuf) + err = atomic.WriteFile(path, &outBuf) if err != nil { return fmt.Errorf("write completion: %w", err) } @@ -194,46 +196,3 @@ func templateConfigSplit(header, footer, data []byte) (before, after []byte, err } return data, nil, nil } - -// writeWithTempFileAndMove writes to a temporary file in the same -// directory as path and renames the temp file to the file provided in -// path. This ensure we avoid trashing the file we are writing due to -// unforeseen circumstance like filesystem full, command killed, etc. -func writeWithTempFileAndMove(path string, r io.Reader) (err error) { - dir := filepath.Dir(path) - name := filepath.Base(path) - - if err = os.MkdirAll(dir, 0o700); err != nil { - return fmt.Errorf("create directory: %w", err) - } - - // Create a tempfile in the same directory for ensuring write - // operation does not fail. - f, err := os.CreateTemp(dir, fmt.Sprintf(".%s.", name)) - if err != nil { - return fmt.Errorf("create temp file failed: %w", err) - } - defer func() { - if err != nil { - _ = os.Remove(f.Name()) // Cleanup in case a step failed. - } - }() - - _, err = io.Copy(f, r) - if err != nil { - _ = f.Close() - return fmt.Errorf("write temp file failed: %w", err) - } - - err = f.Close() - if err != nil { - return fmt.Errorf("close temp file failed: %w", err) - } - - err = os.Rename(f.Name(), path) - if err != nil { - return fmt.Errorf("rename temp file failed: %w", err) - } - - return nil -} diff --git a/completion/bash.go b/completion/bash.go index 47cf158..8e3a1b1 100644 --- a/completion/bash.go +++ b/completion/bash.go @@ -18,12 +18,10 @@ func Bash(goos string, programName string) Shell { return &bash{goos: goos, programName: programName} } -// Name implements Shell. func (b *bash) Name() string { return "bash" } -// InstallPath implements Shell. func (b *bash) InstallPath() (string, error) { homeDir, err := home.Dir() if err != nil { @@ -35,12 +33,10 @@ func (b *bash) InstallPath() (string, error) { return filepath.Join(homeDir, ".bashrc"), nil } -// WriteCompletion implements Shell. func (b *bash) WriteCompletion(w io.Writer) error { - return configTemplateWriter(w, bashCompletionTemplate, b.programName) + return writeConfig(w, bashCompletionTemplate, b.programName) } -// ProgramName implements Shell. func (b *bash) ProgramName() string { return b.programName } diff --git a/completion/fish.go b/completion/fish.go index dbff929..7e5a21e 100644 --- a/completion/fish.go +++ b/completion/fish.go @@ -18,12 +18,10 @@ func Fish(goos string, programName string) Shell { return &fish{goos: goos, programName: programName} } -// Name implements Shell. func (f *fish) Name() string { return "fish" } -// InstallPath implements Shell. func (f *fish) InstallPath() (string, error) { homeDir, err := home.Dir() if err != nil { @@ -32,12 +30,10 @@ func (f *fish) InstallPath() (string, error) { return filepath.Join(homeDir, ".config/fish/completions/", f.programName+".fish"), nil } -// WriteCompletion implements Shell. func (f *fish) WriteCompletion(w io.Writer) error { - return configTemplateWriter(w, fishCompletionTemplate, f.programName) + return writeConfig(w, fishCompletionTemplate, f.programName) } -// ProgramName implements Shell. func (f *fish) ProgramName() string { return f.programName } diff --git a/completion/powershell.go b/completion/powershell.go index f9e7133..b002065 100644 --- a/completion/powershell.go +++ b/completion/powershell.go @@ -13,7 +13,6 @@ type powershell struct { var _ Shell = &powershell{} -// Name implements Shell. func (p *powershell) Name() string { return "powershell" } @@ -22,7 +21,6 @@ func Powershell(goos string, programName string) Shell { return &powershell{goos: goos, programName: programName} } -// InstallPath implements Shell. func (p *powershell) InstallPath() (string, error) { var ( path []byte @@ -40,12 +38,10 @@ func (p *powershell) InstallPath() (string, error) { return strings.TrimSpace(string(path)), nil } -// WriteCompletion implements Shell. func (p *powershell) WriteCompletion(w io.Writer) error { - return configTemplateWriter(w, pshCompletionTemplate, p.programName) + return writeConfig(w, pshCompletionTemplate, p.programName) } -// ProgramName implements Shell. func (p *powershell) ProgramName() string { return p.programName } diff --git a/completion/zsh.go b/completion/zsh.go index 5e0ddd7..831dc65 100644 --- a/completion/zsh.go +++ b/completion/zsh.go @@ -18,12 +18,10 @@ func Zsh(goos string, programName string) Shell { return &zsh{goos: goos, programName: programName} } -// Name implements Shell. func (z *zsh) Name() string { return "zsh" } -// InstallPath implements Shell. func (z *zsh) InstallPath() (string, error) { homeDir, err := home.Dir() if err != nil { @@ -32,12 +30,10 @@ func (z *zsh) InstallPath() (string, error) { return filepath.Join(homeDir, ".zshrc"), nil } -// WriteCompletion implements Shell. func (z *zsh) WriteCompletion(w io.Writer) error { - return configTemplateWriter(w, zshCompletionTemplate, z.programName) + return writeConfig(w, zshCompletionTemplate, z.programName) } -// ProgramName implements Shell. func (z *zsh) ProgramName() string { return z.programName } diff --git a/completion_test.go b/completion_test.go index d90a42a..f1527be 100644 --- a/completion_test.go +++ b/completion_test.go @@ -321,24 +321,20 @@ type fakeShell struct { programName string } -// ProgramName implements completion.Shell. func (f *fakeShell) ProgramName() string { return f.programName } var _ completion.Shell = &fakeShell{} -// InstallPath implements completion.Shell. func (f *fakeShell) InstallPath() (string, error) { return filepath.Join(f.baseInstallDir, "fake.sh"), nil } -// Name implements completion.Shell. func (f *fakeShell) Name() string { return "Fake" } -// WriteCompletion implements completion.Shell. func (f *fakeShell) WriteCompletion(w io.Writer) error { _, err := w.Write([]byte("\nFAKE_COMPLETION\n")) return err diff --git a/go.mod b/go.mod index e70ab99..1c2880c 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/mitchellh/go-homedir v1.1.0 github.com/mitchellh/go-wordwrap v1.0.1 github.com/muesli/termenv v0.15.2 + github.com/natefinch/atomic v1.0.1 github.com/pion/udp v0.1.4 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 diff --git a/go.sum b/go.sum index 011ecb6..a1106fc 100644 --- a/go.sum +++ b/go.sum @@ -54,6 +54,8 @@ github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= +github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A= +github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/transport/v2 v2.0.0 h1:bsMYyqHCbkvHwj+eNCFBuxtlKndKfyGI2vaQmM3fIE4= github.com/pion/transport/v2 v2.0.0/go.mod h1:HS2MEBJTwD+1ZI2eSXSvHJx/HnzQqRy2/LXxt6eVMHc= diff --git a/values.go b/values.go index 7aaaeec..0cd1993 100644 --- a/values.go +++ b/values.go @@ -640,7 +640,6 @@ type EnumArray struct { Value *[]string } -// Append implements pflag.SliceValue. func (e *EnumArray) Append(s string) error { for _, c := range e.Choices { if s == c { From 2b6ea89c9e01f390860e41d6e28d2c02dea9c632 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Thu, 15 Aug 2024 05:50:20 +0000 Subject: [PATCH 4/4] handle completions with spaces + case insensitive enums --- completion/bash.go | 18 +++++++++--------- completion/powershell.go | 2 +- completion/zsh.go | 2 +- option.go | 2 +- values.go | 15 ++++++++------- 5 files changed, 20 insertions(+), 19 deletions(-) diff --git a/completion/bash.go b/completion/bash.go index 8e3a1b1..14282eb 100644 --- a/completion/bash.go +++ b/completion/bash.go @@ -43,19 +43,19 @@ func (b *bash) ProgramName() string { const bashCompletionTemplate = ` _generate_{{.Name}}_completions() { - # Capture the line excluding the command, and everything after the current word local args=("${COMP_WORDS[@]:1:COMP_CWORD}") - # Set COMPLETION_MODE and call the command with the arguments, capturing the output - local completions=$(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}") + declare -a output + mapfile -t output < <(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}") - # Use the command's output to generate completions for the current word - COMPREPLY=($(compgen -W "$completions" -- "${COMP_WORDS[COMP_CWORD]}")) + declare -a completions + mapfile -t completions < <( compgen -W "$(printf '%q ' "${output[@]}")" -- "$2" ) - # Ensure no files are shown, even if there are no matches - if [ ${#COMPREPLY[@]} -eq 0 ]; then - COMPREPLY=() - fi + local comp + COMPREPLY=() + for comp in "${completions[@]}"; do + COMPREPLY+=("$(printf "%q" "$comp")") + done } # Setup Bash to use the function for completions for '{{.Name}}' complete -F _generate_{{.Name}}_completions {{.Name}} diff --git a/completion/powershell.go b/completion/powershell.go index b002065..cc083bd 100644 --- a/completion/powershell.go +++ b/completion/powershell.go @@ -80,7 +80,7 @@ $_{{.Name}}_completions = { Invoke-Expression $Command | Where-Object { $_ -like "$wordToComplete*" } | ForEach-Object { "$_" | _{{.Name}}_escapeStringWithSpecialChars } - rm env:COMPLETION_MODE + $env:COMPLETION_MODE = '' } Register-ArgumentCompleter -CommandName {{.Name}} -ScriptBlock $_{{.Name}}_completions ` diff --git a/completion/zsh.go b/completion/zsh.go index 831dc65..b2793b0 100644 --- a/completion/zsh.go +++ b/completion/zsh.go @@ -42,7 +42,7 @@ const zshCompletionTemplate = ` _{{.Name}}_completions() { local -a args completions args=("${words[@]:1:$#words}") - completions=($(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}")) + completions=(${(f)"$(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}")"}) compadd -a completions } compdef _{{.Name}}_completions {{.Name}} diff --git a/option.go b/option.go index fccc67e..2780fc6 100644 --- a/option.go +++ b/option.go @@ -352,7 +352,7 @@ func (optSet OptionSet) ByFlag(flag string) *Option { } for i := range optSet { opt := &optSet[i] - if opt.Flag == flag || opt.FlagShorthand == flag { + if opt.Flag == flag { return opt } } diff --git a/values.go b/values.go index 0cd1993..79c8e2c 100644 --- a/values.go +++ b/values.go @@ -530,7 +530,7 @@ func EnumOf(v *string, choices ...string) *Enum { func (e *Enum) Set(v string) error { for _, c := range e.Choices { - if v == c { + if strings.EqualFold(v, c) { *e.Value = v return nil } @@ -642,7 +642,7 @@ type EnumArray struct { func (e *EnumArray) Append(s string) error { for _, c := range e.Choices { - if s == c { + if strings.EqualFold(s, c) { *e.Value = append(*e.Value, s) return nil } @@ -658,7 +658,7 @@ func (e *EnumArray) Replace(ss []string) error { for _, s := range ss { found := false for _, c := range e.Choices { - if s == c { + if strings.EqualFold(s, c) { found = true break } @@ -680,11 +680,12 @@ func (e *EnumArray) Set(v string) error { if err != nil { return err } - err = e.Replace(ss) - if err != nil { - return err + for _, s := range ss { + err := e.Append(s) + if err != nil { + return err + } } - *e.Value = append(*e.Value, ss...) return nil }