Skip to content

Commit afbc94c

Browse files
committed
Add checks for compute environment flags and url checks
1 parent c6d0ad4 commit afbc94c

File tree

3 files changed

+56
-43
lines changed

3 files changed

+56
-43
lines changed

cmd/gce-pd-csi-driver/main.go

+22-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ package main
1717

1818
import (
1919
"context"
20+
"errors"
2021
"flag"
22+
"fmt"
2123
"math/rand"
2224
"os"
2325
"runtime"
@@ -66,11 +68,11 @@ var (
6668

6769
maxConcurrentFormatAndMount = flag.Int("max-concurrent-format-and-mount", 1, "If set then format and mount operations are serialized on each node. This is stronger than max-concurrent-format as it includes fsck and other mount operations")
6870
formatAndMountTimeout = flag.Duration("format-and-mount-timeout", 1*time.Minute, "The maximum duration of a format and mount operation before another such operation will be started. Used only if --serialize-format-and-mount")
69-
computeEnvironment = flag.String("compute-environment", "prod", "Sets the compute environment")
71+
fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")
7072

71-
fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")
72-
73-
version string
73+
computeEnvironment gce.Environment = "production"
74+
version string
75+
allowedComputeEnvironment = []string{"staging", "production"}
7476
)
7577

7678
const (
@@ -83,6 +85,7 @@ func init() {
8385
// Use V(4) for general debug information logging
8486
// Use V(5) for GCE Cloud Provider Call informational logging
8587
// Use V(6) for extra repeated/polling information
88+
enumFlag(&computeEnvironment, "compute-environment", allowedComputeEnvironment, "Operating compute environment")
8689
klog.InitFlags(flag.CommandLine)
8790
flag.Set("logtostderr", "true")
8891
}
@@ -138,7 +141,7 @@ func handle() {
138141
// Initialize requirements for the controller service
139142
var controllerServer *driver.GCEControllerServer
140143
if *runControllerService {
141-
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint, *computeEnvironment)
144+
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint, computeEnvironment)
142145
if err != nil {
143146
klog.Fatalf("Failed to get cloud provider: %v", err.Error())
144147
}
@@ -187,3 +190,17 @@ func handle() {
187190

188191
gceDriver.Run(*endpoint, *grpcLogCharCap)
189192
}
193+
194+
func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []string, usage string) {
195+
flag.Func(name, usage, func(flagValue string) error {
196+
for _, allowedValue := range allowedComputeEnvironment {
197+
if flagValue == allowedValue {
198+
*target = gce.Environment(flagValue)
199+
return nil
200+
}
201+
}
202+
errMsg := fmt.Sprintf(`must be one of %v`, allowedComputeEnvironment)
203+
return errors.New(errMsg)
204+
})
205+
206+
}

pkg/gce-cloud-provider/compute/gce.go

+29-32
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ import (
3838
"k8s.io/klog/v2"
3939
)
4040

41+
type Environment string
42+
type Version string
43+
4144
const (
4245
TokenURL = "https://accounts.google.com/o/oauth2/token"
4346
diskSourceURITemplateSingleZone = "projects/%s/zones/%s/disks/%s" // {gce.projectID}/zones/{disk.Zone}/disks/{disk.Name}"
@@ -47,28 +50,13 @@ const (
4750

4851
regionURITemplate = "projects/%s/regions/%s"
4952

50-
replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone}
51-
versionV1 = "v1"
52-
versionBeta = "beta"
53-
versionAlpha = "alpha"
54-
googleEnv = "googleapis"
53+
replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone}
54+
versionV1 Version = "v1"
55+
versionBeta Version = "beta"
56+
versionAlpha Version = "alpha"
57+
environmentStaging Environment = "staging"
5558
)
5659

57-
var computeVersionMap = map[string]map[string]map[string]string{
58-
googleEnv: {
59-
"prod": {
60-
versionV1: "compute/v1/",
61-
versionBeta: "compute/beta/",
62-
versionAlpha: "compute/alpha/",
63-
},
64-
"staging": {
65-
versionV1: "compute/staging_v1/",
66-
versionBeta: "compute/staging_beta/",
67-
versionAlpha: "compute/staging_alpha/",
68-
},
69-
},
70-
}
71-
7260
type CloudProvider struct {
7361
service *compute.Service
7462
betaService *computebeta.Service
@@ -91,7 +79,7 @@ type ConfigGlobal struct {
9179
Zone string `gcfg:"zone"`
9280
}
9381

94-
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string, computeEnvironment string) (*CloudProvider, error) {
82+
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string, computeEnvironment Environment) (*CloudProvider, error) {
9583
configFile, err := readConfig(configPath)
9684
if err != nil {
9785
return nil, err
@@ -185,7 +173,7 @@ func readConfig(configPath string) (*ConfigFile, error) {
185173
return cfg, nil
186174
}
187175

188-
func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment string) (*computealpha.Service, error) {
176+
func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*computealpha.Service, error) {
189177
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionAlpha)
190178
if err != nil {
191179
klog.Errorf("Failed to get compute endpoint: %s", err)
@@ -198,7 +186,7 @@ func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSou
198186
return service, nil
199187
}
200188

201-
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment string) (*computebeta.Service, error) {
189+
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*computebeta.Service, error) {
202190
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionBeta)
203191
if err != nil {
204192
klog.Errorf("Failed to get compute endpoint: %s", err)
@@ -211,7 +199,7 @@ func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSour
211199
return service, nil
212200
}
213201

