Skip to content

Commit ce810c4

Browse files
committed
Add checks for compute environment flags and url checks
1 parent 6aae7d2 commit ce810c4

File tree

3 files changed

+57
-45
lines changed

3 files changed

+57
-45
lines changed

cmd/gce-pd-csi-driver/main.go

+23-7
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ package main
1717

1818
import (
1919
"context"
20+
"errors"
2021
"flag"
22+
"fmt"
2123
"math/rand"
2224
"os"
2325
"runtime"
@@ -67,13 +69,12 @@ var (
6769

6870
maxConcurrentFormatAndMount = flag.Int("max-concurrent-format-and-mount", 1, "If set then format and mount operations are serialized on each node. This is stronger than max-concurrent-format as it includes fsck and other mount operations")
6971
formatAndMountTimeout = flag.Duration("format-and-mount-timeout", 1*time.Minute, "The maximum duration of a format and mount operation before another such operation will be started. Used only if --serialize-format-and-mount")
70-
computeEnvironment = flag.String("compute-environment", "prod", "Sets the compute environment")
72+
fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")
7173

72-
fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")
73-
74-
enableStoragePoolsFlag = flag.Bool("enable-storage-pools", false, "If set to true, the CSI Driver will allow volumes to be provisioned in Storage Pools")
75-
76-
version string
74+
enableStoragePoolsFlag = flag.Bool("enable-storage-pools", false, "If set to true, the CSI Driver will allow volumes to be provisioned in Storage Pools")
75+
computeEnvironment gce.Environment = "production"
76+
version string
77+
allowedComputeEnvironment = []string{"staging", "production"}
7778
)
7879

7980
const (
@@ -86,6 +87,7 @@ func init() {
8687
// Use V(4) for general debug information logging
8788
// Use V(5) for GCE Cloud Provider Call informational logging
8889
// Use V(6) for extra repeated/polling information
90+
enumFlag(&computeEnvironment, "compute-environment", allowedComputeEnvironment, "Operating compute environment")
8991
klog.InitFlags(flag.CommandLine)
9092
flag.Set("logtostderr", "true")
9193
}
@@ -157,7 +159,7 @@ func handle() {
157159
// Initialize requirements for the controller service
158160
var controllerServer *driver.GCEControllerServer
159161
if *runControllerService {
160-
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint, *computeEnvironment)
162+
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint, computeEnvironment)
161163
if err != nil {
162164
klog.Fatalf("Failed to get cloud provider: %v", err.Error())
163165
}
@@ -206,3 +208,17 @@ func handle() {
206208

207209
gceDriver.Run(*endpoint, *grpcLogCharCap, *enableOtelTracing)
208210
}
211+
212+
func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []string, usage string) {
213+
flag.Func(name, usage, func(flagValue string) error {
214+
for _, allowedValue := range allowedComputeEnvironment {
215+
if flagValue == allowedValue {
216+
*target = gce.Environment(flagValue)
217+
return nil
218+
}
219+
}
220+
errMsg := fmt.Sprintf(`must be one of %v`, allowedComputeEnvironment)
221+
return errors.New(errMsg)
222+
})
223+
224+
}

pkg/gce-cloud-provider/compute/gce.go

+29-32
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ import (
3838
"k8s.io/klog/v2"
3939
)
4040

41+
type Environment string
42+
type Version string
43+
4144
const (
4245
TokenURL = "https://accounts.google.com/o/oauth2/token"
4346
diskSourceURITemplateSingleZone = "projects/%s/zones/%s/disks/%s" // {gce.projectID}/zones/{disk.Zone}/disks/{disk.Name}"
@@ -47,28 +50,13 @@ const (
4750

4851
regionURITemplate = "projects/%s/regions/%s"
4952

50-
replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone}
51-
versionV1 = "v1"
52-
versionBeta = "beta"
53-
versionAlpha = "alpha"
54-
googleEnv = "googleapis"
53+
replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone}
54+
versionV1 Version = "v1"
55+
versionBeta Version = "beta"
56+
versionAlpha Version = "alpha"
57+
environmentStaging Environment = "staging"
5558
)
5659

