Skip to content

Update driver to support compute staging #1586

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 16, 2024
29 changes: 23 additions & 6 deletions cmd/gce-pd-csi-driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ package main

import (
"context"
"errors"
"flag"
"fmt"
"math/rand"
"os"
"runtime"
Expand Down Expand Up @@ -67,12 +69,12 @@ var (

maxConcurrentFormatAndMount = flag.Int("max-concurrent-format-and-mount", 1, "If set then format and mount operations are serialized on each node. This is stronger than max-concurrent-format as it includes fsck and other mount operations")
formatAndMountTimeout = flag.Duration("format-and-mount-timeout", 1*time.Minute, "The maximum duration of a format and mount operation before another such operation will be started. Used only if --serialize-format-and-mount")
fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")

fallbackRequisiteZonesFlag = flag.String("fallback-requisite-zones", "", "Comma separated list of requisite zones that will be used if there are not sufficient zones present in requisite topologies when provisioning a disk")

enableStoragePoolsFlag = flag.Bool("enable-storage-pools", false, "If set to true, the CSI Driver will allow volumes to be provisioned in Storage Pools")

version string
enableStoragePoolsFlag = flag.Bool("enable-storage-pools", false, "If set to true, the CSI Driver will allow volumes to be provisioned in Storage Pools")
computeEnvironment gce.Environment = "production"
version string
allowedComputeEnvironment = []string{"staging", "production"}
)

const (
Expand All @@ -85,6 +87,7 @@ func init() {
// Use V(4) for general debug information logging
// Use V(5) for GCE Cloud Provider Call informational logging
// Use V(6) for extra repeated/polling information
enumFlag(&computeEnvironment, "compute-environment", allowedComputeEnvironment, "Operating compute environment")
klog.InitFlags(flag.CommandLine)
flag.Set("logtostderr", "true")
}
Expand Down Expand Up @@ -156,7 +159,7 @@ func handle() {
// Initialize requirements for the controller service
var controllerServer *driver.GCEControllerServer
if *runControllerService {
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint)
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, *computeEndpoint, computeEnvironment)
if err != nil {
klog.Fatalf("Failed to get cloud provider: %v", err.Error())
}
Expand Down Expand Up @@ -205,3 +208,17 @@ func handle() {

gceDriver.Run(*endpoint, *grpcLogCharCap, *enableOtelTracing)
}

func enumFlag(target *gce.Environment, name string, allowedComputeEnvironment []string, usage string) {
flag.Func(name, usage, func(flagValue string) error {
for _, allowedValue := range allowedComputeEnvironment {
if flagValue == allowedValue {
*target = gce.Environment(flagValue)
return nil
}
}
errMsg := fmt.Sprintf(`must be one of %v`, allowedComputeEnvironment)
return errors.New(errMsg)
})

}
89 changes: 54 additions & 35 deletions pkg/gce-cloud-provider/compute/gce.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"os"
"runtime"
"time"
Expand All @@ -37,6 +38,9 @@ import (
"k8s.io/klog/v2"
)

type Environment string
type Version string

const (
TokenURL = "https://accounts.google.com/o/oauth2/token"
diskSourceURITemplateSingleZone = "projects/%s/zones/%s/disks/%s" // {gce.projectID}/zones/{disk.Zone}/disks/{disk.Name}"
Expand All @@ -46,7 +50,11 @@ const (

regionURITemplate = "projects/%s/regions/%s"

replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone}
replicaZoneURITemplateSingleZone = "projects/%s/zones/%s" // {gce.projectID}/zones/{disk.Zone}
versionV1 Version = "v1"
versionBeta Version = "beta"
versionAlpha Version = "alpha"
environmentStaging Environment = "staging"
)

type CloudProvider struct {
Expand All @@ -72,7 +80,7 @@ type ConfigGlobal struct {
Zone string `gcfg:"zone"`
}

func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string) (*CloudProvider, error) {
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint string, computeEnvironment Environment) (*CloudProvider, error) {
configFile, err := readConfig(configPath)
if err != nil {
return nil, err
Expand All @@ -87,20 +95,23 @@ func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath s
return nil, err
}

svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint)
svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment)
if err != nil {
return nil, err
}
klog.Infof("Compute endpoint for V1 version: %s", svc.BasePath)

betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint)
betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment)
if err != nil {
return nil, err
}
klog.Infof("Compute endpoint for Beta version: %s", betasvc.BasePath)

alphasvc, err := createAlphaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint)
alphasvc, err := createAlphaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment)
if err != nil {
return nil, err
}
klog.Infof("Compute endpoint for Alpha version: %s", alphasvc.BasePath)

