Skip to content

Commit fd460d3

Browse files
committed
Add checks for compute environment flags and url checks
1 parent 6aae7d2 commit fd460d3

File tree

3 files changed

+69
-42
lines changed

3 files changed

+69
-42
lines changed

cmd/gce-pd-csi-driver/main.go

+25-7
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ package main
1717

1818
import (
1919
"context"
20+
"errors"
2021
"flag"
22+
"fmt"
2123
"math/rand"
2224
"os"
2325
"runtime"
@@ -67,13 +69,12 @@ var (
6769

6870
maxConcurrentFormatAndMount = flag.Int("max-concurrent-format-and-mount", 1, "If set then format and mount operations are serialized on each node. This is stronger than max-concurrent-format as it includes fsck and other mount operations")
6971
formatAndMountTimeout = flag.Duration("format-and-mount-timeout", 1*time.Minute, "The maximum duration of a format and mount operation before another such operation will be started. Used only if --serialize-format-and-mount")
70-
computeEnvironment = flag.String("compute-environment", "prod", "Sets the compute environment")
72+
fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")
7173

72-
fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")
73-
74-
enableStoragePoolsFlag = flag.Bool("enable-storage-pools", false, "If set to true, the CSI Driver will allow volumes to be provisioned in Storage Pools")
75-
76-
version string
74+
enableStoragePoolsFlag = flag.Bool("enable-storage-pools", false, "If set to true, the CSI Driver will allow volumes to be provisioned in Storage Pools")
75+
computeEnvironment gce.Environment = "production"
76+
version string
77+
allowedComputeEnvironment = []string{"staging", "production"}
7778
)
7879

7980
const (
@@ -154,10 +155,13 @@ func handle() {
154155
// Initilaize requisite zones
155156
fallbackRequisiteZones := strings.Split(*fallbackRequisiteZonesFlag, ",")
156157

158+
enumFlag(&computeEnvironment, "computeEnvironment", allowedComputeEnvironment, "Operating compute environment")
159+
flag.Parse()
160+
157161
// Initialize requirements for the controller service
158162
var controllerServer *driver.GCEControllerServer
159163
if *runControllerService {
160-
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint, *computeEnvironment)
164+
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint, computeEnvironment)
161165
if err != nil {
162166
klog.Fatalf("Failed to get cloud provider: %v", err.Error())
163167
}
@@ -206,3 +210,17 @@ func handle() {
206210

207211
gceDriver.Run(*endpoint, *grpcLogCharCap, *enableOtelTracing)
208212
}
213+
214+
func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []string, usage string) {
215+
flag.Func(name, usage, func(flagValue string) error {
216+
for _, allowedValue := range allowedComputeEnvironment {
217+
if flagValue == allowedValue {
218+
*target = gce.Environment(flagValue)
219+
return nil
220+
}
221+
}
222+
errMsg := fmt.Sprintf(`must be one of %v`, allowedComputeEnvironment)
223+
return errors.New(errMsg)
224+
})
225+
226+
}

pkg/gce-cloud-provider/compute/gce.go

+39-29
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ import (
3838
"k8s.io/klog/v2"
3939
)
4040

41+
type Environment string
42+
type Version string
43+
4144
const (
4245
TokenURL = "https://accounts.google.com/o/oauth2/token"
4346
diskSourceURITemplateSingleZone = "projects/%s/zones/%s/disks/%s" // {gce.projectID}/zones/{disk.Zone}/disks/{disk.Name}"
@@ -47,25 +50,23 @@ const (
4750

4851
regionURITemplate = "projects/%s/regions/%s"
4952

50-
replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone}
51-
versionV1 = "v1"
52-
versionBeta = "beta"
53-
versionAlpha = "alpha"
54-
googleEnv = "googleapis"
53+
replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone}
54+
versionV1 Version = "v1"
55+
versionBeta Version = "beta"
56+
versionAlpha Version = "alpha"
57+
environmentStaging Environment = "staging"
5558
)
5659