57-
var computeVersionMap = map[string]map[string]map[string]string{
58-
googleEnv: {
59-
"prod": {
60-
versionV1: "compute/v1/",
61-
versionBeta: "compute/beta/",
62-
versionAlpha: "compute/alpha/",
63-
},
64-
"staging": {
65-
versionV1: "compute/staging_v1/",
66-
versionBeta: "compute/staging_beta/",
67-
versionAlpha: "compute/staging_alpha/",
68-
},
69-
},
70-
}
71-
7260
type CloudProvider struct {
7361
service *compute.Service
7462
betaService *computebeta.Service
@@ -92,7 +80,7 @@ type ConfigGlobal struct {
9280
Zone string `gcfg:"zone"`
9381
}
9482

95-
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string, computeEnvironment string) (*CloudProvider, error) {
83+
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string, computeEnvironment Environment) (*CloudProvider, error) {
9684
configFile, err := readConfig(configPath)
9785
if err != nil {
9886
return nil, err
@@ -187,7 +175,7 @@ func readConfig(configPath string) (*ConfigFile, error) {
187175
return cfg, nil
188176
}
189177

190-
func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment string) (*computealpha.Service, error) {
178+
func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*computealpha.Service, error) {
191179
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionAlpha)
192180
if err != nil {
193181
klog.Errorf("Failed to get compute endpoint: %s", err)
@@ -200,7 +188,7 @@ func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSou
200188
return service, nil
201189
}
202190

203-
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment string) (*computebeta.Service, error) {
191+
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*computebeta.Service, error) {
204192
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionBeta)
205193
if err != nil {
206194
klog.Errorf("Failed to get compute endpoint: %s", err)
@@ -213,7 +201,7 @@ func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSour
213201
return service, nil
214202
}
215203

216-
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment string) (*compute.Service, error) {
204+
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*compute.Service, error) {
217205
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionV1)
218206
if err != nil {
219207
klog.Errorf("Failed to get compute endpoint: %s", err)
@@ -226,20 +214,21 @@ func createCloudService(ctx context.Context, vendorVersion string, tokenSource o
226214
return service, nil
227215
}
228216

229-
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment string, computeVersion string) ([]option.ClientOption, error) {
217+
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment, computeVersion Version) ([]option.ClientOption, error) {
230218
client, err := newOauthClient(ctx, tokenSource)
231219
if err != nil {
232220
return nil, err
233221
}
234-
computeEnvironmentSuffix, ok := computeVersionMap[googleEnv][computeEnvironment][computeVersion]
235-
if !ok {
236-
return nil, errors.New("Unable to fetch compute endpoint")
237-
}
222+
computeEnvironmentSuffix := getPath(computeEnvironment, computeVersion)
238223
computeOpts := []option.ClientOption{option.WithHTTPClient(client)}
224+
239225
if computeEndpoint != "" {
240-
endpoint := fmt.Sprintf("%s%s", computeEndpoint, computeEnvironmentSuffix)
241-
klog.Infof("Got compute endpoint %s", endpoint)
242-
_, err := url.ParseRequestURI(endpoint)
226+
computeURL, err := url.ParseRequestURI(computeEndpoint)
227+
if err != nil {
228+
return nil, err
229+
}
230+
endpoint := computeURL.JoinPath(computeEnvironmentSuffix).String()
231+
_, err = url.ParseRequestURI(endpoint)
243232
if err != nil {
244233
klog.Fatalf("Error parsing compute endpoint %s", endpoint)
245234
}
@@ -248,6 +237,14 @@ func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, comp
248237
return computeOpts, nil
249238
}
250239

240+
func getPath(env Environment, version Version) string {
241+
prefix := ""
242+
if env == environmentStaging {
243+
prefix = fmt.Sprintf("%s_", env)
244+
}
245+
return fmt.Sprintf("compute/%s%s/", prefix, version)
246+
}
247+
251248
func newOauthClient(ctx context.Context, tokenSource oauth2.TokenSource) (*http.Client, error) {
252249
if err := wait.PollImmediate(5*time.Second, 30*time.Second, func() (bool, error) {
253250
if _, err := tokenSource.Token(); err != nil {

test/e2e/tests/single_zone_e2e_test.go

+5-6
Original file line numberDiff line numberDiff line change
@@ -1296,13 +1296,12 @@ var _ = Describe("GCE PD CSI Driver", func() {
12961296
klog.Infof("Creating new driver and client for node %s\n", i.GetName())
12971297

12981298
// Create new driver and client w/ invalid endpoint
1299-
tcInvalid, err := testutils.GCEClientAndDriverSetup(i, "invalid-string")
1300-
if err != nil {
1301-
klog.Fatalf("Failed to set up Test Context for instance %v: %v", i.GetName(), err)
1299+
computeEndpoint := "invalid-string"
1300+
tcInvalid, err := testutils.GCEClientAndDriverSetup(i, computeEndpoint)
1301+
if tcInvalid != nil {
1302+
klog.Fatalf("Driver setup with incorrect compute %v: %v", i.GetName(), computeEndpoint)
13021303
}
1303-
1304-
_, err = tcInvalid.Client.ListVolumes()
1305-
Expect(err.Error()).To(ContainSubstring("no such host"), "expected error when passed invalid compute url")
1304+
Expect(err.Error()).To(ContainSubstring("failed start driver"), "expected error when passed invalid compute url")
13061305

13071306
// Create new driver and client w/ valid, passed-in endpoint
13081307
tcValid, err := testutils.GCEClientAndDriverSetup(i, "https://compute.googleapis.com")

0 commit comments

Comments
 (0)