214-
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment string) (*compute.Service, error) {
202+
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*compute.Service, error) {
215203
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionV1)
216204
if err != nil {
217205
klog.Errorf("Failed to get compute endpoint: %s", err)
@@ -224,20 +212,21 @@ func createCloudService(ctx context.Context, vendorVersion string, tokenSource o
224212
return service, nil
225213
}
226214

227-
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment string, computeVersion string) ([]option.ClientOption, error) {
215+
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment, computeVersion Version) ([]option.ClientOption, error) {
228216
client, err := newOauthClient(ctx, tokenSource)
229217
if err != nil {
230218
return nil, err
231219
}
232-
computeEnvironmentSuffix, ok := computeVersionMap[googleEnv][computeEnvironment][computeVersion]
233-
if !ok {
234-
return nil, errors.New("Unable to fetch compute endpoint")
235-
}
220+
computeEnvironmentSuffix := getPath(computeEnvironment, computeVersion)
236221
computeOpts := []option.ClientOption{option.WithHTTPClient(client)}
222+
237223
if computeEndpoint != "" {
238-
endpoint := fmt.Sprintf("%s%s", computeEndpoint, computeEnvironmentSuffix)
239-
klog.Infof("Got compute endpoint %s", endpoint)
240-
_, err := url.ParseRequestURI(endpoint)
224+
computeURL, err := url.ParseRequestURI(computeEndpoint)
225+
if err != nil {
226+
return nil, err
227+
}
228+
endpoint := computeURL.JoinPath(computeEnvironmentSuffix).String()
229+
_, err = url.ParseRequestURI(endpoint)
241230
if err != nil {
242231
klog.Fatalf("Error parsing compute endpoint %s", endpoint)
243232
}
@@ -246,6 +235,14 @@ func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, comp
246235
return computeOpts, nil
247236
}
248237

238+
func getPath(env Environment, version Version) string {
239+
prefix := ""
240+
if env == environmentStaging {
241+
prefix = fmt.Sprintf("%s_", env)
242+
}
243+
return fmt.Sprintf("compute/%s%s/", prefix, version)
244+
}
245+
249246
func newOauthClient(ctx context.Context, tokenSource oauth2.TokenSource) (*http.Client, error) {
250247
if err := wait.PollImmediate(5*time.Second, 30*time.Second, func() (bool, error) {
251248
if _, err := tokenSource.Token(); err != nil {

test/e2e/tests/single_zone_e2e_test.go

+5-6
Original file line numberDiff line numberDiff line change
@@ -1296,13 +1296,12 @@ var _ = Describe("GCE PD CSI Driver", func() {
12961296
klog.Infof("Creating new driver and client for node %s\n", i.GetName())
12971297

12981298
// Create new driver and client w/ invalid endpoint
1299-
tcInvalid, err := testutils.GCEClientAndDriverSetup(i, "invalid-string")
1300-
if err != nil {
1301-
klog.Fatalf("Failed to set up Test Context for instance %v: %w", i.GetName(), err)
1299+
computeEndpoint := "invalid-string"
1300+
tcInvalid, err := testutils.GCEClientAndDriverSetup(i, computeEndpoint)
1301+
if tcInvalid != nil {
1302+
klog.Fatalf("Driver setup with incorrect compute %v: %v", i.GetName(), computeEndpoint)
13021303
}
1303-
1304-
_, err = tcInvalid.Client.ListVolumes()
1305-
Expect(err.Error()).To(ContainSubstring("no such host"), "expected error when passed invalid compute url")
1304+
Expect(err.Error()).To(ContainSubstring("failed start driver"), "expected error when passed invalid compute url")
13061305

13071306
// Create new driver and client w/ valid, passed-in endpoint
13081307
tcValid, err := testutils.GCEClientAndDriverSetup(i, "https://compute.googleapis.com")

0 commit comments

Comments
 (0)