project, zone, err := getProjectAndZone(configFile)
if err != nil {
Expand Down Expand Up @@ -164,16 +175,23 @@ func readConfig(configPath string) (*ConfigFile, error) {
return cfg, nil
}

func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*computebeta.Service, error) {
client, err := newOauthClient(ctx, tokenSource)
func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*computealpha.Service, error) {
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionAlpha)
if err != nil {
klog.Errorf("Failed to get compute endpoint: %s", err)
}
service, err := computealpha.NewService(ctx, computeOpts...)
if err != nil {
return nil, err
}
service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH)
return service, nil
}

computeOpts := []option.ClientOption{option.WithHTTPClient(client)}
if computeEndpoint != "" {
betaEndpoint := fmt.Sprintf("%s/compute/beta/", computeEndpoint)
computeOpts = append(computeOpts, option.WithEndpoint(betaEndpoint))
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*computebeta.Service, error) {
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionBeta)
if err != nil {
klog.Errorf("Failed to get compute endpoint: %s", err)
}
service, err := computebeta.NewService(ctx, computeOpts...)
if err != nil {
Expand All @@ -183,47 +201,48 @@ func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSour
return service, nil
}

func createAlphaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*computealpha.Service, error) {
client, err := newOauthClient(ctx, tokenSource)
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment) (*compute.Service, error) {
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, versionV1)
if err != nil {
return nil, err
}

computeOpts := []option.ClientOption{option.WithHTTPClient(client)}
if computeEndpoint != "" {
alphaEndpoint := fmt.Sprintf("%s/compute/alpha/", computeEndpoint)
computeOpts = append(computeOpts, option.WithEndpoint(alphaEndpoint))
klog.Errorf("Failed to get compute endpoint: %s", err)
}
service, err := computealpha.NewService(ctx, computeOpts...)
service, err := compute.NewService(ctx, computeOpts...)
if err != nil {
return nil, err
}
service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH)
return service, nil
}

func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*compute.Service, error) {
svc, err := createCloudServiceWithDefaultServiceAccount(ctx, vendorVersion, tokenSource, computeEndpoint)
return svc, err
}

