Skip to content

Add volume level serialization for controller operations #316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pkg/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@ func VolumeIDToKey(id string) (*meta.Key, error) {
}
}

func KeyToVolumeID(volKey *meta.Key, project string) (string, error) {
switch volKey.Type() {
case meta.Zonal:
return fmt.Sprintf(volIDZonalFmt, project, volKey.Zone, volKey.Zone), nil
case meta.Regional:
return fmt.Sprintf(volIDZonalFmt, project, volKey.Region, volKey.Zone), nil
default:
return "", fmt.Errorf("volume key %v neither zonal nor regional", volKey.Name)
}
}

func GenerateUnderspecifiedVolumeID(diskName string, isZonal bool) string {
if isZonal {
return fmt.Sprintf(volIDZonalFmt, UnspecifiedValue, UnspecifiedValue, diskName)
Expand Down
58 changes: 58 additions & 0 deletions pkg/common/volume_lock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
Copyright 2019 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package common

import (
"sync"

"k8s.io/apimachinery/pkg/util/sets"
)

const (
VolumeOperationAlreadyExistsFmt = "An operation with the given Volume ID %s already exists"
)

// VolumeLocks implements a map with atomic operations. It stores a set of all volume IDs
// with an ongoing operation.
type VolumeLocks struct {
locks sets.String
mux sync.Mutex
}

func NewVolumeLocks() *VolumeLocks {
return &VolumeLocks{
locks: sets.NewString(),
}
}

// TryAcquire tries to acquire the lock for operating on volumeID and returns true if successful.
// If another operation is already using volumeID, returns false.
func (vl *VolumeLocks) TryAcquire(volumeID string) bool {
vl.mux.Lock()
defer vl.mux.Unlock()
if vl.locks.Has(volumeID) {
return false
}
vl.locks.Insert(volumeID)
return true
}

func (vl *VolumeLocks) Release(volumeID string) {
vl.mux.Lock()
defer vl.mux.Unlock()
vl.locks.Delete(volumeID)
}
19 changes: 17 additions & 2 deletions pkg/gce-cloud-provider/compute/fake-gce.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type FakeCloudProvider struct {

var _ GCECompute = &FakeCloudProvider{}

func FakeCreateCloudProvider(project, zone string, cloudDisks []*CloudDisk) (*FakeCloudProvider, error) {
func CreateFakeCloudProvider(project, zone string, cloudDisks []*CloudDisk) (*FakeCloudProvider, error) {
fcp := &FakeCloudProvider{
project: project,
zone: zone,
Expand All @@ -61,7 +61,6 @@ func FakeCreateCloudProvider(project, zone string, cloudDisks []*CloudDisk) (*Fa
fcp.disks[d.GetName()] = d
}
return fcp, nil

}

func (cloud *FakeCloudProvider) RepairUnderspecifiedVolumeKey(ctx context.Context, volumeKey *meta.Key) (*meta.Key, error) {
Expand Down Expand Up @@ -422,6 +421,22 @@ func (cloud *FakeCloudProvider) getGlobalSnapshotURI(snapshotName string) string
snapshotName)
}

type FakeBlockingCloudProvider struct {
*FakeCloudProvider
ReadyToExecute chan chan struct{}
}

// FakeBlockingCloudProvider's method adds functionality to finely control the order of execution of CreateSnapshot calls.
// Upon starting a CreateSnapshot, it passes a chan 'executeCreateSnapshot' into readyToExecute, then blocks on executeCreateSnapshot.
// The test calling this function can block on readyToExecute to ensure that the operation has started and
// allowed the CreateSnapshot to continue by passing a struct into executeCreateSnapshot.
func (cloud *FakeBlockingCloudProvider) CreateSnapshot(ctx context.Context, volKey *meta.Key, snapshotName string) (*compute.Snapshot, error) {
executeCreateSnapshot := make(chan struct{})
cloud.ReadyToExecute <- executeCreateSnapshot
<-executeCreateSnapshot
return cloud.FakeCloudProvider.CreateSnapshot(ctx, volKey, snapshotName)
}

func notFoundError() *googleapi.Error {
return &googleapi.Error{
Errors: []googleapi.ErrorItem{
Expand Down
45 changes: 45 additions & 0 deletions pkg/gce-pd-csi-driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ type GCEControllerServer struct {
Driver *GCEDriver
CloudProvider gce.GCECompute
MetadataService metadataservice.MetadataService

// A map storing all volumes with ongoing operations so that additional operations
// for that same volume (as defined by Volume Key) return an Aborted error
volumeLocks *common.VolumeLocks
}

var _ csi.ControllerServer = &GCEControllerServer{}
Expand Down Expand Up @@ -139,6 +143,15 @@ func (gceCS *GCEControllerServer) CreateVolume(ctx context.Context, req *csi.Cre
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("CreateVolume replication type '%s' is not supported", replicationType))
}

volumeID, err := common.KeyToVolumeID(volKey, gceCS.MetadataService.GetProject())
if err != nil {
return nil, status.Errorf(codes.Internal, "Failed to convert volume key to volume ID: %v", err)
}
if acquired := gceCS.volumeLocks.TryAcquire(volumeID); !acquired {
return nil, status.Errorf(codes.Aborted, common.VolumeOperationAlreadyExistsFmt, volumeID)
}
defer gceCS.volumeLocks.Release(volumeID)

// Validate if disk already exists
existingDisk, err := gceCS.CloudProvider.GetDisk(ctx, volKey)
if err != nil {
Expand Down Expand Up @@ -222,6 +235,11 @@ func (gceCS *GCEControllerServer) DeleteVolume(ctx context.Context, req *csi.Del
return nil, status.Error(codes.NotFound, fmt.Sprintf("Could not find volume with ID %v: %v", volumeID, err))
}

if acquired := gceCS.volumeLocks.TryAcquire(volumeID); !acquired {
return nil, status.Errorf(codes.Aborted, common.VolumeOperationAlreadyExistsFmt, volumeID)
}
defer gceCS.volumeLocks.Release(volumeID)

err = gceCS.CloudProvider.DeleteDisk(ctx, volKey)
if err != nil {
return nil, status.Error(codes.Internal, fmt.Sprintf("unknown Delete disk error: %v", err))
Expand Down Expand Up @@ -258,6 +276,14 @@ func (gceCS *GCEControllerServer) ControllerPublishVolume(ctx context.Context, r
return nil, status.Error(codes.NotFound, fmt.Sprintf("Could not find volume with ID %v: %v", volumeID, err))
}

// Acquires the lock for the volume on that node only, because we need to support the ability
// to publish the same volume onto different nodes concurrently
lockingVolumeID := fmt.Sprintf("%s/%s", nodeID, volumeID)
if acquired := gceCS.volumeLocks.TryAcquire(lockingVolumeID); !acquired {
return nil, status.Errorf(codes.Aborted, common.VolumeOperationAlreadyExistsFmt, lockingVolumeID)
}
defer gceCS.volumeLocks.Release(lockingVolumeID)

// TODO(#253): Check volume capability matches for ALREADY_EXISTS
if err = validateVolumeCapability(volumeCapability); err != nil {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("VolumeCapabilities is invalid: %v", err))
Expand Down Expand Up @@ -343,6 +369,14 @@ func (gceCS *GCEControllerServer) ControllerUnpublishVolume(ctx context.Context,
return nil, err
}

// Acquires the lock for the volume on that node only, because we need to support the ability
// to unpublish the same volume from different nodes concurrently
lockingVolumeID := fmt.Sprintf("%s/%s", nodeID, volumeID)
if acquired := gceCS.volumeLocks.TryAcquire(lockingVolumeID); !acquired {
return nil, status.Errorf(codes.Aborted, common.VolumeOperationAlreadyExistsFmt, lockingVolumeID)
}
defer gceCS.volumeLocks.Release(lockingVolumeID)

instanceZone, instanceName, err := common.NodeIDToZoneAndName(nodeID)
if err != nil {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("could not split nodeID: %v", err))
Expand Down Expand Up @@ -389,6 +423,12 @@ func (gceCS *GCEControllerServer) ValidateVolumeCapabilities(ctx context.Context
if err != nil {
return nil, status.Error(codes.NotFound, fmt.Sprintf("Volume ID is of improper format, got %v", volumeID))
}

if acquired := gceCS.volumeLocks.TryAcquire(volumeID); !acquired {
return nil, status.Errorf(codes.Aborted, common.VolumeOperationAlreadyExistsFmt, volumeID)
}
defer gceCS.volumeLocks.Release(volumeID)

_, err = gceCS.CloudProvider.GetDisk(ctx, volKey)
if err != nil {
if gce.IsGCEError(err, "notFound") {
Expand Down Expand Up @@ -496,6 +536,11 @@ func (gceCS *GCEControllerServer) CreateSnapshot(ctx context.Context, req *csi.C
return nil, status.Error(codes.NotFound, fmt.Sprintf("Could not find volume with ID %v: %v", volumeID, err))
}

if acquired := gceCS.volumeLocks.TryAcquire(volumeID); !acquired {
return nil, status.Errorf(codes.Aborted, common.VolumeOperationAlreadyExistsFmt, volumeID)
}
defer gceCS.volumeLocks.Release(volumeID)

// Check if snapshot already exists
var snapshot *compute.Snapshot
snapshot, err = gceCS.CloudProvider.GetSnapshot(ctx, req.Name)
Expand Down
67 changes: 67 additions & 0 deletions pkg/gce-pd-csi-driver/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1413,3 +1413,70 @@ func TestPickRandAndConsecutive(t *testing.T) {

}
}

func TestVolumeOperationConcurrency(t *testing.T) {
readyToExecute := make(chan chan struct{}, 1)
gceDriver := initBlockingGCEDriver(t, nil, readyToExecute)
cs := gceDriver.cs

vol1CreateSnapshotAReq := &csi.CreateSnapshotRequest{
Name: name + "1A",
SourceVolumeId: testVolumeId + "1",
}
vol1CreateSnapshotBReq := &csi.CreateSnapshotRequest{
Name: name + "1B",
SourceVolumeId: testVolumeId + "1",
}
vol2CreateSnapshotReq := &csi.CreateSnapshotRequest{
Name: name + "2",
SourceVolumeId: testVolumeId + "2",
}

runRequest := func(req *csi.CreateSnapshotRequest) <-chan error {
response := make(chan error)
go func() {
_, err := cs.CreateSnapshot(context.Background(), req)
response <- err
}()
return response
}

// Start first valid request vol1CreateSnapshotA and block until it reaches the CreateSnapshot
vol1CreateSnapshotAResp := runRequest(vol1CreateSnapshotAReq)
execVol1CreateSnapshotA := <-readyToExecute

// Start vol1CreateSnapshotB and allow it to execute to completion. Then check for Aborted error.
// If a non Abort error is received or if the operation was started, then there is a problem
// with volume locking
vol1CreateSnapshotBResp := runRequest(vol1CreateSnapshotBReq)
select {
case err := <-vol1CreateSnapshotBResp:
if err != nil {
serverError, ok := status.FromError(err)
if !ok {
t.Fatalf("Could not get error status code from err: %v", err)
}
if serverError.Code() != codes.Aborted {
t.Errorf("Expected error code: %v, got: %v. err : %v", codes.Aborted, serverError.Code(), err)
}
} else {
t.Errorf("Expected error: %v, got no error", codes.Aborted)
}
case <-readyToExecute:
t.Errorf("The operation for vol1CreateSnapshotB should have been aborted, but was started")
}

// Start vol2CreateSnapshot and allow it to execute to completion. Then check for success.
vol2CreateSnapshotResp := runRequest(vol2CreateSnapshotReq)
execVol2CreateSnapshot := <-readyToExecute
execVol2CreateSnapshot <- struct{}{}
if err := <-vol2CreateSnapshotResp; err != nil {
t.Errorf("Unexpected error: %v", err)
}

// To clean up, allow the vol1CreateSnapshotA to complete
execVol1CreateSnapshotA <- struct{}{}
if err := <-vol1CreateSnapshotAResp; err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
3 changes: 3 additions & 0 deletions pkg/gce-pd-csi-driver/gce-pd-driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"google.golang.org/grpc/status"
"k8s.io/klog"
"k8s.io/kubernetes/pkg/util/mount"
common "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/common"
gce "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/gce-cloud-provider/compute"
metadataservice "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/gce-cloud-provider/metadata"
mountmanager "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/mount-manager"
Expand Down Expand Up @@ -136,6 +137,7 @@ func NewNodeServer(gceDriver *GCEDriver, mounter *mount.SafeFormatAndMount, devi
Mounter: mounter,
DeviceUtils: deviceUtils,
MetadataService: meta,
volumeLocks: common.NewVolumeLocks(),
}
}

Expand All @@ -144,6 +146,7 @@ func NewControllerServer(gceDriver *GCEDriver, cloudProvider gce.GCECompute, met
Driver: gceDriver,
CloudProvider: cloudProvider,
MetadataService: meta,
volumeLocks: common.NewVolumeLocks(),
}
}

Expand Down
24 changes: 20 additions & 4 deletions pkg/gce-pd-csi-driver/gce-pd-driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,29 @@ import (
)

func initGCEDriver(t *testing.T, cloudDisks []*gce.CloudDisk) *GCEDriver {
vendorVersion := "test-vendor"
gceDriver := GetGCEDriver()
fakeCloudProvider, err := gce.FakeCreateCloudProvider(project, zone, cloudDisks)
fakeCloudProvider, err := gce.CreateFakeCloudProvider(project, zone, cloudDisks)
if err != nil {
t.Fatalf("Failed to create fake cloud provider: %v", err)
}
return initGCEDriverWithCloudProvider(t, fakeCloudProvider)
}

func initBlockingGCEDriver(t *testing.T, cloudDisks []*gce.CloudDisk, readyToExecute chan chan struct{}) *GCEDriver {
fakeCloudProvider, err := gce.CreateFakeCloudProvider(project, zone, cloudDisks)
if err != nil {
t.Fatalf("Failed to create fake cloud provider: %v", err)
}
err = gceDriver.SetupGCEDriver(fakeCloudProvider, nil, nil, metadataservice.NewFakeService(), driver, vendorVersion)
fakeBlockingBlockProvider := &gce.FakeBlockingCloudProvider{
FakeCloudProvider: fakeCloudProvider,
ReadyToExecute: readyToExecute,
}
return initGCEDriverWithCloudProvider(t, fakeBlockingBlockProvider)
}

func initGCEDriverWithCloudProvider(t *testing.T, cloudProvider gce.GCECompute) *GCEDriver {
vendorVersion := "test-vendor"
gceDriver := GetGCEDriver()
err := gceDriver.SetupGCEDriver(cloudProvider, nil, nil, metadataservice.NewFakeService(), driver, vendorVersion)
if err != nil {
t.Fatalf("Failed to setup GCE Driver: %v", err)
}
Expand Down
Loading