Skip to content

Commit 4eab6b6

Browse files
committed
Add url check during initialization
1 parent b5c8138 commit 4eab6b6

File tree

3 files changed

+59
-44
lines changed

3 files changed

+59
-44
lines changed

cmd/gce-pd-csi-driver/main.go

+19-6
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"flag"
2222
"fmt"
2323
"math/rand"
24+
"net/url"
2425
"os"
2526
"runtime"
2627
"strings"
@@ -40,7 +41,6 @@ import (
4041
var (
4142
cloudConfigFilePath = flag.String("cloud-config", "", "Path to GCE cloud provider config")
4243
endpoint = flag.String("endpoint", "unix:/tmp/csi.sock", "CSI endpoint")
43-
computeEndpoint = flag.String("compute-endpoint", "", "If set, used as the endpoint for the GCE API.")
4444
runControllerService = flag.Bool("run-controller-service", true, "If set to false then the CSI driver does not activate its controller service (default: true)")
4545
runNodeService = flag.Bool("run-node-service", true, "If set to false then the CSI driver does not activate its node service (default: true)")
4646
httpEndpoint = flag.String("http-endpoint", "", "The TCP network address where the prometheus metrics endpoint will listen (example: `:8080`). The default is empty string, which means metrics endpoint is disabled.")
@@ -72,9 +72,10 @@ var (
7272
fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")
7373

7474
enableStoragePoolsFlag = flag.Bool("enable-storage-pools", false, "If set to true, the CSI Driver will allow volumes to be provisioned in Storage Pools")
75-
computeEnvironment gce.Environment = "production"
75+
computeEnvironment gce.Environment = gce.EnvironmentProduction
76+
computeEndpoint url.URL
7677
version string
77-
allowedComputeEnvironment = []string{"staging", "production"}
78+
allowedComputeEnvironment = []gce.Environment{gce.EnvironmentStaging, gce.EnvironmentProduction}
7879
)
7980

8081
const (
@@ -88,6 +89,7 @@ func init() {
8889
// Use V(5) for GCE Cloud Provider Call informational logging
8990
// Use V(6) for extra repeated/polling information
9091
enumFlag(&computeEnvironment, "compute-environment", allowedComputeEnvironment, "Operating compute environment")
92+
urlFlag(&computeEndpoint, "compute-endpoint", "Compute endpoint")
9193
klog.InitFlags(flag.CommandLine)
9294
flag.Set("logtostderr", "true")
9395
}
@@ -159,7 +161,7 @@ func handle() {
159161
// Initialize requirements for the controller service
160162
var controllerServer *driver.GCEControllerServer
161163
if *runControllerService {
162-
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint, computeEnvironment)
164+
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, computeEndpoint, computeEnvironment)
163165
if err != nil {
164166
klog.Fatalf("Failed to get cloud provider: %v", err.Error())
165167
}
@@ -209,10 +211,10 @@ func handle() {
209211
gceDriver.Run(*endpoint, *grpcLogCharCap, *enableOtelTracing)
210212
}
211213

212-
func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []string, usage string) {
214+
func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []gce.Environment, usage string) {
213215
flag.Func(name, usage, func(flagValue string) error {
214216
for _, allowedValue := range allowedComputeEnvironment {
215-
if flagValue == allowedValue {
217+
if gce.Environment(flagValue) == allowedValue {
216218
*target = gce.Environment(flagValue)
217219
return nil
218220
}
@@ -222,3 +224,14 @@ func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []
222224
})
223225

224226
}
227+
228+
func urlFlag(target *url.URL, name string, usage string) {
229+
flag.Func(name, usage, func(flagValue string) error {
230+
computeURL, err := url.ParseRequestURI(flagValue)
231+
if err == nil {
232+
*target = *computeURL
233+
return nil
234+
}
235+
return err
236+
})
237+
}

pkg/gce-cloud-provider/compute/gce.go

+13-19
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ const (
5454
versionV1 Version = "v1"
5555
versionBeta Version = "beta"
5656
versionAlpha Version = "alpha"
57-
environmentStaging Environment = "staging"
57+
EnvironmentStaging Environment = "staging"
58+
EnvironmentProduction Environment = "production"
5859
)
5960

6061
type CloudProvider struct {
@@ -80,7 +81,7 @@ type ConfigGlobal struct {
8081
Zone string `gcfg:"zone"`
8182
}
8283

83-
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string, computeEnvironment Environment) (*CloudProvider, error) {
84+
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint url.URL, computeEnvironment Environment) (*CloudProvider, error) {
8485
configFile, err := readConfig(configPath)
8586
if err != nil {
8687
return nil, err
@@ -175,7 +176,7 @@ func readConfig(configPath string) (*ConfigFile, error) {
175176
return cfg, nil
176177
}
177178

178-
func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*computealpha.Service, error) {
179+
func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint url.URL, computeEnvironment Environment) (*computealpha.Service, error) {
179180
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionAlpha)
180181
if err != nil {
181182
klog.Errorf("Failed to get compute endpoint: %s", err)
@@ -188,7 +189,7 @@ func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSou
188189
return service, nil
189190
}
190191

191-
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*computebeta.Service, error) {
192+
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint url.URL, computeEnvironment Environment) (*computebeta.Service, error) {
192193
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionBeta)
193194
if err != nil {
194195
klog.Errorf("Failed to get compute endpoint: %s", err)
@@ -201,7 +202,7 @@ func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSour
201202
return service, nil
202203
}
203204

204-
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*compute.Service, error) {
205+
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint url.URL, computeEnvironment Environment) (*compute.Service, error) {
205206
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionV1)
206207
if err != nil {
207208
klog.Errorf("Failed to get compute endpoint: %s", err)
@@ -214,32 +215,25 @@ func createCloudService(ctx context.Context, vendorVersion string, tokenSource o
214215
return service, nil
215216
}
216217

217-
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment, computeVersion Version) ([]option.ClientOption, error) {
218+
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint url.URL, computeEnvironment Environment, computeVersion Version) ([]option.ClientOption, error) {
218219
client, err := newOauthClient(ctx, tokenSource)
219220
if err != nil {
220221
return nil, err
221222
}
222-
computeEnvironmentSuffix := getPath(computeEnvironment, computeVersion)
223223
computeOpts := []option.ClientOption{option.WithHTTPClient(client)}
224224

225-
if computeEndpoint != "" {
226-
computeURL, err := url.ParseRequestURI(computeEndpoint)
227-
if err != nil {
228-
return nil, err
229-
}
230-
endpoint := computeURL.JoinPath(computeEnvironmentSuffix).String()
231-
_, err = url.ParseRequestURI(endpoint)
232-
if err != nil {
233-
klog.Fatalf("Error parsing compute endpoint %s", endpoint)
234-
}
225+
if computeEndpoint.String() != "" {
226+
computeEnvironmentSuffix := constructComputeEndpointPath(computeEnvironment, computeVersion)
227+
computeEndpoint.Path = computeEnvironmentSuffix
228+
endpoint := computeEndpoint.String()
235229
computeOpts = append(computeOpts, option.WithEndpoint(endpoint))
236230
}
237231
return computeOpts, nil
238232
}
239233

240-
func getPath(env Environment, version Version) string {
234+
func constructComputeEndpointPath(env Environment, version Version) string {
241235
prefix := ""
242-
if env == environmentStaging {
236+
if env == EnvironmentStaging {
243237
prefix = fmt.Sprintf("%s_", env)
244238
}
245239
return fmt.Sprintf("compute/%s%s/", prefix, version)

pkg/gce-cloud-provider/compute/gce_test.go

+27-19
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ import (
2222
"errors"
2323
"fmt"
2424
"net/http"
25+
"net/url"
2526
"testing"
2627
"time"
2728

2829
"golang.org/x/oauth2"
2930

31+
"google.golang.org/api/compute/v1"
3032
"google.golang.org/api/googleapi"
3133
)
3234

@@ -102,7 +104,7 @@ func TestIsGCEError(t *testing.T) {
102104
func TestGetComputeVersion(t *testing.T) {
103105
testCases := []struct {
104106
name string
105-
computeEndpoint string
107+
computeEndpoint url.URL
106108
computeEnvironment Environment
107109
computeVersion Version
108110
expectedEndpoint string
@@ -111,30 +113,23 @@ func TestGetComputeVersion(t *testing.T) {
111113

112114
{
113115
name: "check for production environment",
114-
computeEndpoint: "https://compute.googleapis.com",
115-
computeEnvironment: "production",
116-
computeVersion: "v1",
117-
expectedEndpoint: "https://compute.googleapis.com/compute/v1/",
116+
computeEndpoint: convertStringToURL("https://compute.googleapis.com"),
117+
computeEnvironment: EnvironmentProduction,
118+
computeVersion: versionBeta,
119+
expectedEndpoint: "https://compute.googleapis.com/compute/beta/",
118120
expectError: false,
119121
},
120-
{
121-
name: "check for incorrect endpoint",
122-
computeEndpoint: "https://compute.googleapis",
123-
computeEnvironment: "prod",
124-
computeVersion: "v1",
125-
expectError: true,
126-
},
127122
{
128123
name: "check for staging environment",
129-
computeEndpoint: "https://compute.googleapis.com",
130-
computeEnvironment: environmentStaging,
131-
computeVersion: "v1",
132-
expectedEndpoint: "compute/staging_v1/",
124+
computeEndpoint: convertStringToURL("https://compute.googleapis.com"),
125+
computeEnvironment: EnvironmentStaging,
126+
computeVersion: versionV1,
127+
expectedEndpoint: "https://compute.googleapis.com/compute/staging_v1/",
133128
expectError: false,
134129
},
135130
{
136131
name: "check for random string as endpoint",
137-
computeEndpoint: "compute-googleapis",
132+
computeEndpoint: url.URL{},
138133
computeEnvironment: "prod",
139134
computeVersion: "v1",
140135
expectedEndpoint: "compute/v1/",
@@ -143,10 +138,23 @@ func TestGetComputeVersion(t *testing.T) {
143138
}
144139
for _, tc := range testCases {
145140
ctx := context.Background()
146-
_, err := getComputeVersion(ctx, &mockTokenSource{}, tc.computeEndpoint, tc.computeEnvironment, tc.computeVersion)
141+
computeOpts, err := getComputeVersion(ctx, &mockTokenSource{}, tc.computeEndpoint, tc.computeEnvironment, tc.computeVersion)
142+
service, _ := compute.NewService(ctx, computeOpts...)
143+
gotEndpoint := service.BasePath
147144
if err != nil && !tc.expectError {
148-
t.Fatalf("Got error %v, expected endpoint %s", err, tc.expectedEndpoint)
145+
t.Fatalf("Got error %v", err)
146+
}
147+
if gotEndpoint != tc.expectedEndpoint && !tc.expectError {
148+
t.Fatalf("expected endpoint %s, got endpoint %s", tc.expectedEndpoint, gotEndpoint)
149149
}
150150
}
151151

152152
}
153+
154+
func convertStringToURL(urlString string) url.URL {
155+
parsedURL, err := url.ParseRequestURI(urlString)
156+
if err != nil {
157+
return url.URL{}
158+
}
159+
return *parsedURL
160+
}

0 commit comments

Comments
 (0)