From 7a7638b1b6990933d7be5483e059bb64cc725113 Mon Sep 17 00:00:00 2001 From: Sneha Aradhey Date: Tue, 20 Feb 2024 19:25:39 +0000 Subject: [PATCH 1/2] Update PDCSI driver to support staging environment --- cmd/gce-pd-csi-driver/main.go | 44 ++++++++++-- pkg/gce-cloud-provider/compute/gce.go | 84 ++++++++++++++++------ pkg/gce-cloud-provider/compute/gce_test.go | 74 +++++++++++++++++++ pkg/gce-pd-csi-driver/controller.go | 2 +- test/e2e/tests/single_zone_e2e_test.go | 11 +-- 5 files changed, 176 insertions(+), 39 deletions(-) diff --git a/cmd/gce-pd-csi-driver/main.go b/cmd/gce-pd-csi-driver/main.go index 74872b03f..9defcc778 100644 --- a/cmd/gce-pd-csi-driver/main.go +++ b/cmd/gce-pd-csi-driver/main.go @@ -17,8 +17,11 @@ package main import ( "context" + "errors" "flag" + "fmt" "math/rand" + "net/url" "os" "runtime" "strings" @@ -38,7 +41,6 @@ import ( var ( cloudConfigFilePath = flag.String("cloud-config", "", "Path to GCE cloud provider config") endpoint = flag.String("endpoint", "unix:/tmp/csi.sock", "CSI endpoint") - computeEndpoint = flag.String("compute-endpoint", "", "If set, used as the endpoint for the GCE API.") runControllerService = flag.Bool("run-controller-service", true, "If set to false then the CSI driver does not activate its controller service (default: true)") runNodeService = flag.Bool("run-node-service", true, "If set to false then the CSI driver does not activate its node service (default: true)") httpEndpoint = flag.String("http-endpoint", "", "The TCP network address where the prometheus metrics endpoint will listen (example: `:8080`). The default is empty string, which means metrics endpoint is disabled.") @@ -66,10 +68,13 @@ var ( maxConcurrentFormatAndMount = flag.Int("max-concurrent-format-and-mount", 1, "If set then format and mount operations are serialized on each node. This is stronger than max-concurrent-format as it includes fsck and other mount operations") formatAndMountTimeout = flag.Duration("format-and-mount-timeout", 1*time.Minute, "The maximum duration of a format and mount operation before another such operation will be started. Used only if --serialize-format-and-mount") + fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk") - fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk") - - version string + enableStoragePoolsFlag = flag.Bool("enable-storage-pools", false, "If set to true, the CSI Driver will allow volumes to be provisioned in Storage Pools") + computeEnvironment gce.Environment = gce.EnvironmentProduction + computeEndpoint *url.URL + version string + allowedComputeEnvironment = []gce.Environment{gce.EnvironmentStaging, gce.EnvironmentProduction} ) const ( @@ -82,6 +87,8 @@ func init() { // Use V(4) for general debug information logging // Use V(5) for GCE Cloud Provider Call informational logging // Use V(6) for extra repeated/polling information + enumFlag(&computeEnvironment, "compute-environment", allowedComputeEnvironment, "Operating compute environment") + urlFlag(&computeEndpoint, "compute-endpoint", "Compute endpoint") klog.InitFlags(flag.CommandLine) flag.Set("logtostderr", "true") } @@ -89,6 +96,7 @@ func init() { func main() { flag.Parse() rand.Seed(time.Now().UnixNano()) + klog.Infof("Operating compute environment set to: %s and computeEndpoint is set to: %v", computeEnvironment, computeEndpoint) handle() os.Exit(0) } @@ -137,7 +145,7 @@ func handle() { // Initialize requirements for the controller service var controllerServer *driver.GCEControllerServer if *runControllerService { - cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint) + cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, computeEndpoint, computeEnvironment) if err != nil { klog.Fatalf("Failed to get cloud provider: %v", err.Error()) } @@ -186,3 +194,29 @@ func handle() { gceDriver.Run(*endpoint, *grpcLogCharCap) } + +func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []gce.Environment, usage string) { + flag.Func(name, usage, func(flagValue string) error { + for _, allowedValue := range allowedComputeEnvironment { + if gce.Environment(flagValue) == allowedValue { + *target = gce.Environment(flagValue) + return nil + } + } + errMsg := fmt.Sprintf(`must be one of %v`, allowedComputeEnvironment) + return errors.New(errMsg) + }) + +} + +func urlFlag(target **url.URL, name string, usage string) { + flag.Func(name, usage, func(flagValue string) error { + computeURL, err := url.ParseRequestURI(flagValue) + if err == nil { + *target = computeURL + return nil + } + klog.Infof("Error parsing endpoint compute endpoint %v", err) + return err + }) +} diff --git a/pkg/gce-cloud-provider/compute/gce.go b/pkg/gce-cloud-provider/compute/gce.go index 1d9cb051d..03966de92 100644 --- a/pkg/gce-cloud-provider/compute/gce.go +++ b/pkg/gce-cloud-provider/compute/gce.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "net/http" + "net/url" "os" "runtime" "time" @@ -29,6 +30,7 @@ import ( "cloud.google.com/go/compute/metadata" "golang.org/x/oauth2" + computealpha "google.golang.org/api/compute/v0.alpha" computebeta "google.golang.org/api/compute/v0.beta" "google.golang.org/api/compute/v1" "google.golang.org/api/googleapi" @@ -36,6 +38,9 @@ import ( "k8s.io/klog/v2" ) +type Environment string +type Version string + const ( TokenURL = "https://accounts.google.com/o/oauth2/token" diskSourceURITemplateSingleZone = "projects/%s/zones/%s/disks/%s" // {gce.projectID}/zones/{disk.Zone}/disks/{disk.Name}" @@ -45,7 +50,12 @@ const ( regionURITemplate = "projects/%s/regions/%s" - replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone} + replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone} + versionV1 Version = "v1" + versionBeta Version = "beta" + versionAlpha Version = "alpha" + EnvironmentStaging Environment = "staging" + EnvironmentProduction Environment = "production" ) type CloudProvider struct { @@ -70,7 +80,7 @@ type ConfigGlobal struct { Zone string `gcfg:"zone"` } -func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string) (*CloudProvider, error) { +func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint *url.URL, computeEnvironment Environment) (*CloudProvider, error) { configFile, err := readConfig(configPath) if err != nil { return nil, err @@ -85,15 +95,23 @@ func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath s return nil, err } - svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint) + svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment) + if err != nil { + return nil, err + } + klog.Infof("Compute endpoint for V1 version: %s", svc.BasePath) + + betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment) if err != nil { return nil, err } + klog.Infof("Compute endpoint for Beta version: %s", betasvc.BasePath) - betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint) + alphasvc, err := createAlphaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment) if err != nil { return nil, err } + klog.Infof("Compute endpoint for Alpha version: %s", alphasvc.BasePath) project, zone, err := getProjectAndZone(configFile) if err != nil { @@ -156,16 +174,23 @@ func readConfig(configPath string) (*ConfigFile, error) { return cfg, nil } -func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*computebeta.Service, error) { - client, err := newOauthClient(ctx, tokenSource) +func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*computealpha.Service, error) { + computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionAlpha) + if err != nil { + klog.Errorf("Failed to get compute endpoint: %s", err) + } + service, err := computealpha.NewService(ctx, computeOpts...) if err != nil { return nil, err } + service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH) + return service, nil +} - computeOpts := []option.ClientOption{option.WithHTTPClient(client)} - if computeEndpoint != "" { - betaEndpoint := fmt.Sprintf("%s/compute/beta/", computeEndpoint) - computeOpts = append(computeOpts, option.WithEndpoint(betaEndpoint)) +func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*computebeta.Service, error) { + computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionBeta) + if err != nil { + klog.Errorf("Failed to get compute endpoint: %s", err) } service, err := computebeta.NewService(ctx, computeOpts...) if err != nil { @@ -175,28 +200,41 @@ func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSour return service, nil } -func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*compute.Service, error) { - svc, err := createCloudServiceWithDefaultServiceAccount(ctx, vendorVersion, tokenSource, computeEndpoint) - return svc, err +func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*compute.Service, error) { + computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionV1) + if err != nil { + klog.Errorf("Failed to get compute endpoint: %s", err) + } + service, err := compute.NewService(ctx, computeOpts...) + if err != nil { + return nil, err + } + service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH) + return service, nil } -func createCloudServiceWithDefaultServiceAccount(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*compute.Service, error) { +func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment, computeVersion Version) ([]option.ClientOption, error) { client, err := newOauthClient(ctx, tokenSource) if err != nil { return nil, err } - computeOpts := []option.ClientOption{option.WithHTTPClient(client)} - if computeEndpoint != "" { - v1Endpoint := fmt.Sprintf("%s/compute/v1/", computeEndpoint) - computeOpts = append(computeOpts, option.WithEndpoint(v1Endpoint)) + + if computeEndpoint != nil { + computeEnvironmentSuffix := constructComputeEndpointPath(computeEnvironment, computeVersion) + computeEndpoint.Path = computeEnvironmentSuffix + endpoint := computeEndpoint.String() + computeOpts = append(computeOpts, option.WithEndpoint(endpoint)) } - service, err := compute.NewService(ctx, computeOpts...) - if err != nil { - return nil, err + return computeOpts, nil +} + +func constructComputeEndpointPath(env Environment, version Version) string { + prefix := "" + if env == EnvironmentStaging { + prefix = fmt.Sprintf("%s_", env) } - service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH) - return service, nil + return fmt.Sprintf("compute/%s%s/", prefix, version) } func newOauthClient(ctx context.Context, tokenSource oauth2.TokenSource) (*http.Client, error) { diff --git a/pkg/gce-cloud-provider/compute/gce_test.go b/pkg/gce-cloud-provider/compute/gce_test.go index 5bb2aed89..49f85221d 100644 --- a/pkg/gce-cloud-provider/compute/gce_test.go +++ b/pkg/gce-cloud-provider/compute/gce_test.go @@ -18,14 +18,30 @@ limitations under the License. package gcecloudprovider import ( + "context" "errors" "fmt" "net/http" + "net/url" "testing" + "time" + "golang.org/x/oauth2" + + "google.golang.org/api/compute/v1" "google.golang.org/api/googleapi" ) +type mockTokenSource struct{} + +func (*mockTokenSource) Token() (*oauth2.Token, error) { + return &oauth2.Token{ + AccessToken: "access", + TokenType: "Bearer", + RefreshToken: "refresh", + Expiry: time.Now().Add(1 * time.Hour), + }, nil +} func TestIsGCEError(t *testing.T) { testCases := []struct { name string @@ -84,3 +100,61 @@ func TestIsGCEError(t *testing.T) { } } } + +func TestGetComputeVersion(t *testing.T) { + testCases := []struct { + name string + computeEndpoint *url.URL + computeEnvironment Environment + computeVersion Version + expectedEndpoint string + expectError bool + }{ + + { + name: "check for production environment", + computeEndpoint: convertStringToURL("https://compute.googleapis.com"), + computeEnvironment: EnvironmentProduction, + computeVersion: versionBeta, + expectedEndpoint: "https://compute.googleapis.com/compute/beta/", + expectError: false, + }, + { + name: "check for staging environment", + computeEndpoint: convertStringToURL("https://compute.googleapis.com"), + computeEnvironment: EnvironmentStaging, + computeVersion: versionV1, + expectedEndpoint: "https://compute.googleapis.com/compute/staging_v1/", + expectError: false, + }, + { + name: "check for random string as endpoint", + computeEndpoint: convertStringToURL(""), + computeEnvironment: "prod", + computeVersion: "v1", + expectedEndpoint: "compute/v1/", + expectError: true, + }, + } + for _, tc := range testCases { + ctx := context.Background() + computeOpts, err := getComputeVersion(ctx, &mockTokenSource{}, tc.computeEndpoint, tc.computeEnvironment, tc.computeVersion) + service, _ := compute.NewService(ctx, computeOpts...) + gotEndpoint := service.BasePath + if err != nil && !tc.expectError { + t.Fatalf("Got error %v", err) + } + if gotEndpoint != tc.expectedEndpoint && !tc.expectError { + t.Fatalf("expected endpoint %s, got endpoint %s", tc.expectedEndpoint, gotEndpoint) + } + } + +} + +func convertStringToURL(urlString string) *url.URL { + parsedURL, err := url.ParseRequestURI(urlString) + if err != nil { + return nil + } + return parsedURL +} diff --git a/pkg/gce-pd-csi-driver/controller.go b/pkg/gce-pd-csi-driver/controller.go index d570cf372..066e75d8f 100644 --- a/pkg/gce-pd-csi-driver/controller.go +++ b/pkg/gce-pd-csi-driver/controller.go @@ -156,7 +156,7 @@ const ( ) var ( - validResourceApiVersions = map[string]bool{"v1": true, "alpha": true, "beta": true} + validResourceApiVersions = map[string]bool{"v1": true, "alpha": true, "beta": true, "staging_v1": true, "staging_beta": true, "staging_alpha": true} ) func isDiskReady(disk *gce.CloudDisk) (bool, error) { diff --git a/test/e2e/tests/single_zone_e2e_test.go b/test/e2e/tests/single_zone_e2e_test.go index e337f1178..24094bccd 100644 --- a/test/e2e/tests/single_zone_e2e_test.go +++ b/test/e2e/tests/single_zone_e2e_test.go @@ -1280,7 +1280,7 @@ var _ = Describe("GCE PD CSI Driver", func() { }() }) - It("Should pass/fail if valid/invalid compute endpoint is passed in", func() { + It("Should pass if valid compute endpoint is passed in", func() { // gets instance set up w/o compute-endpoint set from test setup _, err := getRandomTestContext().Client.ListVolumes() Expect(err).To(BeNil(), "no error expected when passed valid compute url") @@ -1295,15 +1295,6 @@ var _ = Describe("GCE PD CSI Driver", func() { klog.Infof("Creating new driver and client for node %s\n", i.GetName()) - // Create new driver and client w/ invalid endpoint - tcInvalid, err := testutils.GCEClientAndDriverSetup(i, "invalid-string") - if err != nil { - klog.Fatalf("Failed to set up Test Context for instance %v: %w", i.GetName(), err) - } - - _, err = tcInvalid.Client.ListVolumes() - Expect(err.Error()).To(ContainSubstring("no such host"), "expected error when passed invalid compute url") - // Create new driver and client w/ valid, passed-in endpoint tcValid, err := testutils.GCEClientAndDriverSetup(i, "https://compute.googleapis.com") if err != nil { From 532c5df45aa47d070a4e8d1f2d7ef31654d24bee Mon Sep 17 00:00:00 2001 From: Sneha Aradhey Date: Thu, 22 Feb 2024 23:09:34 +0000 Subject: [PATCH 2/2] Update to handle when no compute endpoint is passed --- cmd/gce-pd-csi-driver/main.go | 5 ++++- test/e2e/tests/single_zone_e2e_test.go | 10 ++++++++++ test/e2e/utils/utils.go | 5 ++--- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/cmd/gce-pd-csi-driver/main.go b/cmd/gce-pd-csi-driver/main.go index 9defcc778..123f0103a 100644 --- a/cmd/gce-pd-csi-driver/main.go +++ b/cmd/gce-pd-csi-driver/main.go @@ -211,12 +211,15 @@ func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment [] func urlFlag(target **url.URL, name string, usage string) { flag.Func(name, usage, func(flagValue string) error { + if flagValue == "" { + return nil + } computeURL, err := url.ParseRequestURI(flagValue) if err == nil { *target = computeURL return nil } - klog.Infof("Error parsing endpoint compute endpoint %v", err) + klog.Errorf("Error parsing endpoint compute endpoint %v", err) return err }) } diff --git a/test/e2e/tests/single_zone_e2e_test.go b/test/e2e/tests/single_zone_e2e_test.go index 24094bccd..854876db5 100644 --- a/test/e2e/tests/single_zone_e2e_test.go +++ b/test/e2e/tests/single_zone_e2e_test.go @@ -1295,6 +1295,16 @@ var _ = Describe("GCE PD CSI Driver", func() { klog.Infof("Creating new driver and client for node %s\n", i.GetName()) + // Create new driver and client with valid, empty endpoint + klog.Infof("Setup driver with empty compute endpoint %s\n", i.GetName()) + tcEmpty, err := testutils.GCEClientAndDriverSetup(i, "") + if err != nil { + klog.Fatalf("Failed to set up Test Context for instance %v: %v", i.GetName(), err) + } + _, err = tcEmpty.Client.ListVolumes() + + Expect(err).To(BeNil(), "no error expected when passed empty compute url") + // Create new driver and client w/ valid, passed-in endpoint tcValid, err := testutils.GCEClientAndDriverSetup(i, "https://compute.googleapis.com") if err != nil { diff --git a/test/e2e/utils/utils.go b/test/e2e/utils/utils.go index 97364bed4..f10ca17d3 100644 --- a/test/e2e/utils/utils.go +++ b/test/e2e/utils/utils.go @@ -57,9 +57,8 @@ func GCEClientAndDriverSetup(instance *remote.InstanceInfo, computeEndpoint stri fmt.Sprintf("--extra-labels=%s=%s", DiskLabelKey, DiskLabelValue), "--max-concurrent-format-and-mount=20", // otherwise the serialization times out the e2e test. } - if computeEndpoint != "" { - extra_flags = append(extra_flags, fmt.Sprintf("--compute-endpoint %s", computeEndpoint)) - } + extra_flags = append(extra_flags, fmt.Sprintf("--compute-endpoint=%s", computeEndpoint)) + workspace := remote.NewWorkspaceDir("gce-pd-e2e-") // Log at V(6) as the compute API calls are emitted at that level and it's // useful to see what's happening when debugging tests.