diff --git a/cmd/gce-pd-csi-driver/main.go b/cmd/gce-pd-csi-driver/main.go index 767cd190c..a27d2b314 100644 --- a/cmd/gce-pd-csi-driver/main.go +++ b/cmd/gce-pd-csi-driver/main.go @@ -36,6 +36,7 @@ import ( var ( cloudConfigFilePath = flag.String("cloud-config", "", "Path to GCE cloud provider config") endpoint = flag.String("endpoint", "unix:/tmp/csi.sock", "CSI endpoint") + computeEndpoint = flag.String("compute-endpoint", "", "If set, used as the endpoint for the GCE API.") runControllerService = flag.Bool("run-controller-service", true, "If set to false then the CSI driver does not activate its controller service (default: true)") runNodeService = flag.Bool("run-node-service", true, "If set to false then the CSI driver does not activate its node service (default: true)") httpEndpoint = flag.String("http-endpoint", "", "The TCP network address where the prometheus metrics endpoint will listen (example: `:8080`). The default is empty string, which means metrics endpoint is disabled.") @@ -117,7 +118,7 @@ func handle() { //Initialize requirements for the controller service var controllerServer *driver.GCEControllerServer if *runControllerService { - cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath) + cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint) if err != nil { klog.Fatalf("Failed to get cloud provider: %v", err) } diff --git a/pkg/gce-cloud-provider/compute/gce.go b/pkg/gce-cloud-provider/compute/gce.go index 493989cb4..c921446a5 100644 --- a/pkg/gce-cloud-provider/compute/gce.go +++ b/pkg/gce-cloud-provider/compute/gce.go @@ -44,10 +44,6 @@ const ( regionURITemplate = "projects/%s/regions/%s" - GCEComputeAPIEndpoint = "https://www.googleapis.com/compute/v1/" - GCEComputeBetaAPIEndpoint = "https://www.googleapis.com/compute/beta/" - GCEComputeAlphaAPIEndpoint = "https://www.googleapis.com/compute/alpha/" - replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone} ) @@ -73,7 +69,7 @@ type ConfigGlobal struct { Zone string `gcfg:"zone"` } -func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string) (*CloudProvider, error) { +func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string) (*CloudProvider, error) { configFile, err := readConfig(configPath) if err != nil { return nil, err @@ -88,12 +84,12 @@ func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath s return nil, err } - svc, err := createCloudService(ctx, vendorVersion, tokenSource) + svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint) if err != nil { return nil, err } - betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource) + betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint) if err != nil { return nil, err } @@ -159,12 +155,17 @@ func readConfig(configPath string) (*ConfigFile, error) { return cfg, nil } -func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource) (*computebeta.Service, error) { +func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*computebeta.Service, error) { client, err := newOauthClient(ctx, tokenSource) if err != nil { return nil, err } - service, err := computebeta.NewService(ctx, option.WithHTTPClient(client)) + + computeOpts := []option.ClientOption{option.WithHTTPClient(client)} + if computeEndpoint != "" { + computeOpts = append(computeOpts, option.WithEndpoint(computeEndpoint)) + } + service, err := computebeta.NewService(ctx, computeOpts...) if err != nil { return nil, err } @@ -172,17 +173,22 @@ func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSour return service, nil } -func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource) (*compute.Service, error) { - svc, err := createCloudServiceWithDefaultServiceAccount(ctx, vendorVersion, tokenSource) +func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*compute.Service, error) { + svc, err := createCloudServiceWithDefaultServiceAccount(ctx, vendorVersion, tokenSource, computeEndpoint) return svc, err } -func createCloudServiceWithDefaultServiceAccount(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource) (*compute.Service, error) { +func createCloudServiceWithDefaultServiceAccount(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*compute.Service, error) { client, err := newOauthClient(ctx, tokenSource) if err != nil { return nil, err } - service, err := compute.New(client) + + computeOpts := []option.ClientOption{option.WithHTTPClient(client)} + if computeEndpoint != "" { + computeOpts = append(computeOpts, option.WithEndpoint(computeEndpoint)) + } + service, err := compute.NewService(ctx, computeOpts...) if err != nil { return nil, err } diff --git a/pkg/gce-pd-csi-driver/controller.go b/pkg/gce-pd-csi-driver/controller.go index 19ccec2a5..0c68d4107 100644 --- a/pkg/gce-pd-csi-driver/controller.go +++ b/pkg/gce-pd-csi-driver/controller.go @@ -18,8 +18,8 @@ import ( "context" "fmt" "math/rand" + "regexp" "sort" - "strings" "time" "github.com/GoogleCloudPlatform/k8s-cloud-provider/pkg/cloud/meta" @@ -1526,9 +1526,8 @@ func generateCreateVolumeResponse(disk *gce.CloudDisk, zones []string) *csi.Crea } func cleanSelfLink(selfLink string) string { - temp := strings.TrimPrefix(selfLink, gce.GCEComputeAPIEndpoint) - temp = strings.TrimPrefix(temp, gce.GCEComputeBetaAPIEndpoint) - return strings.TrimPrefix(temp, gce.GCEComputeAlphaAPIEndpoint) + r, _ := regexp.Compile("https:\\/\\/www.*apis.com\\/.*(v1|beta|alpha)\\/") + return r.ReplaceAllString(selfLink, "") } func createRegionalDisk(ctx context.Context, cloudProvider gce.GCECompute, name string, zones []string, params common.DiskParameters, capacityRange *csi.CapacityRange, capBytes int64, snapshotID string, volumeContentSourceVolumeID string, multiWriter bool) (*gce.CloudDisk, error) { diff --git a/pkg/gce-pd-csi-driver/controller_test.go b/pkg/gce-pd-csi-driver/controller_test.go index 2e6fae83a..f7f68b268 100644 --- a/pkg/gce-pd-csi-driver/controller_test.go +++ b/pkg/gce-pd-csi-driver/controller_test.go @@ -2294,6 +2294,71 @@ func TestControllerPublishBackoffMissingInstance(t *testing.T) { }) } +func TestCleanSelfLink(t *testing.T) { + testCases := []struct { + name string + in string + want string + }{ + { + name: "v1 full standard w/ endpoint prefix", + in: "https://www.googleapis.com/compute/v1/projects/project/zones/zone/disks/disk", + want: "projects/project/zones/zone/disks/disk", + }, + { + name: "beta full standard w/ endpoint prefix", + in: "https://www.googleapis.com/compute/beta/projects/project/zones/zone/disks/disk", + want: "projects/project/zones/zone/disks/disk", + }, + { + name: "alpha full standard w/ endpoint prefix", + in: "https://www.googleapis.com/compute/alpha/projects/project/zones/zone/disks/disk", + want: "projects/project/zones/zone/disks/disk", + }, + { + name: "no prefix", + in: "projects/project/zones/zone/disks/disk", + want: "projects/project/zones/zone/disks/disk", + }, + + { + name: "no prefix + project omitted", + in: "zones/zone/disks/disk", + want: "zones/zone/disks/disk", + }, + { + name: "Compute prefix, google api", + in: "https://www.compute.googleapis.com/compute/v1/projects/project/zones/zone/disks/disk", + want: "projects/project/zones/zone/disks/disk", + }, + { + name: "Compute prefix, partner api", + in: "https://www.compute.PARTNERapis.com/compute/v1/projects/project/zones/zone/disks/disk", + want: "projects/project/zones/zone/disks/disk", + }, + { + name: "Partner beta api", + in: "https://www.PARTNERapis.com/compute/beta/projects/project/zones/zone/disks/disk", + want: "projects/project/zones/zone/disks/disk", + }, + { + name: "Partner alpha api", + in: "https://www.partnerapis.com/compute/alpha/projects/project/zones/zone/disks/disk", + want: "projects/project/zones/zone/disks/disk", + }, + } + + // Run test cases + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := cleanSelfLink(tc.in) + if got != tc.want { + t.Errorf("Expected cleaned self link: %v, got: %v", tc.want, got) + } + }) + } +} + func backoffTesterForPublish(t *testing.T, config *backoffTesterConfig) { readyToExecute := make(chan chan gce.Signal) cloudDisks := []*gce.CloudDisk{ diff --git a/test/e2e/tests/setup_e2e_test.go b/test/e2e/tests/setup_e2e_test.go index 9a7932ff4..dc048e5e2 100644 --- a/test/e2e/tests/setup_e2e_test.go +++ b/test/e2e/tests/setup_e2e_test.go @@ -116,7 +116,7 @@ var _ = BeforeSuite(func() { klog.Infof("Creating new driver and client for node %s\n", i.GetName()) // Create new driver and client - testContext, err := testutils.GCEClientAndDriverSetup(i) + testContext, err := testutils.GCEClientAndDriverSetup(i, "") if err != nil { klog.Fatalf("Failed to set up Test Context for instance %v: %v", i.GetName(), err) } diff --git a/test/e2e/tests/single_zone_e2e_test.go b/test/e2e/tests/single_zone_e2e_test.go index abacc5645..522bea183 100644 --- a/test/e2e/tests/single_zone_e2e_test.go +++ b/test/e2e/tests/single_zone_e2e_test.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "path/filepath" + "regexp" "strings" "time" @@ -1123,6 +1124,40 @@ var _ = Describe("GCE PD CSI Driver", func() { Expect(gce.IsGCEError(err, "notFound")).To(BeTrue(), "Expected disk to not be found") }() }) + + It("Should pass/fail if valid/invalid compute endpoint is passed in", func() { + // gets instance set up w/o compute-endpoint set from test setup + _, err := getRandomTestContext().Client.ListVolumes() + Expect(err).To(BeNil(), "no error expected when passed valid compute url") + + zone := "us-central1-c" + nodeID := fmt.Sprintf("gce-pd-csi-e2e-%s", zone) + i, err := remote.SetupInstance(*project, *architecture, zone, nodeID, *machineType, *serviceAccount, *imageURL, computeService) + + if err != nil { + klog.Fatalf("Failed to setup instance %v: %v", nodeID, err) + } + + klog.Infof("Creating new driver and client for node %s\n", i.GetName()) + + // Create new driver and client w/ invalid endpoint + tcInvalid, err := testutils.GCEClientAndDriverSetup(i, "invalid-string") + if err != nil { + klog.Fatalf("Failed to set up Test Context for instance %v: %v", i.GetName(), err) + } + + _, err = tcInvalid.Client.ListVolumes() + Expect(err.Error()).To(ContainSubstring("no such host"), "expected error when passed invalid compute url") + + // Create new driver and client w/ valid, passed-in endpoint + tcValid, err := testutils.GCEClientAndDriverSetup(i, "https://compute.googleapis.com/compute/v1/") + if err != nil { + klog.Fatalf("Failed to set up Test Context for instance %v: %v", i.GetName(), err) + } + _, err = tcValid.Client.ListVolumes() + + Expect(err).To(BeNil(), "no error expected when passed valid compute url") + }) }) func equalWithinEpsilon(a, b, epsiolon int64) bool { @@ -1204,7 +1239,6 @@ func createAndValidateUniqueZonalMultiWriterDisk(client *remote.CsiClient, proje } func cleanSelfLink(selfLink string) string { - temp := strings.TrimPrefix(selfLink, gce.GCEComputeAPIEndpoint) - temp = strings.TrimPrefix(temp, gce.GCEComputeBetaAPIEndpoint) - return strings.TrimPrefix(temp, gce.GCEComputeAlphaAPIEndpoint) + r, _ := regexp.Compile("https:\\/\\/www.*apis.com\\/.*(v1|beta|alpha)\\/") + return r.ReplaceAllString(selfLink, "") } diff --git a/test/e2e/utils/utils.go b/test/e2e/utils/utils.go index 6d0198ebe..80d2da30d 100644 --- a/test/e2e/utils/utils.go +++ b/test/e2e/utils/utils.go @@ -43,7 +43,7 @@ var ( boskos, _ = boskosclient.NewClient(os.Getenv("JOB_NAME"), "http://boskos", "", "") ) -func GCEClientAndDriverSetup(instance *remote.InstanceInfo) (*remote.TestContext, error) { +func GCEClientAndDriverSetup(instance *remote.InstanceInfo, computeEndpoint string) (*remote.TestContext, error) { port := fmt.Sprintf("%v", 1024+rand.Intn(10000)) goPath, ok := os.LookupEnv("GOPATH") if !ok { @@ -53,10 +53,14 @@ func GCEClientAndDriverSetup(instance *remote.InstanceInfo) (*remote.TestContext binPath := path.Join(pkgPath, "bin/gce-pd-csi-driver") endpoint := fmt.Sprintf("tcp://localhost:%s", port) + computeFlag := "" + if computeEndpoint != "" { + computeFlag = fmt.Sprintf("--compute-endpoint %s", computeEndpoint) + } workspace := remote.NewWorkspaceDir("gce-pd-e2e-") - driverRunCmd := fmt.Sprintf("sh -c '/usr/bin/nohup %s/gce-pd-csi-driver -v=4 --endpoint=%s --extra-labels=%s=%s 2> %s/prog.out < /dev/null > /dev/null &'", - workspace, endpoint, DiskLabelKey, DiskLabelValue, workspace) + driverRunCmd := fmt.Sprintf("sh -c '/usr/bin/nohup %s/gce-pd-csi-driver -v=4 --endpoint=%s %s --extra-labels=%s=%s 2> %s/prog.out < /dev/null > /dev/null &'", + workspace, endpoint, computeFlag, DiskLabelKey, DiskLabelValue, workspace) config := &remote.ClientConfig{ PkgPath: pkgPath,