Skip to content

Commit 34d0297

Browse files
authored
Merge pull request #1730 from pwschuurman/update-enum-flag-strings
Improve enum and string list flag parse handling
2 parents 4004cd7 + c6f1a4a commit 34d0297

File tree

1 file changed

+41
-10
lines changed

1 file changed

+41
-10
lines changed

cmd/gce-pd-csi-driver/main.go

+41-10
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"time"
2929

3030
"k8s.io/klog/v2"
31+
"k8s.io/utils/strings/slices"
3132

3233
"sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/common"
3334
"sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/deviceutils"
@@ -98,7 +99,7 @@ func init() {
9899
// Use V(4) for general debug information logging
99100
// Use V(5) for GCE Cloud Provider Call informational logging
100101
// Use V(6) for extra repeated/polling information
101-
enumFlag(&computeEnvironment, "compute-environment", allowedComputeEnvironment, "Operating compute environment")
102+
stringEnumFlag(&computeEnvironment, "compute-environment", allowedComputeEnvironment, "Operating compute environment")
102103
urlFlag(&computeEndpoint, "compute-endpoint", "Compute endpoint")
103104
klog.InitFlags(flag.CommandLine)
104105
flag.Set("logtostderr", "true")
@@ -175,23 +176,23 @@ func handle() {
175176
identityServer := driver.NewIdentityServer(gceDriver)
176177

177178
// Initialize requisite zones
178-
fallbackRequisiteZones := strings.Split(*fallbackRequisiteZonesFlag, ",")
179+
fallbackRequisiteZones := parseCSVFlag(*fallbackRequisiteZonesFlag)
179180

180181
// Initialize multi-zone disk types
181-
multiZoneVolumeHandleDiskTypes := strings.Split(*multiZoneVolumeHandleDiskTypesFlag, ",")
182+
multiZoneVolumeHandleDiskTypes := parseCSVFlag(*multiZoneVolumeHandleDiskTypesFlag)
182183
multiZoneVolumeHandleConfig := driver.MultiZoneVolumeHandleConfig{
183184
Enable: *multiZoneVolumeHandleEnableFlag,
184185
DiskTypes: multiZoneVolumeHandleDiskTypes,
185186
}
186187

187188
// Initialize waitForAttach config
188-
useInstanceAPIOnWaitForAttachDiskTypes := strings.Split(*useInstanceAPIOnWaitForAttachDiskTypesFlag, ",")
189+
useInstanceAPIOnWaitForAttachDiskTypes := parseCSVFlag(*useInstanceAPIOnWaitForAttachDiskTypesFlag)
189190
waitForAttachConfig := gce.WaitForAttachConfig{
190191
UseInstancesAPIForDiskTypes: useInstanceAPIOnWaitForAttachDiskTypes,
191192
}
192193

193194
// Initialize listVolumes config
194-
instancesListFilters := strings.Split(*instancesListFiltersFlag, ",")
195+
instancesListFilters := parseCSVFlag(*instancesListFiltersFlag)
195196
listInstancesConfig := gce.ListInstancesConfig{
196197
Filters: instancesListFilters,
197198
}
@@ -252,18 +253,48 @@ func handle() {
252253
gceDriver.Run(*endpoint, *grpcLogCharCap, *enableOtelTracing)
253254
}
254255

255-
func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []gce.Environment, usage string) {
256+
func notEmpty(v string) bool {
257+
return v != ""
258+
}
259+
260+
func parseCSVFlag(list string) []string {
261+
return slices.Filter(nil, strings.Split(list, ","), notEmpty)
262+
}
263+
264+
type enumConverter[T any] interface {
265+
convert(v string) (T, error)
266+
eq(a, b T) bool
267+
}
268+
269+
type stringConverter[T ~string] struct{}
270+
271+
func (s stringConverter[T]) convert(v string) (T, error) {
272+
return T(v), nil
273+
}
274+
275+
func (s stringConverter[T]) eq(a, b T) bool {
276+
return a == b
277+
}
278+
279+
func stringEnumFlag[T ~string](target *T, name string, allowed []T, usage string) {
280+
enumFlag(target, name, stringConverter[T]{}, allowed, usage)
281+
}
282+
283+
func enumFlag[T any](target *T, name string, converter enumConverter[T], allowed []T, usage string) {
256284
flag.Func(name, usage, func(flagValue string) error {
257-
for _, allowedValue := range allowedComputeEnvironment {
258-
if gce.Environment(flagValue) == allowedValue {
259-
*target = gce.Environment(flagValue)
285+
tValue, err := converter.convert(flagValue)
286+
if err != nil {
287+
return err
288+
}
289+
for _, allowedValue := range allowed {
290+
if converter.eq(allowedValue, tValue) {
291+
*target = tValue
260292
return nil
261293
}
262294
}
263295
errMsg := fmt.Sprintf(`must be one of %v`, allowedComputeEnvironment)
264296
return errors.New(errMsg)
265297
})
266-
267298
}
268299

269300
func urlFlag(target **url.URL, name string, usage string) {

0 commit comments

Comments
 (0)