Skip to content

Commit a2f75c5

Browse files
committed
Update PDCSI driver to support GCE staging
1 parent 016f84d commit a2f75c5

File tree

6 files changed

+186
-39
lines changed

6 files changed

+186
-39
lines changed

cmd/gce-pd-csi-driver/main.go

+41-5
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ package main
1717

1818
import (
1919
"context"
20+
"errors"
2021
"flag"
22+
"fmt"
2123
"math/rand"
24+
"net/url"
2225
"os"
2326
"runtime"
2427
"strings"
@@ -38,7 +41,6 @@ import (
3841
var (
3942
cloudConfigFilePath = flag.String("cloud-config", "", "Path to GCE cloud provider config")
4043
endpoint = flag.String("endpoint", "unix:/tmp/csi.sock", "CSI endpoint")
41-
computeEndpoint = flag.String("compute-endpoint", "", "If set, used as the endpoint for the GCE API.")
4244
runControllerService = flag.Bool("run-controller-service", true, "If set to false then the CSI driver does not activate its controller service (default: true)")
4345
runNodeService = flag.Bool("run-node-service", true, "If set to false then the CSI driver does not activate its node service (default: true)")
4446
httpEndpoint = flag.String("http-endpoint", "", "The TCP network address where the prometheus metrics endpoint will listen (example: `:8080`). The default is empty string, which means metrics endpoint is disabled.")
@@ -66,10 +68,12 @@ var (
6668

6769
maxConcurrentFormatAndMount = flag.Int("max-concurrent-format-and-mount", 1, "If set then format and mount operations are serialized on each node. This is stronger than max-concurrent-format as it includes fsck and other mount operations")
6870
formatAndMountTimeout = flag.Duration("format-and-mount-timeout", 1*time.Minute, "The maximum duration of a format and mount operation before another such operation will be started. Used only if --serialize-format-and-mount")
71+
fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")
6972

70-
fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")
71-
72-
version string
73+
computeEnvironment gce.Environment = gce.EnvironmentProduction
74+
computeEndpoint *url.URL
75+
version string
76+
allowedComputeEnvironment = []gce.Environment{gce.EnvironmentStaging, gce.EnvironmentProduction}
7377
)
7478

7579
const (
@@ -82,13 +86,16 @@ func init() {
8286
// Use V(4) for general debug information logging
8387
// Use V(5) for GCE Cloud Provider Call informational logging
8488
// Use V(6) for extra repeated/polling information
89+
enumFlag(&computeEnvironment, "compute-environment", allowedComputeEnvironment, "Operating compute environment")
90+
urlFlag(&computeEndpoint, "compute-endpoint", "Compute endpoint")
8591
klog.InitFlags(flag.CommandLine)
8692
flag.Set("logtostderr", "true")
8793
}
8894

8995
func main() {
9096
flag.Parse()
9197
rand.Seed(time.Now().UnixNano())
98+
klog.Infof("Operating compute environment set to: %s and computeEndpoint is set to: %v", computeEnvironment, computeEndpoint)
9299
handle()
93100
os.Exit(0)
94101
}
@@ -137,7 +144,7 @@ func handle() {
137144
// Initialize requirements for the controller service
138145
var controllerServer *driver.GCEControllerServer
139146
if *runControllerService {
140-
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint)
147+
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, computeEndpoint, computeEnvironment)
141148
if err != nil {
142149
klog.Fatalf("Failed to get cloud provider: %v", err.Error())
143150
}
@@ -186,3 +193,32 @@ func handle() {
186193

187194
gceDriver.Run(*endpoint, *grpcLogCharCap)
188195
}
196+
197+
func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []gce.Environment, usage string) {
198+
flag.Func(name, usage, func(flagValue string) error {
199+
for _, allowedValue := range allowedComputeEnvironment {
200+
if gce.Environment(flagValue) == allowedValue {
201+
*target = gce.Environment(flagValue)
202+
return nil
203+
}
204+
}
205+
errMsg := fmt.Sprintf(`must be one of %v`, allowedComputeEnvironment)
206+
return errors.New(errMsg)
207+
})
208+
209+
}
210+
211+
func urlFlag(target **url.URL, name string, usage string) {
212+
flag.Func(name, usage, func(flagValue string) error {
213+
if flagValue == "" {
214+
return nil
215+
}
216+
computeURL, err := url.ParseRequestURI(flagValue)
217+
if err == nil {
218+
*target = computeURL
219+
return nil
220+
}
221+
klog.Errorf("Error parsing endpoint compute endpoint %v", err)
222+
return err
223+
})
224+
}

pkg/gce-cloud-provider/compute/gce.go

+61-23
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"errors"
2020
"fmt"
2121
"net/http"
22+
"net/url"
2223
"os"
2324
"runtime"
2425
"time"
@@ -29,13 +30,17 @@ import (
2930

3031
"cloud.google.com/go/compute/metadata"
3132
"golang.org/x/oauth2"
33+
computealpha "google.golang.org/api/compute/v0.alpha"
3234
computebeta "google.golang.org/api/compute/v0.beta"
3335
"google.golang.org/api/compute/v1"
3436
"google.golang.org/api/googleapi"
3537
"k8s.io/apimachinery/pkg/util/wait"
3638
"k8s.io/klog/v2"
3739
)
3840

41+
type Environment string
42+
type Version string
43+
3944
const (
4045
TokenURL = "https://accounts.google.com/o/oauth2/token"
4146
diskSourceURITemplateSingleZone = "projects/%s/zones/%s/disks/%s" // {gce.projectID}/zones/{disk.Zone}/disks/{disk.Name}"
@@ -45,7 +50,12 @@ const (
4550

4651
regionURITemplate = "projects/%s/regions/%s"
4752

48-
replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone}
53+
replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone}
54+
versionV1 Version = "v1"
55+
versionBeta Version = "beta"
56+
versionAlpha Version = "alpha"
57+
EnvironmentStaging Environment = "staging"
58+
EnvironmentProduction Environment = "production"
4959
)
5060

5161
type CloudProvider struct {
@@ -70,7 +80,7 @@ type ConfigGlobal struct {
7080
Zone string `gcfg:"zone"`
7181
}
7282

73-
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string) (*CloudProvider, error) {
83+
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint *url.URL, computeEnvironment Environment) (*CloudProvider, error) {
7484
configFile, err := readConfig(configPath)
7585
if err != nil {
7686
return nil, err
@@ -85,15 +95,23 @@ func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath s
8595
return nil, err
8696
}
8797

88-
svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint)
98+
svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment)
99+
if err != nil {
100+
return nil, err
101+
}
102+
klog.Infof("Compute endpoint for V1 version: %s", svc.BasePath)
103+
104+
betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment)
89105
if err != nil {
90106
return nil, err
91107
}
108+
klog.Infof("Compute endpoint for Beta version: %s", betasvc.BasePath)
92109

