Skip to content

Commit 52239f2

Browse files
authored
Merge pull request #1614 from Sneha-at/automated-cherry-pick-of-#1586-#1609-upstream-release-1.13
Automated cherry pick of #1586: update driver to support staging compute #1609: fix pointer issue for GCE staging support
2 parents 1990033 + 9f5ca1b commit 52239f2

File tree

6 files changed

+164
-55
lines changed

6 files changed

+164
-55
lines changed

cmd/gce-pd-csi-driver/main.go

+39-7
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ package main
1717

1818
import (
1919
"context"
20+
"errors"
2021
"flag"
22+
"fmt"
2123
"math/rand"
24+
"net/url"
2225
"os"
2326
"runtime"
2427
"strings"
@@ -38,7 +41,6 @@ import (
3841
var (
3942
cloudConfigFilePath = flag.String("cloud-config", "", "Path to GCE cloud provider config")
4043
endpoint = flag.String("endpoint", "unix:/tmp/csi.sock", "CSI endpoint")
41-
computeEndpoint = flag.String("compute-endpoint", "", "If set, used as the endpoint for the GCE API.")
4244
runControllerService = flag.Bool("run-controller-service", true, "If set to false then the CSI driver does not activate its controller service (default: true)")
4345
runNodeService = flag.Bool("run-node-service", true, "If set to false then the CSI driver does not activate its node service (default: true)")
4446
httpEndpoint = flag.String("http-endpoint", "", "The TCP network address where the prometheus metrics endpoint will listen (example: `:8080`). The default is empty string, which means metrics endpoint is disabled.")
@@ -67,12 +69,13 @@ var (
6769

6870
maxConcurrentFormatAndMount = flag.Int("max-concurrent-format-and-mount", 1, "If set then format and mount operations are serialized on each node. This is stronger than max-concurrent-format as it includes fsck and other mount operations")
6971
formatAndMountTimeout = flag.Duration("format-and-mount-timeout", 1*time.Minute, "The maximum duration of a format and mount operation before another such operation will be started. Used only if --serialize-format-and-mount")
72+
fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")
7073

71-
fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")
72-
73-
enableStoragePoolsFlag = flag.Bool("enable-storage-pools", false, "If set to true, the CSI Driver will allow volumes to be provisioned in Storage Pools")
74-
75-
version string
74+
enableStoragePoolsFlag = flag.Bool("enable-storage-pools", false, "If set to true, the CSI Driver will allow volumes to be provisioned in Storage Pools")
75+
computeEnvironment gce.Environment = gce.EnvironmentProduction
76+
computeEndpoint *url.URL
77+
version string
78+
allowedComputeEnvironment = []gce.Environment{gce.EnvironmentStaging, gce.EnvironmentProduction}
7679
)
7780

7881
const (
@@ -85,13 +88,16 @@ func init() {
8588
// Use V(4) for general debug information logging
8689
// Use V(5) for GCE Cloud Provider Call informational logging
8790
// Use V(6) for extra repeated/polling information
91+
enumFlag(&computeEnvironment, "compute-environment", allowedComputeEnvironment, "Operating compute environment")
92+
urlFlag(&computeEndpoint, "compute-endpoint", "Compute endpoint")
8893
klog.InitFlags(flag.CommandLine)
8994
flag.Set("logtostderr", "true")
9095
}
9196

9297
func main() {
9398
flag.Parse()
9499
rand.Seed(time.Now().UnixNano())
100+
klog.Infof("Operating compute environment set to: %s and computeEndpoint is set to: %v", computeEnvironment, computeEndpoint)
95101
handle()
96102
os.Exit(0)
97103
}
@@ -156,7 +162,7 @@ func handle() {
156162
// Initialize requirements for the controller service
157163
var controllerServer *driver.GCEControllerServer
158164
if *runControllerService {
159-
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint)
165+
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, computeEndpoint, computeEnvironment)
160166
if err != nil {
161167
klog.Fatalf("Failed to get cloud provider: %v", err.Error())
162168
}
@@ -205,3 +211,29 @@ func handle() {
205211

206212
gceDriver.Run(*endpoint, *grpcLogCharCap, *enableOtelTracing)
207213
}
214+
215+
func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []gce.Environment, usage string) {
216+
flag.Func(name, usage, func(flagValue string) error {
217+
for _, allowedValue := range allowedComputeEnvironment {
218+
if gce.Environment(flagValue) == allowedValue {
219+
*target = gce.Environment(flagValue)
220+
return nil
221+
}
222+
}
223+
errMsg := fmt.Sprintf(`must be one of %v`, allowedComputeEnvironment)
224+
return errors.New(errMsg)
225+
})
226+
227+
}
228+
229+
func urlFlag(target **url.URL, name string, usage string) {
230+
flag.Func(name, usage, func(flagValue string) error {
231+
computeURL, err := url.ParseRequestURI(flagValue)
232+
if err == nil {
233+
*target = computeURL
234+
return nil
235+
}
236+
klog.Infof("Error parsing endpoint compute endpoint %v", err)
237+
return err
238+
})
239+
}

pkg/gce-cloud-provider/compute/gce.go

+49-36
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"errors"
2020
"fmt"
2121
"net/http"
22+
"net/url"
2223
"os"
2324
"runtime"
2425
"time"
@@ -37,6 +38,9 @@ import (
3738
"k8s.io/klog/v2"
3839
)
3940

41+
type Environment string
42+
type Version string
43+
4044
const (
4145
TokenURL = "https://accounts.google.com/o/oauth2/token"
4246
diskSourceURITemplateSingleZone = "projects/%s/zones/%s/disks/%s" // {gce.projectID}/zones/{disk.Zone}/disks/{disk.Name}"
@@ -46,7 +50,12 @@ const (
4650

4751
regionURITemplate = "projects/%s/regions/%s"
4852

49-
replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone}
53+
replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone}
54+
versionV1 Version = "v1"
55+
versionBeta Version = "beta"
56+
versionAlpha Version = "alpha"
57+
EnvironmentStaging Environment = "staging"
58+
EnvironmentProduction Environment = "production"
5059
)
5160

5261
type CloudProvider struct {
@@ -72,7 +81,7 @@ type ConfigGlobal struct {
7281
Zone string `gcfg:"zone"`
7382
}
7483

75-
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string) (*CloudProvider, error) {
84+
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint *url.URL, computeEnvironment Environment) (*CloudProvider, error) {
7685
configFile, err := readConfig(configPath)
7786
if err != nil {
7887
return nil, err
@@ -87,20 +96,23 @@ func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath s
8796
return nil, err
8897
}
8998

90-
svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint)
99+
svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment)
91100
if err != nil {
92101
return nil, err
93102
}
103+
klog.Infof("Compute endpoint for V1 version: %s", svc.BasePath)
94104

95-
betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint)
105+
betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment)
96106
if err != nil {
97107
return nil, err
98108
}
109+
klog.Infof("Compute endpoint for Beta version: %s", betasvc.BasePath)
99110

