diff --git a/command.go b/command.go index 542dfff..9d10dfe 100644 --- a/command.go +++ b/command.go @@ -296,10 +296,16 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet { } func (inv *Invocation) CurWords() (prev string, cur string) { - if len(inv.Args) == 1 { + switch len(inv.Args) { + // All the shells we support will supply at least one argument (empty string), + // but we don't want to panic. + case 0: + cur = "" + prev = "" + case 1: cur = inv.Args[0] prev = "" - } else { + default: cur = inv.Args[len(inv.Args)-1] prev = inv.Args[len(inv.Args)-2] } @@ -645,9 +651,13 @@ func (inv *Invocation) completeFlag(word string) []string { if opt.CompletionHandler != nil { return opt.CompletionHandler(inv) } - val, ok := opt.Value.(*Enum) + enum, ok := opt.Value.(*Enum) + if ok { + return enum.Choices + } + enumArr, ok := opt.Value.(*EnumArray) if ok { - return val.Choices + return enumArr.Choices } return nil } diff --git a/command_test.go b/command_test.go index de6c12d..41541f4 100644 --- a/command_test.go +++ b/command_test.go @@ -34,14 +34,15 @@ func fakeIO(i *serpent.Invocation) *ioBufs { func sampleCommand(t *testing.T) *serpent.Command { t.Helper() var ( - verbose bool - lower bool - prefix string - reqBool bool - reqStr string - reqArr []string - fileArr []string - enumStr string + verbose bool + lower bool + prefix string + reqBool bool + reqStr string + reqArr []string + reqEnumArr []string + fileArr []string + enumStr string ) enumChoices := []string{"foo", "bar", "qux"} return &serpent.Command{ @@ -94,6 +95,11 @@ func sampleCommand(t *testing.T) *serpent.Command { FlagShorthand: "a", Value: serpent.StringArrayOf(&reqArr), }, + serpent.Option{ + Name: "req-enum-array", + Flag: "req-enum-array", + Value: serpent.EnumArrayOf(&reqEnumArr, enumChoices...), + }, }, HelpHandler: func(i *serpent.Invocation) error { _, _ = i.Stdout.Write([]byte("help text.png")) diff --git a/completion.go b/completion.go index a0fb779..1da93a0 100644 --- a/completion.go +++ b/completion.go @@ -1,5 +1,9 @@ package serpent +import ( + "github.com/spf13/pflag" +) + // CompletionModeEnv is a special environment variable that is // set when the command is being run in completion mode. const CompletionModeEnv = "COMPLETION_MODE" @@ -10,15 +14,18 @@ func (inv *Invocation) IsCompletionMode() bool { return ok } -// DefaultCompletionHandler returns a handler that prints all -// known flags and subcommands that haven't already been set to valid values. +// DefaultCompletionHandler is a handler that prints all known flags and +// subcommands that haven't been exhaustively set. func DefaultCompletionHandler(inv *Invocation) []string { var allResps []string for _, cmd := range inv.Command.Children { allResps = append(allResps, cmd.Name()) } for _, opt := range inv.Command.Options { - if opt.ValueSource == ValueSourceNone || opt.ValueSource == ValueSourceDefault || opt.Value.Type() == "string-array" { + _, isSlice := opt.Value.(pflag.SliceValue) + if opt.ValueSource == ValueSourceNone || + opt.ValueSource == ValueSourceDefault || + isSlice { allResps = append(allResps, "--"+opt.Flag) } } diff --git a/completion/all.go b/completion/all.go index b20c254..ca6e2cf 100644 --- a/completion/all.go +++ b/completion/all.go @@ -1,60 +1,77 @@ package completion import ( + "bytes" + "errors" "fmt" "io" + "io/fs" "os" "os/user" "path/filepath" + "runtime" "strings" "text/template" "github.com/coder/serpent" + + "github.com/natefinch/atomic" ) const ( - BashShell string = "bash" - FishShell string = "fish" - ZShell string = "zsh" - Powershell string = "powershell" + completionStartTemplate = `# ============ BEGIN {{.Name}} COMPLETION ============` + completionEndTemplate = `# ============ END {{.Name}} COMPLETION ==============` ) -var shellCompletionByName = map[string]func(io.Writer, string) error{ - BashShell: generateCompletion(bashCompletionTemplate), - FishShell: generateCompletion(fishCompletionTemplate), - ZShell: generateCompletion(zshCompletionTemplate), - Powershell: generateCompletion(pshCompletionTemplate), +type Shell interface { + Name() string + InstallPath() (string, error) + WriteCompletion(io.Writer) error + ProgramName() string } -func ShellOptions(choice *string) *serpent.Enum { - return serpent.EnumOf(choice, BashShell, FishShell, ZShell, Powershell) -} +const ( + ShellBash string = "bash" + ShellFish string = "fish" + ShellZsh string = "zsh" + ShellPowershell string = "powershell" +) -func WriteCompletion(writer io.Writer, shell string, cmdName string) error { - fn, ok := shellCompletionByName[shell] - if !ok { - return fmt.Errorf("unknown shell %q", shell) +func ShellByName(shell, programName string) (Shell, error) { + switch shell { + case ShellBash: + return Bash(runtime.GOOS, programName), nil + case ShellFish: + return Fish(runtime.GOOS, programName), nil + case ShellZsh: + return Zsh(runtime.GOOS, programName), nil + case ShellPowershell: + return Powershell(runtime.GOOS, programName), nil + default: + return nil, fmt.Errorf("unsupported shell %q", shell) } - fn(writer, cmdName) - return nil } -func DetectUserShell() (string, error) { +func ShellOptions(choice *string) *serpent.Enum { + return serpent.EnumOf(choice, ShellBash, ShellFish, ShellZsh, ShellPowershell) +} + +func DetectUserShell(programName string) (Shell, error) { // Attempt to get the SHELL environment variable first if shell := os.Getenv("SHELL"); shell != "" { - return filepath.Base(shell), nil + return ShellByName(filepath.Base(shell), "") } // Fallback: Look up the current user and parse /etc/passwd currentUser, err := user.Current() if err != nil { - return "", err + return nil, err } // Open and parse /etc/passwd passwdFile, err := os.ReadFile("/etc/passwd") if err != nil { - return "", err + return nil, err } lines := strings.Split(string(passwdFile), "\n") @@ -62,33 +79,120 @@ func DetectUserShell() (string, error) { if strings.HasPrefix(line, currentUser.Username+":") { parts := strings.Split(line, ":") if len(parts) > 6 { - return filepath.Base(parts[6]), nil // The shell is typically the 7th field + return ShellByName(filepath.Base(parts[6]), programName) // The shell is typically the 7th field } } } - return "", fmt.Errorf("default shell not found") + return nil, fmt.Errorf("default shell not found") } -func generateCompletion( - scriptTemplate string, -) func(io.Writer, string) error { - return func(w io.Writer, rootCmdName string) error { - tmpl, err := template.New("script").Parse(scriptTemplate) - if err != nil { - return fmt.Errorf("parse template: %w", err) - } +func writeConfig( + w io.Writer, + cfgTemplate string, + programName string, +) error { + tmpl, err := template.New("script").Parse(cfgTemplate) + if err != nil { + return fmt.Errorf("parse template: %w", err) + } - err = tmpl.Execute( - w, - map[string]string{ - "Name": rootCmdName, - }, - ) - if err != nil { - return fmt.Errorf("execute template: %w", err) - } + err = tmpl.Execute( + w, + map[string]string{ + "Name": programName, + }, + ) + if err != nil { + return fmt.Errorf("execute template: %w", err) + } + + return nil +} + +func InstallShellCompletion(shell Shell) error { + path, err := shell.InstallPath() + if err != nil { + return fmt.Errorf("get install path: %w", err) + } + var headerBuf bytes.Buffer + err = writeConfig(&headerBuf, completionStartTemplate, shell.ProgramName()) + if err != nil { + return fmt.Errorf("generate header: %w", err) + } + + var footerBytes bytes.Buffer + err = writeConfig(&footerBytes, completionEndTemplate, shell.ProgramName()) + if err != nil { + return fmt.Errorf("generate footer: %w", err) + } - return nil + err = os.MkdirAll(filepath.Dir(path), 0o755) + if err != nil { + return fmt.Errorf("create directories: %w", err) + } + + f, err := os.ReadFile(path) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("read ssh config failed: %w", err) + } + + before, after, err := templateConfigSplit(headerBuf.Bytes(), footerBytes.Bytes(), f) + if err != nil { + return err + } + + outBuf := bytes.Buffer{} + _, _ = outBuf.Write(before) + if len(before) > 0 { + _, _ = outBuf.Write([]byte("\n")) + } + _, _ = outBuf.Write(headerBuf.Bytes()) + err = shell.WriteCompletion(&outBuf) + if err != nil { + return fmt.Errorf("generate completion: %w", err) + } + _, _ = outBuf.Write(footerBytes.Bytes()) + _, _ = outBuf.Write([]byte("\n")) + _, _ = outBuf.Write(after) + + err = atomic.WriteFile(path, &outBuf) + if err != nil { + return fmt.Errorf("write completion: %w", err) + } + + return nil +} + +func templateConfigSplit(header, footer, data []byte) (before, after []byte, err error) { + startCount := bytes.Count(data, header) + endCount := bytes.Count(data, footer) + if startCount > 1 || endCount > 1 { + return nil, nil, fmt.Errorf("Malformed config file: multiple config sections") + } + + startIndex := bytes.Index(data, header) + endIndex := bytes.Index(data, footer) + if startIndex == -1 && endIndex != -1 { + return data, nil, fmt.Errorf("Malformed config file: missing completion header") + } + if startIndex != -1 && endIndex == -1 { + return data, nil, fmt.Errorf("Malformed config file: missing completion footer") + } + if startIndex != -1 && endIndex != -1 { + if startIndex > endIndex { + return data, nil, fmt.Errorf("Malformed config file: completion header after footer") + } + // Include leading and trailing newline, if present + start := startIndex + if start > 0 { + start-- + } + end := endIndex + len(footer) + if end < len(data) { + end++ + } + return data[:start], data[end:], nil } + return data, nil, nil } diff --git a/completion/bash.go b/completion/bash.go index fad4069..14282eb 100644 --- a/completion/bash.go +++ b/completion/bash.go @@ -1,22 +1,62 @@ package completion +import ( + "io" + "path/filepath" + + home "github.com/mitchellh/go-homedir" +) + +type bash struct { + goos string + programName string +} + +var _ Shell = &bash{} + +func Bash(goos string, programName string) Shell { + return &bash{goos: goos, programName: programName} +} + +func (b *bash) Name() string { + return "bash" +} + +func (b *bash) InstallPath() (string, error) { + homeDir, err := home.Dir() + if err != nil { + return "", err + } + if b.goos == "darwin" { + return filepath.Join(homeDir, ".bash_profile"), nil + } + return filepath.Join(homeDir, ".bashrc"), nil +} + +func (b *bash) WriteCompletion(w io.Writer) error { + return writeConfig(w, bashCompletionTemplate, b.programName) +} + +func (b *bash) ProgramName() string { + return b.programName +} + const bashCompletionTemplate = ` _generate_{{.Name}}_completions() { - # Capture the line excluding the command, and everything after the current word local args=("${COMP_WORDS[@]:1:COMP_CWORD}") - # Set COMPLETION_MODE and call the command with the arguments, capturing the output - local completions=$(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}") + declare -a output + mapfile -t output < <(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}") - # Use the command's output to generate completions for the current word - COMPREPLY=($(compgen -W "$completions" -- "${COMP_WORDS[COMP_CWORD]}")) + declare -a completions + mapfile -t completions < <( compgen -W "$(printf '%q ' "${output[@]}")" -- "$2" ) - # Ensure no files are shown, even if there are no matches - if [ ${#COMPREPLY[@]} -eq 0 ]; then - COMPREPLY=() - fi + local comp + COMPREPLY=() + for comp in "${completions[@]}"; do + COMPREPLY+=("$(printf "%q" "$comp")") + done } - # Setup Bash to use the function for completions for '{{.Name}}' complete -F _generate_{{.Name}}_completions {{.Name}} ` diff --git a/completion/fish.go b/completion/fish.go index f9b2793..7e5a21e 100644 --- a/completion/fish.go +++ b/completion/fish.go @@ -1,5 +1,43 @@ package completion +import ( + "io" + "path/filepath" + + home "github.com/mitchellh/go-homedir" +) + +type fish struct { + goos string + programName string +} + +var _ Shell = &fish{} + +func Fish(goos string, programName string) Shell { + return &fish{goos: goos, programName: programName} +} + +func (f *fish) Name() string { + return "fish" +} + +func (f *fish) InstallPath() (string, error) { + homeDir, err := home.Dir() + if err != nil { + return "", err + } + return filepath.Join(homeDir, ".config/fish/completions/", f.programName+".fish"), nil +} + +func (f *fish) WriteCompletion(w io.Writer) error { + return writeConfig(w, fishCompletionTemplate, f.programName) +} + +func (f *fish) ProgramName() string { + return f.programName +} + const fishCompletionTemplate = ` function _{{.Name}}_completions # Capture the full command line as an array diff --git a/completion/powershell.go b/completion/powershell.go index e30c61b..cc083bd 100644 --- a/completion/powershell.go +++ b/completion/powershell.go @@ -1,7 +1,52 @@ package completion -const pshCompletionTemplate = ` +import ( + "io" + "os/exec" + "strings" +) + +type powershell struct { + goos string + programName string +} + +var _ Shell = &powershell{} + +func (p *powershell) Name() string { + return "powershell" +} + +func Powershell(goos string, programName string) Shell { + return &powershell{goos: goos, programName: programName} +} +func (p *powershell) InstallPath() (string, error) { + var ( + path []byte + err error + ) + cmd := "$PROFILE.CurrentUserAllHosts" + if p.goos == "windows" { + path, err = exec.Command("powershell", cmd).CombinedOutput() + } else { + path, err = exec.Command("pwsh", "-Command", cmd).CombinedOutput() + } + if err != nil { + return "", err + } + return strings.TrimSpace(string(path)), nil +} + +func (p *powershell) WriteCompletion(w io.Writer) error { + return writeConfig(w, pshCompletionTemplate, p.programName) +} + +func (p *powershell) ProgramName() string { + return p.programName +} + +const pshCompletionTemplate = ` # Escaping output sourced from: # https://github.com/spf13/cobra/blob/e94f6d0dd9a5e5738dca6bce03c4b1207ffbc0ec/powershell_completions.go#L47 filter _{{.Name}}_escapeStringWithSpecialChars { @@ -35,8 +80,7 @@ $_{{.Name}}_completions = { Invoke-Expression $Command | Where-Object { $_ -like "$wordToComplete*" } | ForEach-Object { "$_" | _{{.Name}}_escapeStringWithSpecialChars } - rm env:COMPLETION_MODE + $env:COMPLETION_MODE = '' } - Register-ArgumentCompleter -CommandName {{.Name}} -ScriptBlock $_{{.Name}}_completions ` diff --git a/completion/zsh.go b/completion/zsh.go index a8ee4a8..b2793b0 100644 --- a/completion/zsh.go +++ b/completion/zsh.go @@ -1,12 +1,49 @@ package completion +import ( + "io" + "path/filepath" + + home "github.com/mitchellh/go-homedir" +) + +type zsh struct { + goos string + programName string +} + +var _ Shell = &zsh{} + +func Zsh(goos string, programName string) Shell { + return &zsh{goos: goos, programName: programName} +} + +func (z *zsh) Name() string { + return "zsh" +} + +func (z *zsh) InstallPath() (string, error) { + homeDir, err := home.Dir() + if err != nil { + return "", err + } + return filepath.Join(homeDir, ".zshrc"), nil +} + +func (z *zsh) WriteCompletion(w io.Writer) error { + return writeConfig(w, zshCompletionTemplate, z.programName) +} + +func (z *zsh) ProgramName() string { + return z.programName +} + const zshCompletionTemplate = ` _{{.Name}}_completions() { local -a args completions args=("${words[@]:1:$#words}") - completions=($(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}")) + completions=(${(f)"$(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}")"}) compadd -a completions } - compdef _{{.Name}}_completions {{.Name}} ` diff --git a/completion_test.go b/completion_test.go index 5ca160a..f1527be 100644 --- a/completion_test.go +++ b/completion_test.go @@ -2,11 +2,14 @@ package serpent_test import ( "fmt" + "io" "os" + "path/filepath" "strings" "testing" serpent "github.com/coder/serpent" + "github.com/coder/serpent/completion" "github.com/stretchr/testify/require" ) @@ -52,7 +55,7 @@ func TestCompletion(t *testing.T) { io := fakeIO(i) err := i.Run() require.NoError(t, err) - require.Equal(t, "--req-array\n--req-bool\n--req-enum\n--req-string\n", io.Stdout.String()) + require.Equal(t, "--req-array\n--req-bool\n--req-enum\n--req-enum-array\n--req-string\n", io.Stdout.String()) }) t.Run("ListFlagsAfterArg", func(t *testing.T) { @@ -72,7 +75,7 @@ func TestCompletion(t *testing.T) { io := fakeIO(i) err := i.Run() require.NoError(t, err) - require.Equal(t, "--req-array\n--req-enum\n", io.Stdout.String()) + require.Equal(t, "--req-array\n--req-enum\n--req-enum-array\n", io.Stdout.String()) }) t.Run("FlagShorthand", func(t *testing.T) { @@ -82,7 +85,7 @@ func TestCompletion(t *testing.T) { io := fakeIO(i) err := i.Run() require.NoError(t, err) - require.Equal(t, "--req-array\n--req-enum\n", io.Stdout.String()) + require.Equal(t, "--req-array\n--req-enum\n--req-enum-array\n", io.Stdout.String()) }) t.Run("NoOptDefValueFlag", func(t *testing.T) { @@ -125,6 +128,36 @@ func TestCompletion(t *testing.T) { require.Equal(t, "--req-enum=foo\n--req-enum=bar\n--req-enum=qux\n", io.Stdout.String()) }) + t.Run("EnumArrayOK", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-enum-array", "") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "foo\nbar\nqux\n", io.Stdout.String()) + }) + + t.Run("EnumArrayEqualsOK", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-enum-array=") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-enum-array=foo\n--req-enum-array=bar\n--req-enum-array=qux\n", io.Stdout.String()) + }) + + t.Run("EnumArrayEqualsBeginQuotesOK", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-enum-array=\"") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-enum-array=foo\n--req-enum-array=bar\n--req-enum-array=qux\n", io.Stdout.String()) + }) + } func TestFileCompletion(t *testing.T) { @@ -201,3 +234,108 @@ func TestFileCompletion(t *testing.T) { }) } } + +func TestCompletionInstall(t *testing.T) { + t.Parallel() + + t.Run("InstallingNew", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "fake.sh") + shell := &fakeShell{baseInstallDir: dir, programName: "fake"} + + err := completion.InstallShellCompletion(shell) + require.NoError(t, err) + contents, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, "# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\n", string(contents)) + }) + + cases := []struct { + name string + input []byte + expected []byte + errMsg string + }{ + { + name: "InstallingAppend", + input: []byte("FAKE_SCRIPT"), + expected: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\n"), + }, + { + name: "InstallReplaceBeginning", + input: []byte("# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), + expected: []byte("# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), + }, + { + name: "InstallReplaceMiddle", + input: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), + expected: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), + }, + { + name: "InstallReplaceEnd", + input: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n# ============ END fake COMPLETION ==============\n"), + expected: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\n"), + }, + { + name: "InstallNoFooter", + input: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n"), + errMsg: "missing completion footer", + }, + { + name: "InstallNoHeader", + input: []byte("OLD_COMPLETION\n# ============ END fake COMPLETION ==============\n"), + errMsg: "missing completion header", + }, + { + name: "InstallBadOrder", + input: []byte("# ============ END fake COMPLETION ==============\nFAKE_COMPLETION\n# ============ BEGIN fake COMPLETION =============="), + errMsg: "header after footer", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "fake.sh") + err := os.WriteFile(path, tc.input, 0o644) + require.NoError(t, err) + + shell := &fakeShell{baseInstallDir: dir, programName: "fake"} + err = completion.InstallShellCompletion(shell) + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + return + } else { + require.NoError(t, err) + contents, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, tc.expected, contents) + } + }) + } +} + +type fakeShell struct { + baseInstallDir string + programName string +} + +func (f *fakeShell) ProgramName() string { + return f.programName +} + +var _ completion.Shell = &fakeShell{} + +func (f *fakeShell) InstallPath() (string, error) { + return filepath.Join(f.baseInstallDir, "fake.sh"), nil +} + +func (f *fakeShell) Name() string { + return "Fake" +} + +func (f *fakeShell) WriteCompletion(w io.Writer) error { + _, err := w.Write([]byte("\nFAKE_COMPLETION\n")) + return err +} diff --git a/example/completetest/main.go b/example/completetest/main.go index 920e705..add6d5c 100644 --- a/example/completetest/main.go +++ b/example/completetest/main.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "os" "strings" @@ -11,24 +12,21 @@ import ( // installCommand returns a serpent command that helps // a user configure their shell to use serpent's completion. func installCommand() *serpent.Command { - defaultShell, err := completion.DetectUserShell() - if err != nil { - defaultShell = "bash" - } - var shell string return &serpent.Command{ - Use: "completion", + Use: "completion [--shell ]", Short: "Generate completion scripts for the given shell.", Handler: func(inv *serpent.Invocation) error { - completion.WriteCompletion(inv.Stdout, shell, inv.Command.Parent.Name()) - return nil + defaultShell, err := completion.DetectUserShell(inv.Command.Parent.Name()) + if err != nil { + return fmt.Errorf("Could not detect user shell, please specify a shell using `--shell`") + } + return defaultShell.WriteCompletion(inv.Stdout) }, Options: serpent.OptionSet{ { Flag: "shell", FlagShorthand: "s", - Default: defaultShell, Description: "The shell to generate a completion script for.", Value: completion.ShellOptions(&shell), }, @@ -42,6 +40,7 @@ func main() { upper bool fileType string fileArr []string + types []string ) cmd := serpent.Command{ Use: "completetest ", @@ -109,6 +108,11 @@ func main() { Description: "Extra files.", Value: serpent.StringArrayOf(&fileArr), }, + { + Name: "types", + Flag: "types", + Value: serpent.EnumArrayOf(&types, "binary", "text"), + }, }, CompletionHandler: completion.FileHandler(nil), Middleware: serpent.RequireNArgs(1), diff --git a/go.mod b/go.mod index 8bd432e..1c2880c 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,10 @@ require ( cdr.dev/slog v1.6.2-0.20240126064726-20367d4aede6 github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0 github.com/hashicorp/go-multierror v1.1.1 + github.com/mitchellh/go-homedir v1.1.0 github.com/mitchellh/go-wordwrap v1.0.1 github.com/muesli/termenv v0.15.2 + github.com/natefinch/atomic v1.0.1 github.com/pion/udp v0.1.4 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 diff --git a/go.sum b/go.sum index 63d175d..a1106fc 100644 --- a/go.sum +++ b/go.sum @@ -46,12 +46,16 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= +github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A= +github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/transport/v2 v2.0.0 h1:bsMYyqHCbkvHwj+eNCFBuxtlKndKfyGI2vaQmM3fIE4= github.com/pion/transport/v2 v2.0.0/go.mod h1:HS2MEBJTwD+1ZI2eSXSvHJx/HnzQqRy2/LXxt6eVMHc= diff --git a/help.go b/help.go index 3dc49a4..533be11 100644 --- a/help.go +++ b/help.go @@ -95,6 +95,8 @@ var defaultHelpTemplate = func() *template.Template { switch v := opt.Value.(type) { case *Enum: return strings.Join(v.Choices, "|") + case *EnumArray: + return fmt.Sprintf("[%s]", strings.Join(v.Choices, "|")) default: return v.Type() } diff --git a/option.go b/option.go index fccc67e..2780fc6 100644 --- a/option.go +++ b/option.go @@ -352,7 +352,7 @@ func (optSet OptionSet) ByFlag(flag string) *Option { } for i := range optSet { opt := &optSet[i] - if opt.Flag == flag || opt.FlagShorthand == flag { + if opt.Flag == flag { return opt } } diff --git a/values.go b/values.go index 25fc20f..79c8e2c 100644 --- a/values.go +++ b/values.go @@ -191,7 +191,10 @@ func (String) Type() string { return "string" } -var _ pflag.SliceValue = &StringArray{} +var ( + _ pflag.SliceValue = &StringArray{} + _ pflag.Value = &StringArray{} +) // StringArray is a slice of strings that implements pflag.Value and pflag.SliceValue. type StringArray []string @@ -527,7 +530,7 @@ func EnumOf(v *string, choices ...string) *Enum { func (e *Enum) Set(v string) error { for _, c := range e.Choices { - if v == c { + if strings.EqualFold(v, c) { *e.Value = v return nil } @@ -628,3 +631,76 @@ func (p *YAMLConfigPath) String() string { func (*YAMLConfigPath) Type() string { return "yaml-config-path" } + +var _ pflag.SliceValue = (*EnumArray)(nil) +var _ pflag.Value = (*EnumArray)(nil) + +type EnumArray struct { + Choices []string + Value *[]string +} + +func (e *EnumArray) Append(s string) error { + for _, c := range e.Choices { + if strings.EqualFold(s, c) { + *e.Value = append(*e.Value, s) + return nil + } + } + return xerrors.Errorf("invalid choice: %s, should be one of %v", s, e.Choices) +} + +func (e *EnumArray) GetSlice() []string { + return *e.Value +} + +func (e *EnumArray) Replace(ss []string) error { + for _, s := range ss { + found := false + for _, c := range e.Choices { + if strings.EqualFold(s, c) { + found = true + break + } + } + if !found { + return xerrors.Errorf("invalid choice: %s, should be one of %v", s, e.Choices) + } + } + *e.Value = ss + return nil +} + +func (e *EnumArray) Set(v string) error { + if v == "" { + *e.Value = nil + return nil + } + ss, err := readAsCSV(v) + if err != nil { + return err + } + for _, s := range ss { + err := e.Append(s) + if err != nil { + return err + } + } + return nil +} + +func (e *EnumArray) String() string { + return writeAsCSV(*e.Value) +} + +func (e *EnumArray) Type() string { + return fmt.Sprintf("enum-array[%v]", strings.Join(e.Choices, "\\|")) +} + +func EnumArrayOf(v *[]string, choices ...string) *EnumArray { + choices = append([]string{}, choices...) + return &EnumArray{ + Choices: choices, + Value: v, + } +}