func createCloudServiceWithDefaultServiceAccount(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint string) (*compute.Service, error) {
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint string, computeEnvironment Environment, computeVersion Version) ([]option.ClientOption, error) {
client, err := newOauthClient(ctx, tokenSource)
if err != nil {
return nil, err
}

computeEnvironmentSuffix := getPath(computeEnvironment, computeVersion)
computeOpts := []option.ClientOption{option.WithHTTPClient(client)}

if computeEndpoint != "" {
v1Endpoint := fmt.Sprintf("%s/compute/v1/", computeEndpoint)
computeOpts = append(computeOpts, option.WithEndpoint(v1Endpoint))
computeURL, err := url.ParseRequestURI(computeEndpoint)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation of computeEndpoint could be done at the initialization layer. You could do something very similar to your enumFlag() call, where you validate the computeEndpoint string set a computeURL value (pointer) with either the value, or nil if no computeURL is provided.

This allows you to avoid needing to check for err here, and your function on the construction of the URL and path.

if err != nil {
return nil, err
}
endpoint := computeURL.JoinPath(computeEnvironmentSuffix).String()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than JoinPath, I think we should ignore the path and just construct it ourselves. This will make the logic backwards compatible for anything that a calling process that runs the driver with --compute-endpoint string that includes a path.

computeURL.Path = computeEnvironmentSuffix
endpoint := computeURL.String()

You may need to update your getPath() function to include a leading / character when constructing the path.

_, err = url.ParseRequestURI(endpoint)
if err != nil {
klog.Fatalf("Error parsing compute endpoint %s", endpoint)
}
computeOpts = append(computeOpts, option.WithEndpoint(endpoint))
}
service, err := compute.NewService(ctx, computeOpts...)
if err != nil {
return nil, err
return computeOpts, nil
}

func getPath(env Environment, version Version) string {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would call this constructComputeEndpointPath, rather than getPath. The name getPath is genericvery generic, and get implies we're either fetching a field from an object, or we're retrieving specific data from a higher level resource.

prefix := ""
if env == environmentStaging {
prefix = fmt.Sprintf("%s_", env)
}
service.UserAgent = fmt.Sprintf("GCE CSI Driver/%s (%s %s)", vendorVersion, runtime.GOOS, runtime.GOARCH)
return service, nil
return fmt.Sprintf("compute/%s%s/", prefix, version)
}

func newOauthClient(ctx context.Context, tokenSource oauth2.TokenSource) (*http.Client, error) {
Expand Down
66 changes: 66 additions & 0 deletions pkg/gce-cloud-provider/compute/gce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,28 @@ limitations under the License.
package gcecloudprovider

import (
"context"
"errors"
"fmt"
"net/http"
"testing"
"time"

"golang.org/x/oauth2"

"google.golang.org/api/googleapi"
)

type mockTokenSource struct{}

func (*mockTokenSource) Token() (*oauth2.Token, error) {
return &oauth2.Token{
AccessToken: "access",
TokenType: "Bearer",
RefreshToken: "refresh",
Expiry: time.Now().Add(1 * time.Hour),
}, nil
}
func TestIsGCEError(t *testing.T) {
testCases := []struct {
name string
Expand Down Expand Up @@ -84,3 +98,55 @@ func TestIsGCEError(t *testing.T) {
}
}
}

func TestGetComputeVersion(t *testing.T) {
testCases := []struct {
name string
computeEndpoint string
computeEnvironment Environment
computeVersion Version
expectedEndpoint string
expectError bool
}{

{
name: "check for production environment",
computeEndpoint: "https://compute.googleapis.com",
computeEnvironment: "production",
computeVersion: "v1",
expectedEndpoint: "https://compute.googleapis.com/compute/v1/",
expectError: false,
},
{
name: "check for incorrect endpoint",
computeEndpoint: "https://compute.googleapis",
computeEnvironment: "prod",
computeVersion: "v1",
expectError: true,
},
{
name: "check for staging environment",
computeEndpoint: "https://compute.googleapis.com",
computeEnvironment: environmentStaging,
computeVersion: "v1",
expectedEndpoint: "compute/staging_v1/",
expectError: false,
},
{
name: "check for random string as endpoint",
computeEndpoint: "compute-googleapis",
computeEnvironment: "prod",
computeVersion: "v1",
expectedEndpoint: "compute/v1/",
expectError: true,
},
}
for _, tc := range testCases {
ctx := context.Background()
_, err := getComputeVersion(ctx, &mockTokenSource{}, tc.computeEndpoint, tc.computeEnvironment, tc.computeVersion)
if err != nil && !tc.expectError {
t.Fatalf("Got error %v, expected endpoint %s", err, tc.expectedEndpoint)
}
}

}
2 changes: 1 addition & 1 deletion pkg/gce-pd-csi-driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ const (
)

var (
validResourceApiVersions = map[string]bool{"v1": true, "alpha": true, "beta": true}
validResourceApiVersions = map[string]bool{"v1": true, "alpha": true, "beta": true, "staging_v1": true, "staging_beta": true, "staging_alpha": true}
)

func isDiskReady(disk *gce.CloudDisk) (bool, error) {
Expand Down
11 changes: 1 addition & 10 deletions test/e2e/tests/single_zone_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1280,7 +1280,7 @@ var _ = Describe("GCE PD CSI Driver", func() {
}()
})

It("Should pass/fail if valid/invalid compute endpoint is passed in", func() {
It("Should pass if valid compute endpoint is passed in", func() {
// gets instance set up w/o compute-endpoint set from test setup
_, err := getRandomTestContext().Client.ListVolumes()
Expect(err).To(BeNil(), "no error expected when passed valid compute url")
Expand All @@ -1295,15 +1295,6 @@ var _ = Describe("GCE PD CSI Driver", func() {

klog.Infof("Creating new driver and client for node %s\n", i.GetName())

// Create new driver and client w/ invalid endpoint
tcInvalid, err := testutils.GCEClientAndDriverSetup(i, "invalid-string")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we drop the invalid endpoint case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, IMO this should be treated as input validation for the program rather than the RPC layer. Passing an invalid endpoint should crash the program, rather than cause a delayed RPC error.

if err != nil {
klog.Fatalf("Failed to set up Test Context for instance %v: %v", i.GetName(), err)
}

_, err = tcInvalid.Client.ListVolumes()
Expect(err.Error()).To(ContainSubstring("no such host"), "expected error when passed invalid compute url")

// Create new driver and client w/ valid, passed-in endpoint
tcValid, err := testutils.GCEClientAndDriverSetup(i, "https://compute.googleapis.com")
if err != nil {
Expand Down
1 change: 0 additions & 1 deletion test/e2e/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ func GCEClientAndDriverSetup(instance *remote.InstanceInfo, computeEndpoint stri
// useful to see what's happening when debugging tests.
driverRunCmd := fmt.Sprintf("sh -c '/usr/bin/nohup %s/gce-pd-csi-driver -v=6 --endpoint=%s %s 2> %s/prog.out < /dev/null > /dev/null &'",
workspace, endpoint, strings.Join(extra_flags, " "), workspace)

config := &remote.ClientConfig{
PkgPath: pkgPath,
BinPath: binPath,
Expand Down