100-
alphasvc, err := createAlphaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint)
111+
alphasvc, err := createAlphaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment)
101112
if err != nil {
102113
return nil, err
103114
}
115+
klog.Infof("Compute endpoint for Alpha version: %s", alphasvc.BasePath)
104116

105117
project, zone, err := getProjectAndZone(configFile)
106118
if err != nil {
@@ -164,16 +176,23 @@ func readConfig(configPath string) (*ConfigFile, error) {
164176
return cfg, nil
165177
}
166178

167-
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*computebeta.Service, error) {
168-
client, err := newOauthClient(ctx, tokenSource)
179+
func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*computealpha.Service, error) {
180+
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionAlpha)
181+
if err != nil {
182+
klog.Errorf("Failed to get compute endpoint: %s", err)
183+
}
184+
service, err := computealpha.NewService(ctx, computeOpts...)
169185
if err != nil {
170186
return nil, err
171187
}
188+
service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH)
189+
return service, nil
190+
}
172191

173-
computeOpts := []option.ClientOption{option.WithHTTPClient(client)}
174-
if computeEndpoint != "" {
175-
betaEndpoint := fmt.Sprintf("%s/compute/beta/", computeEndpoint)
176-
computeOpts = append(computeOpts, option.WithEndpoint(betaEndpoint))
192+
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*computebeta.Service, error) {
193+
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionBeta)
194+
if err != nil {
195+
klog.Errorf("Failed to get compute endpoint: %s", err)
177196
}
178197
service, err := computebeta.NewService(ctx, computeOpts...)
179198
if err != nil {
@@ -183,47 +202,41 @@ func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSour
183202
return service, nil
184203
}
185204