57-
var computeVersionMap = map[string]map[string]map[string]string{
58-
googleEnv: {
59-
"prod": {
60-
versionV1: "compute/v1/",
61-
versionBeta: "compute/beta/",
62-
versionAlpha: "compute/alpha/",
63-
},
64-
"staging": {
65-
versionV1: "compute/staging_v1/",
66-
versionBeta: "compute/staging_beta/",
67-
versionAlpha: "compute/staging_alpha/",
68-
},
60+
var computeVersionMap = map[Environment]map[Version]string{
61+
"prod": {
62+
versionV1: "compute/v1/",
63+
versionBeta: "compute/beta/",
64+
versionAlpha: "compute/alpha/",
65+
},
66+
"staging": {
67+
versionV1: "compute/staging_v1/",
68+
versionBeta: "compute/staging_beta/",
69+
versionAlpha: "compute/staging_alpha/",
6970
},
7071
}
7172

@@ -92,7 +93,7 @@ type ConfigGlobal struct {
9293
Zone string `gcfg:"zone"`
9394
}
9495

95-
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string, computeEnvironment string) (*CloudProvider, error) {
96+
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string, computeEnvironment Environment) (*CloudProvider, error) {
9697
configFile, err := readConfig(configPath)
9798
if err != nil {
9899
return nil, err
@@ -187,7 +188,7 @@ func readConfig(configPath string) (*ConfigFile, error) {
187188
return cfg, nil
188189
}
189190

190-
func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment string) (*computealpha.Service, error) {
191+
func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*computealpha.Service, error) {
191192
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionAlpha)
192193
if err != nil {
193194
klog.Errorf("Failed to get compute endpoint: %s", err)
@@ -200,7 +201,7 @@ func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSou
200201
return service, nil
201202
}
202203

203-
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment string) (*computebeta.Service, error) {
204+
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*computebeta.Service, error) {
204205
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionBeta)
205206
if err != nil {
206207
klog.Errorf("Failed to get compute endpoint: %s", err)
@@ -213,7 +214,7 @@ func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSour
213214
return service, nil
214215
}
215216

216-
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment string) (*compute.Service, error) {
217+
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*compute.Service, error) {
217218
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionV1)
218219
if err != nil {
219220
klog.Errorf("Failed to get compute endpoint: %s", err)
@@ -226,20 +227,21 @@ func createCloudService(ctx context.Context, vendorVersion string, tokenSource o
226227
return service, nil
227228
}
228229

229-
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment string, computeVersion string) ([]option.ClientOption, error) {
230+
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment, computeVersion Version) ([]option.ClientOption, error) {
230231
client, err := newOauthClient(ctx, tokenSource)
231232
if err != nil {
232233
return nil, err
233234
}
234-
computeEnvironmentSuffix, ok := computeVersionMap[googleEnv][computeEnvironment][computeVersion]
235-
if !ok {
236-
return nil, errors.New("Unable to fetch compute endpoint")
237-
}
235+
computeEnvironmentSuffix := getPath(computeEnvironment, computeVersion)
238236
computeOpts := []option.ClientOption{option.WithHTTPClient(client)}
237+
239238
if computeEndpoint != "" {
240-
endpoint := fmt.Sprintf("%s%s", computeEndpoint, computeEnvironmentSuffix)
241-
klog.Infof("Got compute endpoint %s", endpoint)
242-
_, err := url.ParseRequestURI(endpoint)
239+
computeURL, err := url.ParseRequestURI(computeEndpoint)
240+
if err != nil {
241+
return nil, err
242+
}
243+
endpoint := computeURL.JoinPath(computeEnvironmentSuffix).String()
244+
_, err = url.ParseRequestURI(endpoint)
243245
if err != nil {
244246
klog.Fatalf("Error parsing compute endpoint %s", endpoint)
245247
}
@@ -248,6 +250,14 @@ func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, comp
248250
return computeOpts, nil
249251
}
250252

253+
func getPath(env Environment, version Version) string {
254+
prefix := ""
255+
if env == environmentStaging {
256+
prefix = fmt.Sprintf("%s_", env)
257+
}
258+
return fmt.Sprintf("compute/%s%s/", prefix, version)
259+
}
260+
251261
func newOauthClient(ctx context.Context, tokenSource oauth2.TokenSource) (*http.Client, error) {
252262
if err := wait.PollImmediate(5*time.Second, 30*time.Second, func() (bool, error) {
253263
if _, err := tokenSource.Token(); err != nil {

test/e2e/tests/single_zone_e2e_test.go

+5-6
Original file line numberDiff line numberDiff line change
@@ -1296,13 +1296,12 @@ var _ = Describe("GCE PD CSI Driver", func() {
12961296
klog.Infof("Creating new driver and client for node %s\n", i.GetName())
12971297

12981298
// Create new driver and client w/ invalid endpoint
1299-
tcInvalid, err := testutils.GCEClientAndDriverSetup(i, "invalid-string")
1300-
if err != nil {
1301-
klog.Fatalf("Failed to set up Test Context for instance %v: %v", i.GetName(), err)
1299+
computeEndpoint := "invalid-string"
1300+
tcInvalid, err := testutils.GCEClientAndDriverSetup(i, computeEndpoint)
1301+
if tcInvalid != nil {
1302+
klog.Fatalf("Driver setup with incorrect compute %v: %v", i.GetName(), computeEndpoint)
13021303
}
1303-
1304-
_, err = tcInvalid.Client.ListVolumes()
1305-
Expect(err.Error()).To(ContainSubstring("no such host"), "expected error when passed invalid compute url")
1304+
Expect(err.Error()).To(ContainSubstring("failed start driver"), "expected error when passed invalid compute url")
13061305

13071306
// Create new driver and client w/ valid, passed-in endpoint
13081307
tcValid, err := testutils.GCEClientAndDriverSetup(i, "https://compute.googleapis.com")

0 commit comments

Comments
 (0)