Skip to content

Add auto-completion #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 1, 2024
98 changes: 93 additions & 5 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ type Command struct {
Middleware MiddlewareFunc
Handler HandlerFunc
HelpHandler HandlerFunc
// CompletionHandler is called when the command is run is completion
// mode. If nil, only the default completion handler is used.
//
// Flag and option parsing is best-effort in this mode, so even if an Option
// is "required" it may not be set.
CompletionHandler CompletionHandlerFunc
}

// AddSubcommands adds the given subcommands, setting their
Expand Down Expand Up @@ -181,6 +187,7 @@ func (c *Command) Invoke(args ...string) *Invocation {
return &Invocation{
Command: c,
Args: args,
AllArgs: args,
Stdout: io.Discard,
Stderr: io.Discard,
Stdin: strings.NewReader(""),
Expand All @@ -193,15 +200,27 @@ type Invocation struct {
ctx context.Context
Command *Command
parsedFlags *pflag.FlagSet
Args []string

// Args is reduced into the remaining arguments after parsing flags
// during Run.
Args []string
// AllArgs is the original arguments passed to the command, including flags.
// When invoked `WithOS`, this includes argv[0], otherwise it is the same as Args.
AllArgs []string
// CurWord is the word the terminal cursor is currently in
CurWord string
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can CurWord be an index into AllArgs? This provides future compatibility with mid-line completions.

Copy link
Member

@ethanndickson ethanndickson Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To get auto-complete for equals flags (--flag=) I'm currently just setting this to empty string before we call any handlers. If we don't do this then anyone writing a flag completion handler would sometimes see --flag=<arg> as the current word, and other times just <arg>

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what use-case for mid-line completions isn't handled by just truncating the line at the cursor? Everything I tried just works.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, ok. If we're truncating the line at the cursor then CurWord's only purpose is to distinguish the true flag value? That makes sense, I think it warrants a comment.


// Environ is a list of environment variables. Use EnvsWithPrefix to parse
// os.Environ.
Environ Environ
Stdout io.Writer
Stderr io.Writer
Stdin io.Reader
Logger slog.Logger
Net Net

// Deprecated
Logger slog.Logger
// Deprecated
Net Net

// testing
signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc)
Expand All @@ -214,6 +233,7 @@ func (inv *Invocation) WithOS() *Invocation {
i.Stdout = os.Stdout
i.Stderr = os.Stderr
i.Stdin = os.Stdin
i.AllArgs = os.Args
i.Args = os.Args[1:]
i.Environ = ParseEnviron(os.Environ(), "")
i.Net = osNet{}
Expand Down Expand Up @@ -282,6 +302,17 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
return fs2
}

func (inv *Invocation) GetCurWords() (prev string, cur string) {
if len(inv.AllArgs) == 1 {
cur = inv.AllArgs[0]
prev = ""
} else {
cur = inv.AllArgs[len(inv.AllArgs)-1]
prev = inv.AllArgs[len(inv.AllArgs)-2]
}
return
}

// run recursively executes the command and its children.
// allArgs is wired through the stack so that global flags can be accepted
// anywhere in the command invocation.
Expand Down Expand Up @@ -378,8 +409,10 @@ func (inv *Invocation) run(state *runState) error {
}
}

ignoreFlagParseErrors := inv.Command.RawArgs || inv.IsCompletionMode()

// Flag parse errors are irrelevant for raw args commands.
if !inv.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
if !ignoreFlagParseErrors && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
return xerrors.Errorf(
"parsing flags (%v) for %q: %w",
state.allArgs,
Expand All @@ -396,7 +429,7 @@ func (inv *Invocation) run(state *runState) error {
}
}
// Don't error for missing flags if `--help` was supplied.
if len(missing) > 0 && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
if len(missing) > 0 && !inv.IsCompletionMode() && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
return xerrors.Errorf("Missing values for the required flags: %s", strings.Join(missing, ", "))
}