93-
betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint)
110+
alphasvc, err := createAlphaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment)
94111
if err != nil {
95112
return nil, err
96113
}
114+
klog.Infof("Compute endpoint for Alpha version: %s", alphasvc.BasePath)
97115

98116
project, zone, err := getProjectAndZone(configFile)
99117
if err != nil {
@@ -156,16 +174,23 @@ func readConfig(configPath string) (*ConfigFile, error) {
156174
return cfg, nil
157175
}
158176

159-
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*computebeta.Service, error) {
160-
client, err := newOauthClient(ctx, tokenSource)
177+
func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*computealpha.Service, error) {
178+
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionAlpha)
179+
if err != nil {
180+
klog.Errorf("Failed to get compute endpoint: %s", err)
181+
}
182+
service, err := computealpha.NewService(ctx, computeOpts...)
161183
if err != nil {
162184
return nil, err
163185
}
186+
service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH)
187+
return service, nil
188+
}
164189

165-
computeOpts := []option.ClientOption{option.WithHTTPClient(client)}
166-
if computeEndpoint != "" {
167-
betaEndpoint := fmt.Sprintf("%s/compute/beta/", computeEndpoint)
168-
computeOpts = append(computeOpts, option.WithEndpoint(betaEndpoint))
190+
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*computebeta.Service, error) {
191+
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionBeta)
192+
if err != nil {
193+
klog.Errorf("Failed to get compute endpoint: %s", err)
169194
}
170195
service, err := computebeta.NewService(ctx, computeOpts...)
171196
if err != nil {
@@ -175,28 +200,41 @@ func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSour
175200
return service, nil
176201
}
177202

