Skip to content

Commit 09f1207

Browse files
authored
Fix overwrite by SetDefault for options that share Value (#23)
1 parent 5a56b57 commit 09f1207

File tree

3 files changed

+182
-17
lines changed

3 files changed

+182
-17
lines changed

command_test.go

+104-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,14 @@ func sampleCommand(t *testing.T) *serpent.Command {
4949
Use: "root [subcommand]",
5050
Options: serpent.OptionSet{
5151
serpent.Option{
52-
Name: "verbose",
53-
Flag: "verbose",
52+
Name: "verbose",
53+
Flag: "verbose",
54+
Default: "false",
55+
Value: serpent.BoolOf(&verbose),
56+
},
57+
serpent.Option{
58+
Name: "verbose-old",
59+
Flag: "verbode-old",
5460
Value: serpent.BoolOf(&verbose),
5561
},
5662
serpent.Option{
@@ -742,6 +748,12 @@ func TestCommand_DefaultsOverride(t *testing.T) {
742748
Value: serpent.StringOf(&got),
743749
YAML: "url",
744750
},
751+
{
752+
Name: "url-deprecated",
753+
Flag: "url-deprecated",
754+
Env: "URL_DEPRECATED",
755+
Value: serpent.StringOf(&got),
756+
},
745757
{
746758
Name: "config",
747759
Flag: "config",
@@ -790,6 +802,17 @@ func TestCommand_DefaultsOverride(t *testing.T) {
790802
inv.Args = []string{"--config", fi.Name(), "--url", "good.com"}
791803
})
792804

805+
test("EnvOverYAML", "good.com", func(t *testing.T, inv *serpent.Invocation) {
806+
fi, err := os.CreateTemp(t.TempDir(), "config.yaml")
807+
require.NoError(t, err)
808+
defer fi.Close()
809+
810+
_, err = fi.WriteString("url: bad.com")
811+
require.NoError(t, err)
812+
813+
inv.Environ.Set("URL", "good.com")
814+
})
815+
793816
test("YAMLOverDefault", "good.com", func(t *testing.T, inv *serpent.Invocation) {
794817
fi, err := os.CreateTemp(t.TempDir(), "config.yaml")
795818
require.NoError(t, err)
@@ -800,4 +823,83 @@ func TestCommand_DefaultsOverride(t *testing.T) {
800823

801824
inv.Args = []string{"--config", fi.Name()}
802825
})
826+
827+
test("AltFlagOverDefault", "good.com", func(t *testing.T, inv *serpent.Invocation) {
828+
inv.Args = []string{"--url-deprecated", "good.com"}
829+
})
830+
}
831+
832+
func TestCommand_OptionsWithSharedValue(t *testing.T) {
833+
t.Parallel()
834+
835+
var got string
836+
makeCmd := func(def, altDef string) *serpent.Command {
837+
got = ""
838+
return &serpent.Command{
839+
Options: serpent.OptionSet{
840+
{
841+
Name: "url",
842+
Flag: "url",
843+
Env: "URL",
844+
Default: def,
845+
Value: serpent.StringOf(&got),
846+
},
847+
{
848+
Name: "alt-url",
849+
Flag: "alt-url",
850+
Env: "ALT_URL",
851+
Default: altDef,
852+
Value: serpent.StringOf(&got),
853+
},
854+
},
855+
Handler: (func(i *serpent.Invocation) error {
856+
return nil
857+
}),
858+
}
859+
}
860+
861+
// Check proper value propagation.
862+
err := makeCmd("def.com", "def.com").Invoke().Run()
863+
require.NoError(t, err, "default values are same")
864+
require.Equal(t, "def.com", got)
865+
866+
err = makeCmd("def.com", "").Invoke().Run()
867+
require.NoError(t, err, "other default value is empty")
868+
require.Equal(t, "def.com", got)
869+
870+
err = makeCmd("def.com", "").Invoke("--url", "sup").Run()
871+
require.NoError(t, err)
872+
require.Equal(t, "sup", got)
873+
874+
err = makeCmd("def.com", "").Invoke("--alt-url", "hup").Run()
875+
require.NoError(t, err)
876+
require.Equal(t, "hup", got)
877+
878+
// Both flags are given, last wins.
879+
err = makeCmd("def.com", "").Invoke("--url", "sup", "--alt-url", "hup").Run()
880+
require.NoError(t, err)
881+
require.Equal(t, "hup", got)
882+
883+
// Both flags are given, last wins #2.
884+
err = makeCmd("", "def.com").Invoke("--alt-url", "hup", "--url", "sup").Run()
885+
require.NoError(t, err)
886+
require.Equal(t, "sup", got)
887+
888+
// Both flags are given, option type priority wins.
889+
inv := makeCmd("def.com", "").Invoke("--alt-url", "hup")
890+
inv.Environ.Set("URL", "sup")
891+
err = inv.Run()
892+
require.NoError(t, err)
893+
require.Equal(t, "hup", got)
894+
895+
// Both flags are given, option type priority wins #2.
896+
inv = makeCmd("", "def.com").Invoke("--url", "sup")
897+
inv.Environ.Set("ALT_URL", "hup")
898+
err = inv.Run()
899+
require.NoError(t, err)
900+
require.Equal(t, "sup", got)
901+
902+
// Catch invalid configuration.
903+
err = makeCmd("def.com", "alt-def.com").Invoke().Run()
904+
require.Error(t, err, "default values are different")
803905
}

option.go

+74-13
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"encoding/json"
66
"os"
7+
"slices"
78
"strings"
89

910
"github.com/hashicorp/go-multierror"
@@ -21,6 +22,14 @@ const (
2122
ValueSourceDefault ValueSource = "default"
2223
)
2324

25+
var valueSourcePriority = []ValueSource{
26+
ValueSourceFlag,
27+
ValueSourceEnv,
28+
ValueSourceYAML,
29+
ValueSourceDefault,
30+
ValueSourceNone,
31+
}
32+
2433
// Option is a configuration option for a CLI application.
2534
type Option struct {
2635
Name string `json:"name,omitempty"`
@@ -305,16 +314,12 @@ func (optSet *OptionSet) SetDefaults() error {
305314

306315
var merr *multierror.Error
307316

308-
for i, opt := range *optSet {
309-
// Skip values that may have already been set by the user.
310-
if opt.ValueSource != ValueSourceNone {
311-
continue
312-
}
313-
314-
if opt.Default == "" {
315-
continue
316-
}
317-
317+
// It's common to have multiple options with the same value to
318+
// handle deprecation. We group the options by value so that we
319+
// don't let other options overwrite user input.
320+
groupByValue := make(map[pflag.Value][]*Option)
321+
for i := range *optSet {
322+
opt := &(*optSet)[i]
318323
if opt.Value == nil {
319324
merr = multierror.Append(
320325
merr,
@@ -325,13 +330,69 @@ func (optSet *OptionSet) SetDefaults() error {
325330
)
326331
continue
327332
}
328-
(*optSet)[i].ValueSource = ValueSourceDefault
329-
if err := opt.Value.Set(opt.Default); err != nil {
333+
groupByValue[opt.Value] = append(groupByValue[opt.Value], opt)
334+
}
335+
336+
// Sorts by value source, then a default value being set.
337+
sortOptionByValueSourcePriorityOrDefault := func(a, b *Option) int {
338+
if a.ValueSource != b.ValueSource {
339+
return slices.Index(valueSourcePriority, a.ValueSource) - slices.Index(valueSourcePriority, b.ValueSource)
340+
}
341+
if a.Default != b.Default {
342+
if a.Default == "" {
343+
return 1
344+
}
345+
if b.Default == "" {
346+
return -1
347+
}
348+
}
349+
return 0
350+
}
351+
for _, opts := range groupByValue {
352+
// Sort the options by priority and whether or not a default is
353+
// set. This won't affect the value but represents correctness
354+
// from whence the value originated.
355+
slices.SortFunc(opts, sortOptionByValueSourcePriorityOrDefault)
356+
357+
// If the first option has a value source, then we don't need to
358+
// set the default, but mark the source for all options.
359+
if opts[0].ValueSource != ValueSourceNone {
360+
for _, opt := range opts[1:] {
361+
opt.ValueSource = opts[0].ValueSource
362+
}
363+
continue
364+
}
365+
366+
var optWithDefault *Option
367+
for _, opt := range opts {
368+
if opt.Default == "" {
369+
continue
370+
}
371+
if optWithDefault != nil && optWithDefault.Default != opt.Default {
372+
merr = multierror.Append(
373+
merr,
374+
xerrors.Errorf(
375+
"parse %q: multiple defaults set for the same value: %q and %q (%q)",
376+
opt.Name, opt.Default, optWithDefault.Default, optWithDefault.Name,
377+
),
378+
)
379+
continue
380+
}
381+
optWithDefault = opt
382+
}
383+
if optWithDefault == nil {
384+
continue
385+
}
386+
if err := optWithDefault.Value.Set(optWithDefault.Default); err != nil {
330387
merr = multierror.Append(
331-
merr, xerrors.Errorf("parse %q: %w", opt.Name, err),
388+
merr, xerrors.Errorf("parse %q: %w", optWithDefault.Name, err),
332389
)
333390
}
391+
for _, opt := range opts {
392+
opt.ValueSource = ValueSourceDefault
393+
}
334394
}
395+
335396
return merr.ErrorOrNil()
336397
}
337398

yaml.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,10 @@ func (o *Option) setFromYAMLNode(n *yaml.Node) error {
213213
// We treat empty values as nil for consistency with other option
214214
// mechanisms.
215215
if len(n.Content) == 0 {
216-
o.Value = nil
217-
return nil
216+
if o.Value == nil {
217+
return nil
218+
}
219+
return o.Value.Set("")
218220
}
219221
return n.Decode(o.Value)
220222
case yaml.MappingNode:

0 commit comments

Comments
 (0)