diff --git a/command.go b/command.go index 8b99e26..2a089ba 100644 --- a/command.go +++ b/command.go @@ -59,6 +59,12 @@ type Command struct { Middleware MiddlewareFunc Handler HandlerFunc HelpHandler HandlerFunc + // CompletionHandler is called when the command is run in completion + // mode. If nil, only the default completion handler is used. + // + // Flag and option parsing is best-effort in this mode, so even if an Option + // is "required" it may not be set. + CompletionHandler CompletionHandlerFunc } // AddSubcommands adds the given subcommands, setting their @@ -193,15 +199,22 @@ type Invocation struct { ctx context.Context Command *Command parsedFlags *pflag.FlagSet - Args []string + + // Args is reduced into the remaining arguments after parsing flags + // during Run. + Args []string + // Environ is a list of environment variables. Use EnvsWithPrefix to parse // os.Environ. Environ Environ Stdout io.Writer Stderr io.Writer Stdin io.Reader - Logger slog.Logger - Net Net + + // Deprecated + Logger slog.Logger + // Deprecated + Net Net // testing signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) @@ -282,6 +295,17 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet { return fs2 } +func (inv *Invocation) CurWords() (prev string, cur string) { + if len(inv.Args) == 1 { + cur = inv.Args[0] + prev = "" + } else { + cur = inv.Args[len(inv.Args)-1] + prev = inv.Args[len(inv.Args)-2] + } + return +} + // run recursively executes the command and its children. // allArgs is wired through the stack so that global flags can be accepted // anywhere in the command invocation. @@ -378,8 +402,19 @@ func (inv *Invocation) run(state *runState) error { } } + // Outputted completions are not filtered based on the word under the cursor, as every shell we support does this already. + // We only look at the current word to figure out handler to run, or what directory to inspect. + if inv.IsCompletionMode() { + for _, e := range inv.complete() { + fmt.Fprintln(inv.Stdout, e) + } + return nil + } + + ignoreFlagParseErrors := inv.Command.RawArgs + // Flag parse errors are irrelevant for raw args commands. - if !inv.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) { + if !ignoreFlagParseErrors && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) { return xerrors.Errorf( "parsing flags (%v) for %q: %w", state.allArgs, @@ -396,7 +431,7 @@ func (inv *Invocation) run(state *runState) error { } } // Don't error for missing flags if `--help` was supplied. - if len(missing) > 0 && !errors.Is(state.flagParseErr, pflag.ErrHelp) { + if len(missing) > 0 && !inv.IsCompletionMode() && !errors.Is(state.flagParseErr, pflag.ErrHelp) { return xerrors.Errorf("Missing values for the required flags: %s", strings.Join(missing, ", ")) } @@ -553,6 +588,65 @@ func (inv *Invocation) with(fn func(*Invocation)) *Invocation { return &i2 } +func (inv *Invocation) complete() []string { + prev, cur := inv.CurWords() + + // If the current word is a flag + if strings.HasPrefix(cur, "--") { + flagParts := strings.Split(cur, "=") + flagName := flagParts[0][2:] + // If it's an equals flag + if len(flagParts) == 2 { + if out := inv.completeFlag(flagName); out != nil { + for i, o := range out { + out[i] = fmt.Sprintf("--%s=%s", flagName, o) + } + return out + } + } else if out := inv.Command.Options.ByFlag(flagName); out != nil { + // If the current word is a valid flag, auto-complete it so the + // shell moves the cursor + return []string{cur} + } + } + // If the previous word is a flag, then we're writing it's value + // and we should check it's handler + if strings.HasPrefix(prev, "--") { + word := prev[2:] + if out := inv.completeFlag(word); out != nil { + return out + } + } + // If the current word is the command, move the shell cursor + if inv.Command.Name() == cur { + return []string{inv.Command.Name()} + } + var completions []string + + if inv.Command.CompletionHandler != nil { + completions = append(completions, inv.Command.CompletionHandler(inv)...) + } + + completions = append(completions, DefaultCompletionHandler(inv)...) + + return completions +} + +func (inv *Invocation) completeFlag(word string) []string { + opt := inv.Command.Options.ByFlag(word) + if opt == nil { + return nil + } + if opt.CompletionHandler != nil { + return opt.CompletionHandler(inv) + } + val, ok := opt.Value.(*Enum) + if ok { + return val.Choices + } + return nil +} + // MiddlewareFunc returns the next handler in the chain, // or nil if there are no more. type MiddlewareFunc func(next HandlerFunc) HandlerFunc @@ -637,3 +731,5 @@ func RequireRangeArgs(start, end int) MiddlewareFunc { // HandlerFunc handles an Invocation of a command. type HandlerFunc func(i *Invocation) error + +type CompletionHandlerFunc func(i *Invocation) []string diff --git a/command_test.go b/command_test.go index f6a20a2..de6c12d 100644 --- a/command_test.go +++ b/command_test.go @@ -12,6 +12,7 @@ import ( "golang.org/x/xerrors" serpent "github.com/coder/serpent" + "github.com/coder/serpent/completion" ) // ioBufs is the standard input, output, and error for a command. @@ -30,100 +31,147 @@ func fakeIO(i *serpent.Invocation) *ioBufs { return &b } -func TestCommand(t *testing.T) { - t.Parallel() - - cmd := func() *serpent.Command { - var ( - verbose bool - lower bool - prefix string - reqBool bool - reqStr string - ) - return &serpent.Command{ - Use: "root [subcommand]", - Options: serpent.OptionSet{ - serpent.Option{ - Name: "verbose", - Flag: "verbose", - Value: serpent.BoolOf(&verbose), - }, - serpent.Option{ - Name: "prefix", - Flag: "prefix", - Value: serpent.StringOf(&prefix), - }, +func sampleCommand(t *testing.T) *serpent.Command { + t.Helper() + var ( + verbose bool + lower bool + prefix string + reqBool bool + reqStr string + reqArr []string + fileArr []string + enumStr string + ) + enumChoices := []string{"foo", "bar", "qux"} + return &serpent.Command{ + Use: "root [subcommand]", + Options: serpent.OptionSet{ + serpent.Option{ + Name: "verbose", + Flag: "verbose", + Value: serpent.BoolOf(&verbose), }, - Children: []*serpent.Command{ - { - Use: "required-flag --req-bool=true --req-string=foo", - Short: "Example with required flags", - Options: serpent.OptionSet{ - serpent.Option{ - Name: "req-bool", - Flag: "req-bool", - Value: serpent.BoolOf(&reqBool), - Required: true, - }, - serpent.Option{ - Name: "req-string", - Flag: "req-string", - Value: serpent.Validate(serpent.StringOf(&reqStr), func(value *serpent.String) error { - ok := strings.Contains(value.String(), " ") - if !ok { - return xerrors.Errorf("string must contain a space") - } - return nil - }), - Required: true, - }, + serpent.Option{ + Name: "prefix", + Flag: "prefix", + Value: serpent.StringOf(&prefix), + }, + }, + Children: []*serpent.Command{ + { + Use: "required-flag --req-bool=true --req-string=foo", + Short: "Example with required flags", + Options: serpent.OptionSet{ + serpent.Option{ + Name: "req-bool", + Flag: "req-bool", + FlagShorthand: "b", + Value: serpent.BoolOf(&reqBool), + Required: true, }, - HelpHandler: func(i *serpent.Invocation) error { - _, _ = i.Stdout.Write([]byte("help text.png")) - return nil + serpent.Option{ + Name: "req-string", + Flag: "req-string", + FlagShorthand: "s", + Value: serpent.Validate(serpent.StringOf(&reqStr), func(value *serpent.String) error { + ok := strings.Contains(value.String(), " ") + if !ok { + return xerrors.Errorf("string must contain a space") + } + return nil + }), + Required: true, }, - Handler: func(i *serpent.Invocation) error { - _, _ = i.Stdout.Write([]byte(fmt.Sprintf("%s-%t", reqStr, reqBool))) - return nil + serpent.Option{ + Name: "req-enum", + Flag: "req-enum", + Value: serpent.EnumOf(&enumStr, enumChoices...), + }, + serpent.Option{ + Name: "req-array", + Flag: "req-array", + FlagShorthand: "a", + Value: serpent.StringArrayOf(&reqArr), }, }, - { - Use: "toupper [word]", - Short: "Converts a word to upper case", - Middleware: serpent.Chain( - serpent.RequireNArgs(1), - ), - Aliases: []string{"up"}, - Options: serpent.OptionSet{ - serpent.Option{ - Name: "lower", - Flag: "lower", - Value: serpent.BoolOf(&lower), - }, + HelpHandler: func(i *serpent.Invocation) error { + _, _ = i.Stdout.Write([]byte("help text.png")) + return nil + }, + Handler: func(i *serpent.Invocation) error { + _, _ = i.Stdout.Write([]byte(fmt.Sprintf("%s-%t", reqStr, reqBool))) + return nil + }, + }, + { + Use: "toupper [word]", + Short: "Converts a word to upper case", + Middleware: serpent.Chain( + serpent.RequireNArgs(1), + ), + Aliases: []string{"up"}, + Options: serpent.OptionSet{ + serpent.Option{ + Name: "lower", + Flag: "lower", + Value: serpent.BoolOf(&lower), }, - Handler: func(i *serpent.Invocation) error { - _, _ = i.Stdout.Write([]byte(prefix)) - w := i.Args[0] - if lower { - w = strings.ToLower(w) - } else { - w = strings.ToUpper(w) - } - _, _ = i.Stdout.Write( - []byte( - w, - ), - ) - if verbose { - _, _ = i.Stdout.Write([]byte("!!!")) - } - return nil + }, + Handler: func(i *serpent.Invocation) error { + _, _ = i.Stdout.Write([]byte(prefix)) + w := i.Args[0] + if lower { + w = strings.ToLower(w) + } else { + w = strings.ToUpper(w) + } + _, _ = i.Stdout.Write( + []byte( + w, + ), + ) + if verbose { + _, _ = i.Stdout.Write([]byte("!!!")) + } + return nil + }, + }, + { + Use: "file ", + Handler: func(inv *serpent.Invocation) error { + return nil + }, + CompletionHandler: completion.FileHandler(func(info os.FileInfo) bool { + return true + }), + Middleware: serpent.RequireNArgs(1), + }, + { + Use: "altfile", + Handler: func(inv *serpent.Invocation) error { + return nil + }, + Options: serpent.OptionSet{ + { + Name: "extra", + Flag: "extra", + Description: "Extra files.", + Value: serpent.StringArrayOf(&fileArr), }, }, + CompletionHandler: func(i *serpent.Invocation) []string { + return []string{"doesntexist.go"} + }, }, - } + }, } +} + +func TestCommand(t *testing.T) { + t.Parallel() + + cmd := func() *serpent.Command { return sampleCommand(t) } t.Run("SimpleOK", func(t *testing.T) { t.Parallel() diff --git a/completion.go b/completion.go new file mode 100644 index 0000000..a0fb779 --- /dev/null +++ b/completion.go @@ -0,0 +1,26 @@ +package serpent + +// CompletionModeEnv is a special environment variable that is +// set when the command is being run in completion mode. +const CompletionModeEnv = "COMPLETION_MODE" + +// IsCompletionMode returns true if the command is being run in completion mode. +func (inv *Invocation) IsCompletionMode() bool { + _, ok := inv.Environ.Lookup(CompletionModeEnv) + return ok +} + +// DefaultCompletionHandler returns a handler that prints all +// known flags and subcommands that haven't already been set to valid values. +func DefaultCompletionHandler(inv *Invocation) []string { + var allResps []string + for _, cmd := range inv.Command.Children { + allResps = append(allResps, cmd.Name()) + } + for _, opt := range inv.Command.Options { + if opt.ValueSource == ValueSourceNone || opt.ValueSource == ValueSourceDefault || opt.Value.Type() == "string-array" { + allResps = append(allResps, "--"+opt.Flag) + } + } + return allResps +} diff --git a/completion/README.md b/completion/README.md new file mode 100644 index 0000000..d7021ed --- /dev/null +++ b/completion/README.md @@ -0,0 +1,11 @@ +# completion + +The `completion` package extends `serpent` to allow applications to generate rich auto-completions. + + +## Protocol + +The completion scripts call out to the serpent command to generate +completions. The convention is to pass the exact args and flags (or +cmdline) of the in-progress command with a `COMPLETION_MODE=1` environment variable. That environment variable lets the command know to generate completions instead of running the command. +By default, completions will be generated based on available flags and subcommands. Additional completions can be added by supplying a `CompletionHandlerFunc` on an Option or Command. \ No newline at end of file diff --git a/completion/all.go b/completion/all.go new file mode 100644 index 0000000..b20c254 --- /dev/null +++ b/completion/all.go @@ -0,0 +1,94 @@ +package completion + +import ( + "fmt" + "io" + "os" + "os/user" + "path/filepath" + "strings" + "text/template" + + "github.com/coder/serpent" +) + +const ( + BashShell string = "bash" + FishShell string = "fish" + ZShell string = "zsh" + Powershell string = "powershell" +) + +var shellCompletionByName = map[string]func(io.Writer, string) error{ + BashShell: generateCompletion(bashCompletionTemplate), + FishShell: generateCompletion(fishCompletionTemplate), + ZShell: generateCompletion(zshCompletionTemplate), + Powershell: generateCompletion(pshCompletionTemplate), +} + +func ShellOptions(choice *string) *serpent.Enum { + return serpent.EnumOf(choice, BashShell, FishShell, ZShell, Powershell) +} + +func WriteCompletion(writer io.Writer, shell string, cmdName string) error { + fn, ok := shellCompletionByName[shell] + if !ok { + return fmt.Errorf("unknown shell %q", shell) + } + fn(writer, cmdName) + return nil +} + +func DetectUserShell() (string, error) { + // Attempt to get the SHELL environment variable first + if shell := os.Getenv("SHELL"); shell != "" { + return filepath.Base(shell), nil + } + + // Fallback: Look up the current user and parse /etc/passwd + currentUser, err := user.Current() + if err != nil { + return "", err + } + + // Open and parse /etc/passwd + passwdFile, err := os.ReadFile("/etc/passwd") + if err != nil { + return "", err + } + + lines := strings.Split(string(passwdFile), "\n") + for _, line := range lines { + if strings.HasPrefix(line, currentUser.Username+":") { + parts := strings.Split(line, ":") + if len(parts) > 6 { + return filepath.Base(parts[6]), nil // The shell is typically the 7th field + } + } + } + + return "", fmt.Errorf("default shell not found") +} + +func generateCompletion( + scriptTemplate string, +) func(io.Writer, string) error { + return func(w io.Writer, rootCmdName string) error { + tmpl, err := template.New("script").Parse(scriptTemplate) + if err != nil { + return fmt.Errorf("parse template: %w", err) + } + + err = tmpl.Execute( + w, + map[string]string{ + "Name": rootCmdName, + }, + ) + if err != nil { + return fmt.Errorf("execute template: %w", err) + } + + return nil + } +} diff --git a/completion/bash.go b/completion/bash.go new file mode 100644 index 0000000..fad4069 --- /dev/null +++ b/completion/bash.go @@ -0,0 +1,22 @@ +package completion + +const bashCompletionTemplate = ` +_generate_{{.Name}}_completions() { + # Capture the line excluding the command, and everything after the current word + local args=("${COMP_WORDS[@]:1:COMP_CWORD}") + + # Set COMPLETION_MODE and call the command with the arguments, capturing the output + local completions=$(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}") + + # Use the command's output to generate completions for the current word + COMPREPLY=($(compgen -W "$completions" -- "${COMP_WORDS[COMP_CWORD]}")) + + # Ensure no files are shown, even if there are no matches + if [ ${#COMPREPLY[@]} -eq 0 ]; then + COMPREPLY=() + fi +} + +# Setup Bash to use the function for completions for '{{.Name}}' +complete -F _generate_{{.Name}}_completions {{.Name}} +` diff --git a/completion/fish.go b/completion/fish.go new file mode 100644 index 0000000..f9b2793 --- /dev/null +++ b/completion/fish.go @@ -0,0 +1,13 @@ +package completion + +const fishCompletionTemplate = ` +function _{{.Name}}_completions + # Capture the full command line as an array + set -l args (commandline -opc) + set -l current (commandline -ct) + COMPLETION_MODE=1 $args $current +end + +# Setup Fish to use the function for completions for '{{.Name}}' +complete -c {{.Name}} -f -a '(_{{.Name}}_completions)' +` diff --git a/completion/handlers.go b/completion/handlers.go new file mode 100644 index 0000000..848bb06 --- /dev/null +++ b/completion/handlers.go @@ -0,0 +1,55 @@ +package completion + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/coder/serpent" +) + +// FileHandler returns a handler that completes file names, using the +// given filter func, which may be nil. +func FileHandler(filter func(info os.FileInfo) bool) serpent.CompletionHandlerFunc { + return func(inv *serpent.Invocation) []string { + var out []string + _, word := inv.CurWords() + + dir, _ := filepath.Split(word) + if dir == "" { + dir = "." + } + f, err := os.Open(dir) + if err != nil { + return out + } + defer f.Close() + if dir == "." { + dir = "" + } + + infos, err := f.Readdir(0) + if err != nil { + return out + } + + for _, info := range infos { + if filter != nil && !filter(info) { + continue + } + + var cur string + if info.IsDir() { + cur = fmt.Sprintf("%s%s%c", dir, info.Name(), os.PathSeparator) + } else { + cur = fmt.Sprintf("%s%s", dir, info.Name()) + } + + if strings.HasPrefix(cur, word) { + out = append(out, cur) + } + } + return out + } +} diff --git a/completion/powershell.go b/completion/powershell.go new file mode 100644 index 0000000..e30c61b --- /dev/null +++ b/completion/powershell.go @@ -0,0 +1,42 @@ +package completion + +const pshCompletionTemplate = ` + +# Escaping output sourced from: +# https://github.com/spf13/cobra/blob/e94f6d0dd9a5e5738dca6bce03c4b1207ffbc0ec/powershell_completions.go#L47 +filter _{{.Name}}_escapeStringWithSpecialChars { +` + " $_ -replace '\\s|#|@|\\$|;|,|''|\\{|\\}|\\(|\\)|\"|`|\\||<|>|&','`$&'" + ` +} + +$_{{.Name}}_completions = { + param( + $wordToComplete, + $commandAst, + $cursorPosition + ) + # Legacy space handling sourced from: + # https://github.com/spf13/cobra/blob/e94f6d0dd9a5e5738dca6bce03c4b1207ffbc0ec/powershell_completions.go#L107 + if ($PSVersionTable.PsVersion -lt [version]'7.2.0' -or + ($PSVersionTable.PsVersion -lt [version]'7.3.0' -and -not [ExperimentalFeature]::IsEnabled("PSNativeCommandArgumentPassing")) -or + (($PSVersionTable.PsVersion -ge [version]'7.3.0' -or [ExperimentalFeature]::IsEnabled("PSNativeCommandArgumentPassing")) -and + $PSNativeCommandArgumentPassing -eq 'Legacy')) { + $Space =` + "' `\"`\"'" + ` + } else { + $Space = ' ""' + } + $Command = $commandAst.ToString().Substring(0, $cursorPosition - 1) + if ($wordToComplete -ne "" ) { + $wordToComplete = $Command.Split(" ")[-1] + } else { + $Command = $Command + $Space + } + # Get completions by calling the command with the COMPLETION_MODE environment variable set to 1 + $env:COMPLETION_MODE = 1 + Invoke-Expression $Command | Where-Object { $_ -like "$wordToComplete*" } | ForEach-Object { + "$_" | _{{.Name}}_escapeStringWithSpecialChars + } + rm env:COMPLETION_MODE +} + +Register-ArgumentCompleter -CommandName {{.Name}} -ScriptBlock $_{{.Name}}_completions +` diff --git a/completion/zsh.go b/completion/zsh.go new file mode 100644 index 0000000..a8ee4a8 --- /dev/null +++ b/completion/zsh.go @@ -0,0 +1,12 @@ +package completion + +const zshCompletionTemplate = ` +_{{.Name}}_completions() { + local -a args completions + args=("${words[@]:1:$#words}") + completions=($(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}")) + compadd -a completions +} + +compdef _{{.Name}}_completions {{.Name}} +` diff --git a/completion_test.go b/completion_test.go new file mode 100644 index 0000000..5ca160a --- /dev/null +++ b/completion_test.go @@ -0,0 +1,203 @@ +package serpent_test + +import ( + "fmt" + "os" + "strings" + "testing" + + serpent "github.com/coder/serpent" + "github.com/stretchr/testify/require" +) + +func TestCompletion(t *testing.T) { + t.Parallel() + + cmd := func() *serpent.Command { return sampleCommand(t) } + + t.Run("SubcommandList", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "altfile\nfile\nrequired-flag\ntoupper\n--prefix\n--verbose\n", io.Stdout.String()) + }) + + t.Run("SubcommandNoPartial", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("f") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "altfile\nfile\nrequired-flag\ntoupper\n--prefix\n--verbose\n", io.Stdout.String()) + }) + + t.Run("SubcommandComplete", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "required-flag\n", io.Stdout.String()) + }) + + t.Run("ListFlags", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-array\n--req-bool\n--req-enum\n--req-string\n", io.Stdout.String()) + }) + + t.Run("ListFlagsAfterArg", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("altfile", "") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "doesntexist.go\n--extra\n", io.Stdout.String()) + }) + + t.Run("FlagExhaustive", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-bool", "--req-string", "foo bar", "--req-array", "asdf", "--req-array", "qwerty") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-array\n--req-enum\n", io.Stdout.String()) + }) + + t.Run("FlagShorthand", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "-b", "-s", "foo bar", "-a", "asdf") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-array\n--req-enum\n", io.Stdout.String()) + }) + + t.Run("NoOptDefValueFlag", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("--verbose", "") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "altfile\nfile\nrequired-flag\ntoupper\n--prefix\n", io.Stdout.String()) + }) + + t.Run("EnumOK", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-enum", "") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "foo\nbar\nqux\n", io.Stdout.String()) + }) + + t.Run("EnumEqualsOK", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-enum", "--req-enum=") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-enum=foo\n--req-enum=bar\n--req-enum=qux\n", io.Stdout.String()) + }) + + t.Run("EnumEqualsBeginQuotesOK", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-enum", "--req-enum=\"") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-enum=foo\n--req-enum=bar\n--req-enum=qux\n", io.Stdout.String()) + }) + +} + +func TestFileCompletion(t *testing.T) { + t.Parallel() + + cmd := func() *serpent.Command { return sampleCommand(t) } + + t.Run("DirOK", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + i := cmd().Invoke("file", tempDir) + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("%s%c\n", tempDir, os.PathSeparator), io.Stdout.String()) + }) + + t.Run("EmptyDirOK", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + string(os.PathSeparator) + i := cmd().Invoke("file", tempDir) + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "", io.Stdout.String()) + }) + + cases := []struct { + name string + realPath string + paths []string + }{ + { + name: "CurDirOK", + realPath: ".", + paths: []string{"", "./", "././"}, + }, + { + name: "PrevDirOK", + realPath: "..", + paths: []string{"../", ".././"}, + }, + { + name: "RootOK", + realPath: "/", + paths: []string{"/", "/././"}, + }, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + for _, path := range tc.paths { + i := cmd().Invoke("file", path) + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + output := strings.Split(io.Stdout.String(), "\n") + output = output[:len(output)-1] + for _, str := range output { + if strings.HasSuffix(str, string(os.PathSeparator)) { + require.DirExists(t, str) + } else { + require.FileExists(t, str) + } + } + files, err := os.ReadDir(tc.realPath) + require.NoError(t, err) + require.Equal(t, len(files), len(output)) + } + }) + } +} diff --git a/example/completetest/main.go b/example/completetest/main.go new file mode 100644 index 0000000..920e705 --- /dev/null +++ b/example/completetest/main.go @@ -0,0 +1,126 @@ +package main + +import ( + "os" + "strings" + + "github.com/coder/serpent" + "github.com/coder/serpent/completion" +) + +// installCommand returns a serpent command that helps +// a user configure their shell to use serpent's completion. +func installCommand() *serpent.Command { + defaultShell, err := completion.DetectUserShell() + if err != nil { + defaultShell = "bash" + } + + var shell string + return &serpent.Command{ + Use: "completion", + Short: "Generate completion scripts for the given shell.", + Handler: func(inv *serpent.Invocation) error { + completion.WriteCompletion(inv.Stdout, shell, inv.Command.Parent.Name()) + return nil + }, + Options: serpent.OptionSet{ + { + Flag: "shell", + FlagShorthand: "s", + Default: defaultShell, + Description: "The shell to generate a completion script for.", + Value: completion.ShellOptions(&shell), + }, + }, + } +} + +func main() { + var ( + print bool + upper bool + fileType string + fileArr []string + ) + cmd := serpent.Command{ + Use: "completetest ", + Short: "Prints the given text to the console.", + Options: serpent.OptionSet{ + { + Name: "different", + Value: serpent.BoolOf(&upper), + Flag: "different", + Description: "Do the command differently.", + }, + }, + Handler: func(inv *serpent.Invocation) error { + if len(inv.Args) == 0 { + inv.Stderr.Write([]byte("error: missing text\n")) + os.Exit(1) + } + + text := inv.Args[0] + if upper { + text = strings.ToUpper(text) + } + + inv.Stdout.Write([]byte(text)) + return nil + }, + Children: []*serpent.Command{ + { + Use: "sub", + Short: "A subcommand", + Handler: func(inv *serpent.Invocation) error { + inv.Stdout.Write([]byte("subcommand")) + return nil + }, + Options: serpent.OptionSet{ + { + Name: "upper", + Value: serpent.BoolOf(&upper), + Flag: "upper", + Description: "Prints the text in upper case.", + }, + }, + }, + { + Use: "file ", + Handler: func(inv *serpent.Invocation) error { + return nil + }, + Options: serpent.OptionSet{ + { + Name: "print", + Value: serpent.BoolOf(&print), + Flag: "print", + Description: "Print the file.", + }, + { + Name: "type", + Value: serpent.EnumOf(&fileType, "binary", "text"), + Flag: "type", + Description: "The type of file.", + }, + { + Name: "extra", + Flag: "extra", + Description: "Extra files.", + Value: serpent.StringArrayOf(&fileArr), + }, + }, + CompletionHandler: completion.FileHandler(nil), + Middleware: serpent.RequireNArgs(1), + }, + installCommand(), + }, + } + + inv := cmd.Invoke().WithOS() + + err := inv.Run() + if err != nil { + panic(err) + } +} diff --git a/option.go b/option.go index 5545d07..fccc67e 100644 --- a/option.go +++ b/option.go @@ -65,6 +65,8 @@ type Option struct { Hidden bool `json:"hidden,omitempty"` ValueSource ValueSource `json:"value_source,omitempty"` + + CompletionHandler CompletionHandlerFunc `json:"-"` } // optionNoMethods is just a wrapper around Option so we can defer to the @@ -335,10 +337,22 @@ func (optSet *OptionSet) SetDefaults() error { // ByName returns the Option with the given name, or nil if no such option // exists. -func (optSet *OptionSet) ByName(name string) *Option { - for i := range *optSet { - opt := &(*optSet)[i] - if opt.Name == name { +func (optSet OptionSet) ByName(name string) *Option { + for i := range optSet { + if optSet[i].Name == name { + return &optSet[i] + } + } + return nil +} + +func (optSet OptionSet) ByFlag(flag string) *Option { + if flag == "" { + return nil + } + for i := range optSet { + opt := &optSet[i] + if opt.Flag == flag || opt.FlagShorthand == flag { return opt } }