Skip to content

Commit 91fbd91

Browse files
committed
Add url check during initialization
1 parent d5ec522 commit 91fbd91

File tree

3 files changed

+58
-43
lines changed

3 files changed

+58
-43
lines changed

cmd/gce-pd-csi-driver/main.go

+18-5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"flag"
2222
"fmt"
2323
"math/rand"
24+
"net/url"
2425
"os"
2526
"runtime"
2627
"strings"
@@ -40,7 +41,6 @@ import (
4041
var (
4142
cloudConfigFilePath = flag.String("cloud-config", "", "Path to GCE cloud provider config")
4243
endpoint = flag.String("endpoint", "unix:/tmp/csi.sock", "CSI endpoint")
43-
computeEndpoint = flag.String("compute-endpoint", "", "If set, used as the endpoint for the GCE API.")
4444
runControllerService = flag.Bool("run-controller-service", true, "If set to false then the CSI driver does not activate its controller service (default: true)")
4545
runNodeService = flag.Bool("run-node-service", true, "If set to false then the CSI driver does not activate its node service (default: true)")
4646
httpEndpoint = flag.String("http-endpoint", "", "The TCP network address where the prometheus metrics endpoint will listen (example: `:8080`). The default is empty string, which means metrics endpoint is disabled.")
@@ -71,8 +71,9 @@ var (
7171
fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")
7272

7373
computeEnvironment gce.Environment = "production"
74+
computeEndpoint url.URL
7475
version string
75-
allowedComputeEnvironment = []string{"staging", "production"}
76+
allowedComputeEnvironment = []gce.Environment{gce.EnvironmentStaging, gce.EnvironmentProduction}
7677
)
7778

7879
const (
@@ -86,6 +87,7 @@ func init() {
8687
// Use V(5) for GCE Cloud Provider Call informational logging
8788
// Use V(6) for extra repeated/polling information
8889
enumFlag(&computeEnvironment, "compute-environment", allowedComputeEnvironment, "Operating compute environment")
90+
urlFlag(&computeEndpoint, "compute-endpoint", "Compute endpoint")
8991
klog.InitFlags(flag.CommandLine)
9092
flag.Set("logtostderr", "true")
9193
}
@@ -141,7 +143,7 @@ func handle() {
141143
// Initialize requirements for the controller service
142144
var controllerServer *driver.GCEControllerServer
143145
if *runControllerService {
144-
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint, computeEnvironment)
146+
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, computeEndpoint, computeEnvironment)
145147
if err != nil {
146148
klog.Fatalf("Failed to get cloud provider: %v", err.Error())
147149
}
@@ -191,10 +193,10 @@ func handle() {
191193
gceDriver.Run(*endpoint, *grpcLogCharCap)
192194
}
193195

194-
func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []string, usage string) {
196+
func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []gce.Environment, usage string) {
195197
flag.Func(name, usage, func(flagValue string) error {
196198
for _, allowedValue := range allowedComputeEnvironment {
197-
if flagValue == allowedValue {
199+
if gce.Environment(flagValue) == allowedValue {
198200
*target = gce.Environment(flagValue)
199201
return nil
200202
}
@@ -204,3 +206,14 @@ func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []
204206
})
205207

206208
}
209+
210+
func urlFlag(target *url.URL, name string, usage string) {
211+
flag.Func(name, usage, func(flagValue string) error {
212+
computeURL, err := url.ParseRequestURI(flagValue)
213+
if err == nil {
214+
*target = *computeURL
215+
return nil
216+
}
217+
return err
218+
})
219+
}

pkg/gce-cloud-provider/compute/gce.go

+13-19
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ const (
5454
versionV1 Version = "v1"
5555
versionBeta Version = "beta"
5656
versionAlpha Version = "alpha"
57-
environmentStaging Environment = "staging"
57+
EnvironmentStaging Environment = "staging"
58+
EnvironmentProduction Environment = "production"
5859
)
5960

6061
type CloudProvider struct {
@@ -79,7 +80,7 @@ type ConfigGlobal struct {
7980
Zone string `gcfg:"zone"`
8081
}
8182

82-
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string, computeEnvironment Environment) (*CloudProvider, error) {
83+
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint url.URL, computeEnvironment Environment) (*CloudProvider, error) {
8384
configFile, err := readConfig(configPath)
8485
if err != nil {
8586
return nil, err
@@ -173,7 +174,7 @@ func readConfig(configPath string) (*ConfigFile, error) {
173174
return cfg, nil
174175
}
175176

176-
func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*computealpha.Service, error) {
177+
func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint url.URL, computeEnvironment Environment) (*computealpha.Service, error) {
177178
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionAlpha)
178179
if err != nil {
179180
klog.Errorf("Failed to get compute endpoint: %s", err)
@@ -186,7 +187,7 @@ func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSou
186187
return service, nil
187188
}
188189

189-
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*computebeta.Service, error) {
190+
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint url.URL, computeEnvironment Environment) (*computebeta.Service, error) {
190191
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionBeta)
191192
if err != nil {
192193
klog.Errorf("Failed to get compute endpoint: %s", err)
@@ -199,7 +200,7 @@ func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSour
199200
return service, nil
200201
}
201202

202-
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*compute.Service, error) {
203+
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint url.URL, computeEnvironment Environment) (*compute.Service, error) {
203204
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionV1)
204205
if err != nil {
205206
klog.Errorf("Failed to get compute endpoint: %s", err)
@@ -212,32 +213,25 @@ func createCloudService(ctx context.Context, vendorVersion string, tokenSource o
212213
return service, nil
213214
}
214215

215-
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment, computeVersion Version) ([]option.ClientOption, error) {
216+
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint url.URL, computeEnvironment Environment, computeVersion Version) ([]option.ClientOption, error) {
216217
client, err := newOauthClient(ctx, tokenSource)
217218
if err != nil {
218219
return nil, err
219220
}
220-
computeEnvironmentSuffix := getPath(computeEnvironment, computeVersion)
221221
computeOpts := []option.ClientOption{option.WithHTTPClient(client)}
222222

223-
if computeEndpoint != "" {
224-
computeURL, err := url.ParseRequestURI(computeEndpoint)
225-
if err != nil {
226-
return nil, err
227-
}
228-
endpoint := computeURL.JoinPath(computeEnvironmentSuffix).String()
229-
_, err = url.ParseRequestURI(endpoint)
230-
if err != nil {
231-
klog.Fatalf("Error parsing compute endpoint %s", endpoint)
232-
}
223+
if computeEndpoint.String() != "" {
224+
computeEnvironmentSuffix := constructComputeEndpointPath(computeEnvironment, computeVersion)
225+
computeEndpoint.Path = computeEnvironmentSuffix
226+
endpoint := computeEndpoint.String()
233227
computeOpts = append(computeOpts, option.WithEndpoint(endpoint))
234228
}
235229
return computeOpts, nil
236230
}
237231

238-
func getPath(env Environment, version Version) string {
232+
func constructComputeEndpointPath(env Environment, version Version) string {
239233
prefix := ""
240-
if env == environmentStaging {
234+
if env == EnvironmentStaging {
241235
prefix = fmt.Sprintf("%s_", env)
242236
}
243237
return fmt.Sprintf("compute/%s%s/", prefix, version)

pkg/gce-cloud-provider/compute/gce_test.go

+27-19
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ import (
2222
"errors"
2323
"fmt"
2424
"net/http"
25+
"net/url"
2526
"testing"
2627
"time"
2728

2829
"golang.org/x/oauth2"
2930

31+
"google.golang.org/api/compute/v1"
3032
"google.golang.org/api/googleapi"
3133
)
3234

@@ -102,7 +104,7 @@ func TestIsGCEError(t *testing.T) {
102104
func TestGetComputeVersion(t *testing.T) {
103105
testCases := []struct {
104106
name string
105-
computeEndpoint string
107+
computeEndpoint url.URL
106108
computeEnvironment Environment
107109
computeVersion Version
108110
expectedEndpoint string
@@ -111,30 +113,23 @@ func TestGetComputeVersion(t *testing.T) {
111113

112114
{
113115
name: "check for production environment",
114-
computeEndpoint: "https://compute.googleapis.com",
115-
computeEnvironment: "production",
116-
computeVersion: "v1",
117-
expectedEndpoint: "https://compute.googleapis.com/compute/v1/",
116+
computeEndpoint: convertStringToURL("https://compute.googleapis.com"),
117+
computeEnvironment: EnvironmentProduction,
118+
computeVersion: versionBeta,
119+
expectedEndpoint: "https://compute.googleapis.com/compute/beta/",
118120
expectError: false,
119121
},
120-
{
121-
name: "check for incorrect endpoint",
122-
computeEndpoint: "https://compute.googleapis",
123-
computeEnvironment: "prod",
124-
computeVersion: "v1",
125-
expectError: true,
126-
},
127122
{
128123
name: "check for staging environment",
129-
computeEndpoint: "https://compute.googleapis.com",
130-
computeEnvironment: environmentStaging,
131-
computeVersion: "v1",
132-
expectedEndpoint: "compute/staging_v1/",
124+
computeEndpoint: convertStringToURL("https://compute.googleapis.com"),
125+
computeEnvironment: EnvironmentStaging,
126+
computeVersion: versionV1,
127+
expectedEndpoint: "https://compute.googleapis.com/compute/staging_v1/",
133128
expectError: false,
134129
},
135130
{
136131
name: "check for random string as endpoint",
137-
computeEndpoint: "compute-googleapis",
132+
computeEndpoint: url.URL{},
138133
computeEnvironment: "prod",
139134
computeVersion: "v1",
140135
expectedEndpoint: "compute/v1/",
@@ -143,10 +138,23 @@ func TestGetComputeVersion(t *testing.T) {
143138
}
144139
for _, tc := range testCases {
145140
ctx := context.Background()
146-
_, err := getComputeVersion(ctx, &mockTokenSource{}, tc.computeEndpoint, tc.computeEnvironment, tc.computeVersion)
141+
computeOpts, err := getComputeVersion(ctx, &mockTokenSource{}, tc.computeEndpoint, tc.computeEnvironment, tc.computeVersion)
142+
service, _ := compute.NewService(ctx, computeOpts...)
143+
gotEndpoint := service.BasePath
147144
if err != nil && !tc.expectError {
148-
t.Fatalf("Got error %v, expected endpoint %s", err, tc.expectedEndpoint)
145+
t.Fatalf("Got error %v", err)
146+
}
147+
if gotEndpoint != tc.expectedEndpoint && !tc.expectError {
148+
t.Fatalf("expected endpoint %s, got endpoint %s", tc.expectedEndpoint, gotEndpoint)
149149
}
150150
}
151151

152152
}
153+
154+
func convertStringToURL(urlString string) url.URL {
155+
parsedURL, err := url.ParseRequestURI(urlString)
156+
if err != nil {
157+
return url.URL{}
158+
}
159+
return *parsedURL
160+
}

0 commit comments

Comments
 (0)