178-
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*compute.Service, error) {
179-
svc, err := createCloudServiceWithDefaultServiceAccount(ctx, vendorVersion, tokenSource, computeEndpoint)
180-
return svc, err
203+
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*compute.Service, error) {
204+
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionV1)
205+
if err != nil {
206+
klog.Errorf("Failed to get compute endpoint: %s", err)
207+
}
208+
service, err := compute.NewService(ctx, computeOpts...)
209+
if err != nil {
210+
return nil, err
211+
}
212+
service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH)
213+
return service, nil
181214
}
182215

183-
func createCloudServiceWithDefaultServiceAccount(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*compute.Service, error) {
216+
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment, computeVersion Version) ([]option.ClientOption, error) {
184217
client, err := newOauthClient(ctx, tokenSource)
185218
if err != nil {
186219
return nil, err
187220
}
188-
189221
computeOpts := []option.ClientOption{option.WithHTTPClient(client)}
190-
if computeEndpoint != "" {
191-
v1Endpoint := fmt.Sprintf("%s/compute/v1/", computeEndpoint)
192-
computeOpts = append(computeOpts, option.WithEndpoint(v1Endpoint))
222+
223+
if computeEndpoint != nil {
224+
computeEnvironmentSuffix := constructComputeEndpointPath(computeEnvironment, computeVersion)
225+
computeEndpoint.Path = computeEnvironmentSuffix
226+
endpoint := computeEndpoint.String()
227+
computeOpts = append(computeOpts, option.WithEndpoint(endpoint))
193228
}
194-
service, err := compute.NewService(ctx, computeOpts...)
195-
if err != nil {
196-
return nil, err
229+
return computeOpts, nil
230+
}
231+
232+
func constructComputeEndpointPath(env Environment, version Version) string {
233+
prefix := ""
234+
if env == EnvironmentStaging {
235+
prefix = fmt.Sprintf("%s_", env)
197236
}
198-
service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH)
199-
return service, nil
237+
return fmt.Sprintf("compute/%s%s/", prefix, version)
200238
}
201239

202240
func newOauthClient(ctx context.Context, tokenSource oauth2.TokenSource) (*http.Client, error) {

pkg/gce-cloud-provider/compute/gce_test.go

+74
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,30 @@ limitations under the License.
1818
package gcecloudprovider
1919

2020
import (
21+
"context"
2122
"errors"
2223
"fmt"
2324
"net/http"
25+
"net/url"
2426
"testing"
27+
"time"
2528

29+
"golang.org/x/oauth2"
30+
31+
"google.golang.org/api/compute/v1"
2632
"google.golang.org/api/googleapi"
2733
)
2834

35+
type mockTokenSource struct{}
36+
37+
func (*mockTokenSource) Token() (*oauth2.Token, error) {
38+
return &oauth2.Token{
39+
AccessToken: "access",
40+
TokenType: "Bearer",
41+
RefreshToken: "refresh",
42+
Expiry: time.Now().Add(1 * time.Hour),
43+
}, nil
44+
}
2945
func TestIsGCEError(t *testing.T) {
3046
testCases := []struct {
3147
name string
@@ -84,3 +100,61 @@ func TestIsGCEError(t *testing.T) {
84100
}
85101
}
86102
}
103+
104+
func TestGetComputeVersion(t *testing.T) {
105+
testCases := []struct {
106+
name string
107+
computeEndpoint *url.URL
108+
computeEnvironment Environment
109+
computeVersion Version
110+
expectedEndpoint string
111+
expectError bool
112+
}{
113+
114+
{
115+
name: "check for production environment",
116+
computeEndpoint: convertStringToURL("https://compute.googleapis.com"),
117+
computeEnvironment: EnvironmentProduction,
118+
computeVersion: versionBeta,
119+
expectedEndpoint: "https://compute.googleapis.com/compute/beta/",
120+
expectError: false,
121+
},
122+
{
123+
name: "check for staging environment",
124+
computeEndpoint: convertStringToURL("https://compute.googleapis.com"),
125+
computeEnvironment: EnvironmentStaging,
126+
computeVersion: versionV1,
127+
expectedEndpoint: "https://compute.googleapis.com/compute/staging_v1/",
128+
expectError: false,
129+
},
130+
{
131+
name: "check for random string as endpoint",
132+
computeEndpoint: convertStringToURL(""),
133+
computeEnvironment: "prod",
134+
computeVersion: "v1",
135+
expectedEndpoint: "compute/v1/",
136+
expectError: true,
137+
},
138+
}
139+
for _, tc := range testCases {
140+
ctx := context.Background()
141+
computeOpts, err := getComputeVersion(ctx, &mockTokenSource{}, tc.computeEndpoint, tc.computeEnvironment, tc.computeVersion)
142+
service, _ := compute.NewService(ctx, computeOpts...)
143+
gotEndpoint := service.BasePath
144+
if err != nil && !tc.expectError {
145+
t.Fatalf("Got error %v", err)
146+
}
147+
if gotEndpoint != tc.expectedEndpoint && !tc.expectError {
148+
t.Fatalf("expected endpoint %s, got endpoint %s", tc.expectedEndpoint, gotEndpoint)
149+
}
150+
}
151+
152+
}
153+
154+
func convertStringToURL(urlString string) *url.URL {
155+
parsedURL, err := url.ParseRequestURI(urlString)
156+
if err != nil {
157+
return nil
158+
}
159+
return parsedURL
160+
}

pkg/gce-pd-csi-driver/controller.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ const (
156156
)
157157

158158
var (
159-
validResourceApiVersions = map[string]bool{"v1": true, "alpha": true, "beta": true}
159+
validResourceApiVersions = map[string]bool{"v1": true, "alpha": true, "beta": true, "staging_v1": true, "staging_beta": true, "staging_alpha": true}
160160
)
161161

162162
func isDiskReady(disk *gce.CloudDisk) (bool, error) {

test/e2e/tests/single_zone_e2e_test.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -1280,7 +1280,7 @@ var _ = Describe("GCE PD CSI Driver", func() {
12801280
}()
12811281
})
12821282

1283-
It("Should pass/fail if valid/invalid compute endpoint is passed in", func() {
1283+
It("Should pass if valid compute endpoint is passed in", func() {
12841284
// gets instance set up w/o compute-endpoint set from test setup
12851285
_, err := getRandomTestContext().Client.ListVolumes()
12861286
Expect(err).To(BeNil(), "no error expected when passed valid compute url")
@@ -1295,14 +1295,15 @@ var _ = Describe("GCE PD CSI Driver", func() {
12951295

12961296
klog.Infof("Creating new driver and client for node %s\n", i.GetName())
12971297

1298-
// Create new driver and client w/ invalid endpoint
1299-
tcInvalid, err := testutils.GCEClientAndDriverSetup(i, "invalid-string")
1298+
// Create new driver and client with valid, empty endpoint
1299+
klog.Infof("Setup driver with empty compute endpoint %s\n", i.GetName())
1300+
tcEmpty, err := testutils.GCEClientAndDriverSetup(i, "")
13001301
if err != nil {
1301-
klog.Fatalf("Failed to set up Test Context for instance %v: %w", i.GetName(), err)
1302+
klog.Fatalf("Failed to set up Test Context for instance %v: %v", i.GetName(), err)
13021303
}
1304+
_, err = tcEmpty.Client.ListVolumes()
13031305

1304-
_, err = tcInvalid.Client.ListVolumes()
1305-
Expect(err.Error()).To(ContainSubstring("no such host"), "expected error when passed invalid compute url")
1306+
Expect(err).To(BeNil(), "no error expected when passed empty compute url")
13061307

13071308
// Create new driver and client w/ valid, passed-in endpoint
13081309
tcValid, err := testutils.GCEClientAndDriverSetup(i, "https://compute.googleapis.com")

0 commit comments

Comments
 (0)