diff --git a/cmd/gce-pd-csi-driver/main.go b/cmd/gce-pd-csi-driver/main.go index 836664758..3c7808498 100644 --- a/cmd/gce-pd-csi-driver/main.go +++ b/cmd/gce-pd-csi-driver/main.go @@ -28,6 +28,7 @@ import ( "time" "k8s.io/klog/v2" + "k8s.io/utils/strings/slices" "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/common" "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/deviceutils" @@ -98,7 +99,7 @@ func init() { // Use V(4) for general debug information logging // Use V(5) for GCE Cloud Provider Call informational logging // Use V(6) for extra repeated/polling information - enumFlag(&computeEnvironment, "compute-environment", allowedComputeEnvironment, "Operating compute environment") + stringEnumFlag(&computeEnvironment, "compute-environment", allowedComputeEnvironment, "Operating compute environment") urlFlag(&computeEndpoint, "compute-endpoint", "Compute endpoint") klog.InitFlags(flag.CommandLine) flag.Set("logtostderr", "true") @@ -175,23 +176,23 @@ func handle() { identityServer := driver.NewIdentityServer(gceDriver) // Initialize requisite zones - fallbackRequisiteZones := strings.Split(*fallbackRequisiteZonesFlag, ",") + fallbackRequisiteZones := parseCSVFlag(*fallbackRequisiteZonesFlag) // Initialize multi-zone disk types - multiZoneVolumeHandleDiskTypes := strings.Split(*multiZoneVolumeHandleDiskTypesFlag, ",") + multiZoneVolumeHandleDiskTypes := parseCSVFlag(*multiZoneVolumeHandleDiskTypesFlag) multiZoneVolumeHandleConfig := driver.MultiZoneVolumeHandleConfig{ Enable: *multiZoneVolumeHandleEnableFlag, DiskTypes: multiZoneVolumeHandleDiskTypes, } // Initialize waitForAttach config - useInstanceAPIOnWaitForAttachDiskTypes := strings.Split(*useInstanceAPIOnWaitForAttachDiskTypesFlag, ",") + useInstanceAPIOnWaitForAttachDiskTypes := parseCSVFlag(*useInstanceAPIOnWaitForAttachDiskTypesFlag) waitForAttachConfig := gce.WaitForAttachConfig{ UseInstancesAPIForDiskTypes: useInstanceAPIOnWaitForAttachDiskTypes, } // Initialize listVolumes config - instancesListFilters := strings.Split(*instancesListFiltersFlag, ",") + instancesListFilters := parseCSVFlag(*instancesListFiltersFlag) listInstancesConfig := gce.ListInstancesConfig{ Filters: instancesListFilters, } @@ -252,18 +253,48 @@ func handle() { gceDriver.Run(*endpoint, *grpcLogCharCap, *enableOtelTracing) } -func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []gce.Environment, usage string) { +func notEmpty(v string) bool { + return v != "" +} + +func parseCSVFlag(list string) []string { + return slices.Filter(nil, strings.Split(list, ","), notEmpty) +} + +type enumConverter[T any] interface { + convert(v string) (T, error) + eq(a, b T) bool +} + +type stringConverter[T ~string] struct{} + +func (s stringConverter[T]) convert(v string) (T, error) { + return T(v), nil +} + +func (s stringConverter[T]) eq(a, b T) bool { + return a == b +} + +func stringEnumFlag[T ~string](target *T, name string, allowed []T, usage string) { + enumFlag(target, name, stringConverter[T]{}, allowed, usage) +} + +func enumFlag[T any](target *T, name string, converter enumConverter[T], allowed []T, usage string) { flag.Func(name, usage, func(flagValue string) error { - for _, allowedValue := range allowedComputeEnvironment { - if gce.Environment(flagValue) == allowedValue { - *target = gce.Environment(flagValue) + tValue, err := converter.convert(flagValue) + if err != nil { + return err + } + for _, allowedValue := range allowed { + if converter.eq(allowedValue, tValue) { + *target = tValue return nil } } errMsg := fmt.Sprintf(`must be one of %v`, allowedComputeEnvironment) return errors.New(errMsg) }) - } func urlFlag(target **url.URL, name string, usage string) {