186-
func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*computealpha.Service, error) {
187-
client, err := newOauthClient(ctx, tokenSource)
205+
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*compute.Service, error) {
206+
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionV1)
188207
if err != nil {
189-
return nil, err
190-
}
191-
192-
computeOpts := []option.ClientOption{option.WithHTTPClient(client)}
193-
if computeEndpoint != "" {
194-
alphaEndpoint := fmt.Sprintf("%s/compute/alpha/", computeEndpoint)
195-
computeOpts = append(computeOpts, option.WithEndpoint(alphaEndpoint))
208+
klog.Errorf("Failed to get compute endpoint: %s", err)
196209
}
197-
service, err := computealpha.NewService(ctx, computeOpts...)
210+
service, err := compute.NewService(ctx, computeOpts...)
198211
if err != nil {
199212
return nil, err
200213
}
201214
service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH)
202215
return service, nil
203216
}
204217

205-
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*compute.Service, error) {
206-
svc, err := createCloudServiceWithDefaultServiceAccount(ctx, vendorVersion, tokenSource, computeEndpoint)
207-
return svc, err
208-
}
209-
210-
func createCloudServiceWithDefaultServiceAccount(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*compute.Service, error) {
218+
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment, computeVersion Version) ([]option.ClientOption, error) {
211219
client, err := newOauthClient(ctx, tokenSource)
212220
if err != nil {
213221
return nil, err
214222
}
215-
216223
computeOpts := []option.ClientOption{option.WithHTTPClient(client)}
217-
if computeEndpoint != "" {
218-
v1Endpoint := fmt.Sprintf("%s/compute/v1/", computeEndpoint)
219-
computeOpts = append(computeOpts, option.WithEndpoint(v1Endpoint))
224+
225+
if computeEndpoint != nil {
226+
computeEnvironmentSuffix := constructComputeEndpointPath(computeEnvironment, computeVersion)
227+
computeEndpoint.Path = computeEnvironmentSuffix
228+
endpoint := computeEndpoint.String()
229+
computeOpts = append(computeOpts, option.WithEndpoint(endpoint))
220230
}
221-
service, err := compute.NewService(ctx, computeOpts...)
222-
if err != nil {
223-
return nil, err
231+
return computeOpts, nil
232+
}
233+
234+
func constructComputeEndpointPath(env Environment, version Version) string {
235+
prefix := ""
236+
if env == EnvironmentStaging {
237+
prefix = fmt.Sprintf("%s_", env)
224238
}
225-
service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH)
226-
return service, nil
239+
return fmt.Sprintf("compute/%s%s/", prefix, version)
227240
}
228241

229242
func newOauthClient(ctx context.Context, tokenSource oauth2.TokenSource) (*http.Client, error) {

pkg/gce-cloud-provider/compute/gce_test.go

+74
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,30 @@ limitations under the License.
1818
package gcecloudprovider
1919

2020
import (
21+
"context"
2122
"errors"
2223
"fmt"
2324
"net/http"
25+
"net/url"
2426
"testing"
27+
"time"
2528

29+
"golang.org/x/oauth2"
30+
31+
"google.golang.org/api/compute/v1"
2632
"google.golang.org/api/googleapi"
2733
)
2834

35+
type mockTokenSource struct{}
36+
37+
func (*mockTokenSource) Token() (*oauth2.Token, error) {
38+
return &oauth2.Token{
39+
AccessToken: "access",
40+
TokenType: "Bearer",
41+
RefreshToken: "refresh",
42+
Expiry: time.Now().Add(1 * time.Hour),
43+
}, nil
44+
}
2945
func TestIsGCEError(t *testing.T) {
3046
testCases := []struct {
3147
name string
@@ -84,3 +100,61 @@ func TestIsGCEError(t *testing.T) {
84100
}
85101
}
86102
}
103+
104+
func TestGetComputeVersion(t *testing.T) {
105+
testCases := []struct {
106+
name string
107+
computeEndpoint *url.URL
108+
computeEnvironment Environment
109+
computeVersion Version
110+
expectedEndpoint string
111+
expectError bool
112+
}{
113+
114+
{
115+
name: "check for production environment",
116+
computeEndpoint: convertStringToURL("https://compute.googleapis.com"),
117+
computeEnvironment: EnvironmentProduction,
118+
computeVersion: versionBeta,
119+
expectedEndpoint: "https://compute.googleapis.com/compute/beta/",
120+
expectError: false,
121+
},
122+
{
123+
name: "check for staging environment",
124+
computeEndpoint: convertStringToURL("https://compute.googleapis.com"),
125+
computeEnvironment: EnvironmentStaging,
126+
computeVersion: versionV1,
127+
expectedEndpoint: "https://compute.googleapis.com/compute/staging_v1/",
128+
expectError: false,
129+
},
130+
{
131+
name: "check for random string as endpoint",
132+
computeEndpoint: convertStringToURL(""),
133+
computeEnvironment: "prod",
134+
computeVersion: "v1",
135+
expectedEndpoint: "compute/v1/",
136+
expectError: true,
137+
},
138+
}
139+
for _, tc := range testCases {
140+
ctx := context.Background()
141+
computeOpts, err := getComputeVersion(ctx, &mockTokenSource{}, tc.computeEndpoint, tc.computeEnvironment, tc.computeVersion)
142+
service, _ := compute.NewService(ctx, computeOpts...)
143+
gotEndpoint := service.BasePath
144+
if err != nil && !tc.expectError {
145+
t.Fatalf("Got error %v", err)
146+
}
147+
if gotEndpoint != tc.expectedEndpoint && !tc.expectError {
148+
t.Fatalf("expected endpoint %s, got endpoint %s", tc.expectedEndpoint, gotEndpoint)
149+
}
150+
}
151+
152+
}
153+
154+
func convertStringToURL(urlString string) *url.URL {
155+
parsedURL, err := url.ParseRequestURI(urlString)
156+
if err != nil {
157+
return nil
158+
}
159+
return parsedURL
160+
}

pkg/gce-pd-csi-driver/controller.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ const (
159159
)
160160

161161
var (
162-
validResourceApiVersions = map[string]bool{"v1": true, "alpha": true, "beta": true}
162+
validResourceApiVersions = map[string]bool{"v1": true, "alpha": true, "beta": true, "staging_v1": true, "staging_beta": true, "staging_alpha": true}
163163
)
164164

165165
func isDiskReady(disk *gce.CloudDisk) (bool, error) {

test/e2e/tests/single_zone_e2e_test.go

+1-10
Original file line numberDiff line numberDiff line change
@@ -1280,7 +1280,7 @@ var _ = Describe("GCE PD CSI Driver", func() {
12801280
}()
12811281
})
12821282

1283-
It("Should pass/fail if valid/invalid compute endpoint is passed in", func() {
1283+
It("Should pass if valid compute endpoint is passed in", func() {
12841284
// gets instance set up w/o compute-endpoint set from test setup
12851285
_, err := getRandomTestContext().Client.ListVolumes()
12861286
Expect(err).To(BeNil(), "no error expected when passed valid compute url")
@@ -1295,15 +1295,6 @@ var _ = Describe("GCE PD CSI Driver", func() {
12951295

12961296
klog.Infof("Creating new driver and client for node %s\n", i.GetName())
12971297

1298-
// Create new driver and client w/ invalid endpoint
1299-
tcInvalid, err := testutils.GCEClientAndDriverSetup(i, "invalid-string")
1300-
if err != nil {
1301-
klog.Fatalf("Failed to set up Test Context for instance %v: %w", i.GetName(), err)
1302-
}
1303-
1304-
_, err = tcInvalid.Client.ListVolumes()
1305-
Expect(err.Error()).To(ContainSubstring("no such host"), "expected error when passed invalid compute url")
1306-
13071298
// Create new driver and client w/ valid, passed-in endpoint
13081299
tcValid, err := testutils.GCEClientAndDriverSetup(i, "https://compute.googleapis.com")
13091300
if err != nil {

test/e2e/utils/utils.go

-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ func GCEClientAndDriverSetup(instance *remote.InstanceInfo, computeEndpoint stri
6565
// useful to see what's happening when debugging tests.
6666
driverRunCmd := fmt.Sprintf("sh -c '/usr/bin/nohup %s/gce-pd-csi-driver -v=6 --endpoint=%s %s 2> %s/prog.out < /dev/null > /dev/null &'",
6767
workspace, endpoint, strings.Join(extra_flags, " "), workspace)
68-
6968
config := &remote.ClientConfig{
7069
PkgPath: pkgPath,
7170
BinPath: binPath,

0 commit comments

Comments
 (0)