Expand Down Expand Up @@ -431,6 +464,36 @@ func (inv *Invocation) run(state *runState) error {
defer cancel()
inv = inv.WithContext(ctx)

if inv.IsCompletionMode() {
prev, cur := inv.GetCurWords()
inv.CurWord = cur
if prev != "" {
// If the previous word is a flag, we use it's handler
if strings.HasPrefix(prev, "--") {
opt := inv.Command.Options.ByFlag(prev[2:])
if opt != nil && opt.CompletionHandler != nil {
for _, e := range opt.CompletionHandler(inv) {
fmt.Fprintf(inv.Stdout, "%s\n", e)
}
return nil
}
}
}
if inv.Command.Name() == inv.CurWord {
fmt.Fprintf(inv.Stdout, "%s\n", inv.Command.Name())
return nil
}
if inv.Command.CompletionHandler != nil {
for _, e := range inv.Command.CompletionHandler(inv) {
fmt.Fprintf(inv.Stdout, "%s\n", e)
}
}
for _, e := range DefaultCompletionHandler(inv) {
fmt.Fprintf(inv.Stdout, "%s\n", e)
}
return nil
}

if inv.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
if inv.Command.HelpHandler == nil {
return defaultHelpFn()(inv)
Expand Down Expand Up @@ -500,6 +563,27 @@ func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) {
return -1, xerrors.Errorf("arg %s not found", want)
}

// // findArgByPos returns the index of first full word before the given cursor position in the arguments
// // list. If the cursor is at the end of the line, the last word is returned.
// func findArgByPos(pos int, args []string) int {
// if pos == 0 {
// return -1
// }
// if len(args) == 0 {
// return -1
// }
// curChar := 0
// for i, arg := range args {
// next := curChar + len(arg)
// if pos <= next {
// return i
// }
// curChar = next + 1
// }
// // Otherwise, must be the last word
// return len(args)
// }

// Run executes the command.
// If two command share a flag name, the first command wins.
//
Expand Down Expand Up @@ -637,3 +721,7 @@ func RequireRangeArgs(start, end int) MiddlewareFunc {

// HandlerFunc handles an Invocation of a command.
type HandlerFunc func(i *Invocation) error

type CompletionHandlerFunc func(i *Invocation) []string

var NopHandler HandlerFunc = func(i *Invocation) error { return nil }
192 changes: 107 additions & 85 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"golang.org/x/xerrors"

serpent "github.com/coder/serpent"
"github.com/coder/serpent/completion"
)

// ioBufs is the standard input, output, and error for a command.
Expand All @@ -30,100 +31,121 @@ func fakeIO(i *serpent.Invocation) *ioBufs {
return &b
}

func TestCommand(t *testing.T) {
t.Parallel()

cmd := func() *serpent.Command {
var (
verbose bool
lower bool
prefix string
reqBool bool
reqStr string
)
return &serpent.Command{
Use: "root [subcommand]",
Options: serpent.OptionSet{
serpent.Option{
Name: "verbose",
Flag: "verbose",
Value: serpent.BoolOf(&verbose),
},
serpent.Option{
Name: "prefix",
Flag: "prefix",
Value: serpent.StringOf(&prefix),
},
func SampleCommand(t *testing.T) *serpent.Command {
t.Helper()
var (
verbose bool
lower bool
prefix string
reqBool bool
reqStr string
enumStr string
)
enumChoices := []string{"foo", "bar", "qux"}
return &serpent.Command{
Use: "root [subcommand]",
Options: serpent.OptionSet{
serpent.Option{
Name: "verbose",
Flag: "verbose",
Value: serpent.BoolOf(&verbose),
},
Children: []*serpent.Command{
{
Use: "required-flag --req-bool=true --req-string=foo",
Short: "Example with required flags",
Options: serpent.OptionSet{
serpent.Option{
Name: "req-bool",
Flag: "req-bool",
Value: serpent.BoolOf(&reqBool),
Required: true,
},
serpent.Option{
Name: "req-string",
Flag: "req-string",
Value: serpent.Validate(serpent.StringOf(&reqStr), func(value *serpent.String) error {
ok := strings.Contains(value.String(), " ")
if !ok {
return xerrors.Errorf("string must contain a space")
}
return nil
}),
Required: true,
},
serpent.Option{
Name: "prefix",
Flag: "prefix",
Value: serpent.StringOf(&prefix),
},
},
Children: []*serpent.Command{
{
Use: "required-flag --req-bool=true --req-string=foo",
Short: "Example with required flags",
Options: serpent.OptionSet{
serpent.Option{
Name: "req-bool",
Flag: "req-bool",
Value: serpent.BoolOf(&reqBool),
Required: true,
},
HelpHandler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte("help text.png"))
return nil
serpent.Option{
Name: "req-string",
Flag: "req-string",
Value: serpent.Validate(serpent.StringOf(&reqStr), func(value *serpent.String) error {
ok := strings.Contains(value.String(), " ")
if !ok {
return xerrors.Errorf("string must contain a space")
}
return nil
}),
Required: true,
},
Handler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte(fmt.Sprintf("%s-%t", reqStr, reqBool)))
return nil
serpent.Option{
Name: "req-enum",
Flag: "req-enum",
Value: serpent.EnumOf(&enumStr, enumChoices...),
CompletionHandler: completion.EnumHandler(enumChoices...),
},
},
{
Use: "toupper [word]",
Short: "Converts a word to upper case",
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Aliases: []string{"up"},
Options: serpent.OptionSet{
serpent.Option{
Name: "lower",
Flag: "lower",
Value: serpent.BoolOf(&lower),
},
},
Handler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte(prefix))
w := i.Args[0]
if lower {
w = strings.ToLower(w)
} else {
w = strings.ToUpper(w)
}
_, _ = i.Stdout.Write(
[]byte(
w,
),
)
if verbose {
_, _ = i.Stdout.Write([]byte("!!!"))
}
return nil
HelpHandler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte("help text.png"))
return nil
},
Handler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte(fmt.Sprintf("%s-%t", reqStr, reqBool)))
return nil
},
},
{
Use: "toupper [word]",
Short: "Converts a word to upper case",
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Aliases: []string{"up"},
Options: serpent.OptionSet{
serpent.Option{
Name: "lower",
Flag: "lower",
Value: serpent.BoolOf(&lower),
},
},
Handler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte(prefix))
w := i.Args[0]
if lower {
w = strings.ToLower(w)
} else {
w = strings.ToUpper(w)
}
_, _ = i.Stdout.Write(
[]byte(
w,
),
)
if verbose {
_, _ = i.Stdout.Write([]byte("!!!"))
}
return nil
},
},
}
{
Use: "file <file>",
Handler: func(inv *serpent.Invocation) error {
return nil
},
CompletionHandler: completion.FileHandler(func(info os.FileInfo) bool {
return true
}),
Middleware: serpent.RequireNArgs(1),
},
},
}
}

func TestCommand(t *testing.T) {
t.Parallel()

cmd := func() *serpent.Command { return SampleCommand(t) }

t.Run("SimpleOK", func(t *testing.T) {
t.Parallel()
Expand Down
Loading
Loading