From 53a286c751bdf7b6c960617eb8f9787287b4897a Mon Sep 17 00:00:00 2001 From: "Hantao (Will) Wang" Date: Mon, 24 Jun 2019 15:05:49 -0700 Subject: [PATCH] add volume level seralization for controller operations --- pkg/common/utils.go | 11 +++ pkg/common/volume_lock.go | 58 +++++++++++++ pkg/gce-cloud-provider/compute/fake-gce.go | 19 ++++- pkg/gce-pd-csi-driver/controller.go | 45 ++++++++++ pkg/gce-pd-csi-driver/controller_test.go | 67 +++++++++++++++ pkg/gce-pd-csi-driver/gce-pd-driver.go | 3 + pkg/gce-pd-csi-driver/gce-pd-driver_test.go | 24 +++++- pkg/gce-pd-csi-driver/node.go | 29 ++++--- pkg/gce-pd-csi-driver/node_test.go | 91 ++++++++++----------- pkg/mount-manager/fake-safe-mounter.go | 33 +++----- test/sanity/sanity_test.go | 2 +- 11 files changed, 291 insertions(+), 91 deletions(-) create mode 100644 pkg/common/volume_lock.go diff --git a/pkg/common/utils.go b/pkg/common/utils.go index c191e6a4a..9cb028176 100644 --- a/pkg/common/utils.go +++ b/pkg/common/utils.go @@ -73,6 +73,17 @@ func VolumeIDToKey(id string) (*meta.Key, error) { } } +func KeyToVolumeID(volKey *meta.Key, project string) (string, error) { + switch volKey.Type() { + case meta.Zonal: + return fmt.Sprintf(volIDZonalFmt, project, volKey.Zone, volKey.Zone), nil + case meta.Regional: + return fmt.Sprintf(volIDZonalFmt, project, volKey.Region, volKey.Zone), nil + default: + return "", fmt.Errorf("volume key %v neither zonal nor regional", volKey.Name) + } +} + func GenerateUnderspecifiedVolumeID(diskName string, isZonal bool) string { if isZonal { return fmt.Sprintf(volIDZonalFmt, UnspecifiedValue, UnspecifiedValue, diskName) diff --git a/pkg/common/volume_lock.go b/pkg/common/volume_lock.go new file mode 100644 index 000000000..65105e426 --- /dev/null +++ b/pkg/common/volume_lock.go @@ -0,0 +1,58 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +import ( + "sync" + + "k8s.io/apimachinery/pkg/util/sets" +) + +const ( + VolumeOperationAlreadyExistsFmt = "An operation with the given Volume ID %s already exists" +) + +// VolumeLocks implements a map with atomic operations. It stores a set of all volume IDs +// with an ongoing operation. +type VolumeLocks struct { + locks sets.String + mux sync.Mutex +} + +func NewVolumeLocks() *VolumeLocks { + return &VolumeLocks{ + locks: sets.NewString(), + } +} + +// TryAcquire tries to acquire the lock for operating on volumeID and returns true if successful. +// If another operation is already using volumeID, returns false. +func (vl *VolumeLocks) TryAcquire(volumeID string) bool { + vl.mux.Lock() + defer vl.mux.Unlock() + if vl.locks.Has(volumeID) { + return false + } + vl.locks.Insert(volumeID) + return true +} + +func (vl *VolumeLocks) Release(volumeID string) { + vl.mux.Lock() + defer vl.mux.Unlock() + vl.locks.Delete(volumeID) +} diff --git a/pkg/gce-cloud-provider/compute/fake-gce.go b/pkg/gce-cloud-provider/compute/fake-gce.go index bf6c423fa..83ccfa3bd 100644 --- a/pkg/gce-cloud-provider/compute/fake-gce.go +++ b/pkg/gce-cloud-provider/compute/fake-gce.go @@ -49,7 +49,7 @@ type FakeCloudProvider struct { var _ GCECompute = &FakeCloudProvider{} -func FakeCreateCloudProvider(project, zone string, cloudDisks []*CloudDisk) (*FakeCloudProvider, error) { +func CreateFakeCloudProvider(project, zone string, cloudDisks []*CloudDisk) (*FakeCloudProvider, error) { fcp := &FakeCloudProvider{ project: project, zone: zone, @@ -61,7 +61,6 @@ func FakeCreateCloudProvider(project, zone string, cloudDisks []*CloudDisk) (*Fa fcp.disks[d.GetName()] = d } return fcp, nil - } func (cloud *FakeCloudProvider) RepairUnderspecifiedVolumeKey(ctx context.Context, volumeKey *meta.Key) (*meta.Key, error) { @@ -422,6 +421,22 @@ func (cloud *FakeCloudProvider) getGlobalSnapshotURI(snapshotName string) string snapshotName) } +type FakeBlockingCloudProvider struct { + *FakeCloudProvider + ReadyToExecute chan chan struct{} +} + +// FakeBlockingCloudProvider's method adds functionality to finely control the order of execution of CreateSnapshot calls. +// Upon starting a CreateSnapshot, it passes a chan 'executeCreateSnapshot' into readyToExecute, then blocks on executeCreateSnapshot. +// The test calling this function can block on readyToExecute to ensure that the operation has started and +// allowed the CreateSnapshot to continue by passing a struct into executeCreateSnapshot. +func (cloud *FakeBlockingCloudProvider) CreateSnapshot(ctx context.Context, volKey *meta.Key, snapshotName string) (*compute.Snapshot, error) { + executeCreateSnapshot := make(chan struct{}) + cloud.ReadyToExecute <- executeCreateSnapshot + <-executeCreateSnapshot + return cloud.FakeCloudProvider.CreateSnapshot(ctx, volKey, snapshotName) +} + func notFoundError() *googleapi.Error { return &googleapi.Error{ Errors: []googleapi.ErrorItem{ diff --git a/pkg/gce-pd-csi-driver/controller.go b/pkg/gce-pd-csi-driver/controller.go index 898c20097..f3228ac2c 100644 --- a/pkg/gce-pd-csi-driver/controller.go +++ b/pkg/gce-pd-csi-driver/controller.go @@ -42,6 +42,10 @@ type GCEControllerServer struct { Driver *GCEDriver CloudProvider gce.GCECompute MetadataService metadataservice.MetadataService + + // A map storing all volumes with ongoing operations so that additional operations + // for that same volume (as defined by Volume Key) return an Aborted error + volumeLocks *common.VolumeLocks } var _ csi.ControllerServer = &GCEControllerServer{} @@ -139,6 +143,15 @@ func (gceCS *GCEControllerServer) CreateVolume(ctx context.Context, req *csi.Cre return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("CreateVolume replication type '%s' is not supported", replicationType)) } + volumeID, err := common.KeyToVolumeID(volKey, gceCS.MetadataService.GetProject()) + if err != nil { + return nil, status.Errorf(codes.Internal, "Failed to convert volume key to volume ID: %v", err) + } + if acquired := gceCS.volumeLocks.TryAcquire(volumeID); !acquired { + return nil, status.Errorf(codes.Aborted, common.VolumeOperationAlreadyExistsFmt, volumeID) + } + defer gceCS.volumeLocks.Release(volumeID) + // Validate if disk already exists existingDisk, err := gceCS.CloudProvider.GetDisk(ctx, volKey) if err != nil { @@ -222,6 +235,11 @@ func (gceCS *GCEControllerServer) DeleteVolume(ctx context.Context, req *csi.Del return nil, status.Error(codes.NotFound, fmt.Sprintf("Could not find volume with ID %v: %v", volumeID, err)) } + if acquired := gceCS.volumeLocks.TryAcquire(volumeID); !acquired { + return nil, status.Errorf(codes.Aborted, common.VolumeOperationAlreadyExistsFmt, volumeID) + } + defer gceCS.volumeLocks.Release(volumeID) + err = gceCS.CloudProvider.DeleteDisk(ctx, volKey) if err != nil { return nil, status.Error(codes.Internal, fmt.Sprintf("unknown Delete disk error: %v", err)) @@ -258,6 +276,14 @@ func (gceCS *GCEControllerServer) ControllerPublishVolume(ctx context.Context, r return nil, status.Error(codes.NotFound, fmt.Sprintf("Could not find volume with ID %v: %v", volumeID, err)) } + // Acquires the lock for the volume on that node only, because we need to support the ability + // to publish the same volume onto different nodes concurrently + lockingVolumeID := fmt.Sprintf("%s/%s", nodeID, volumeID) + if acquired := gceCS.volumeLocks.TryAcquire(lockingVolumeID); !acquired { + return nil, status.Errorf(codes.Aborted, common.VolumeOperationAlreadyExistsFmt, lockingVolumeID) + } + defer gceCS.volumeLocks.Release(lockingVolumeID) + // TODO(#253): Check volume capability matches for ALREADY_EXISTS if err = validateVolumeCapability(volumeCapability); err != nil { return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("VolumeCapabilities is invalid: %v", err)) @@ -343,6 +369,14 @@ func (gceCS *GCEControllerServer) ControllerUnpublishVolume(ctx context.Context, return nil, err } + // Acquires the lock for the volume on that node only, because we need to support the ability + // to unpublish the same volume from different nodes concurrently + lockingVolumeID := fmt.Sprintf("%s/%s", nodeID, volumeID) + if acquired := gceCS.volumeLocks.TryAcquire(lockingVolumeID); !acquired { + return nil, status.Errorf(codes.Aborted, common.VolumeOperationAlreadyExistsFmt, lockingVolumeID) + } + defer gceCS.volumeLocks.Release(lockingVolumeID) + instanceZone, instanceName, err := common.NodeIDToZoneAndName(nodeID) if err != nil { return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("could not split nodeID: %v", err)) @@ -389,6 +423,12 @@ func (gceCS *GCEControllerServer) ValidateVolumeCapabilities(ctx context.Context if err != nil { return nil, status.Error(codes.NotFound, fmt.Sprintf("Volume ID is of improper format, got %v", volumeID)) } + + if acquired := gceCS.volumeLocks.TryAcquire(volumeID); !acquired { + return nil, status.Errorf(codes.Aborted, common.VolumeOperationAlreadyExistsFmt, volumeID) + } + defer gceCS.volumeLocks.Release(volumeID) + _, err = gceCS.CloudProvider.GetDisk(ctx, volKey) if err != nil { if gce.IsGCEError(err, "notFound") { @@ -496,6 +536,11 @@ func (gceCS *GCEControllerServer) CreateSnapshot(ctx context.Context, req *csi.C return nil, status.Error(codes.NotFound, fmt.Sprintf("Could not find volume with ID %v: %v", volumeID, err)) } + if acquired := gceCS.volumeLocks.TryAcquire(volumeID); !acquired { + return nil, status.Errorf(codes.Aborted, common.VolumeOperationAlreadyExistsFmt, volumeID) + } + defer gceCS.volumeLocks.Release(volumeID) + // Check if snapshot already exists var snapshot *compute.Snapshot snapshot, err = gceCS.CloudProvider.GetSnapshot(ctx, req.Name) diff --git a/pkg/gce-pd-csi-driver/controller_test.go b/pkg/gce-pd-csi-driver/controller_test.go index 1562ae572..07c2c0478 100644 --- a/pkg/gce-pd-csi-driver/controller_test.go +++ b/pkg/gce-pd-csi-driver/controller_test.go @@ -1413,3 +1413,70 @@ func TestPickRandAndConsecutive(t *testing.T) { } } + +func TestVolumeOperationConcurrency(t *testing.T) { + readyToExecute := make(chan chan struct{}, 1) + gceDriver := initBlockingGCEDriver(t, nil, readyToExecute) + cs := gceDriver.cs + + vol1CreateSnapshotAReq := &csi.CreateSnapshotRequest{ + Name: name + "1A", + SourceVolumeId: testVolumeId + "1", + } + vol1CreateSnapshotBReq := &csi.CreateSnapshotRequest{ + Name: name + "1B", + SourceVolumeId: testVolumeId + "1", + } + vol2CreateSnapshotReq := &csi.CreateSnapshotRequest{ + Name: name + "2", + SourceVolumeId: testVolumeId + "2", + } + + runRequest := func(req *csi.CreateSnapshotRequest) <-chan error { + response := make(chan error) + go func() { + _, err := cs.CreateSnapshot(context.Background(), req) + response <- err + }() + return response + } + + // Start first valid request vol1CreateSnapshotA and block until it reaches the CreateSnapshot + vol1CreateSnapshotAResp := runRequest(vol1CreateSnapshotAReq) + execVol1CreateSnapshotA := <-readyToExecute + + // Start vol1CreateSnapshotB and allow it to execute to completion. Then check for Aborted error. + // If a non Abort error is received or if the operation was started, then there is a problem + // with volume locking + vol1CreateSnapshotBResp := runRequest(vol1CreateSnapshotBReq) + select { + case err := <-vol1CreateSnapshotBResp: + if err != nil { + serverError, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from err: %v", err) + } + if serverError.Code() != codes.Aborted { + t.Errorf("Expected error code: %v, got: %v. err : %v", codes.Aborted, serverError.Code(), err) + } + } else { + t.Errorf("Expected error: %v, got no error", codes.Aborted) + } + case <-readyToExecute: + t.Errorf("The operation for vol1CreateSnapshotB should have been aborted, but was started") + } + + // Start vol2CreateSnapshot and allow it to execute to completion. Then check for success. + vol2CreateSnapshotResp := runRequest(vol2CreateSnapshotReq) + execVol2CreateSnapshot := <-readyToExecute + execVol2CreateSnapshot <- struct{}{} + if err := <-vol2CreateSnapshotResp; err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // To clean up, allow the vol1CreateSnapshotA to complete + execVol1CreateSnapshotA <- struct{}{} + if err := <-vol1CreateSnapshotAResp; err != nil { + t.Errorf("Unexpected error: %v", err) + } +} diff --git a/pkg/gce-pd-csi-driver/gce-pd-driver.go b/pkg/gce-pd-csi-driver/gce-pd-driver.go index 68d45cd62..c2c9196a9 100644 --- a/pkg/gce-pd-csi-driver/gce-pd-driver.go +++ b/pkg/gce-pd-csi-driver/gce-pd-driver.go @@ -22,6 +22,7 @@ import ( "google.golang.org/grpc/status" "k8s.io/klog" "k8s.io/kubernetes/pkg/util/mount" + common "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/common" gce "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/gce-cloud-provider/compute" metadataservice "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/gce-cloud-provider/metadata" mountmanager "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/mount-manager" @@ -136,6 +137,7 @@ func NewNodeServer(gceDriver *GCEDriver, mounter *mount.SafeFormatAndMount, devi Mounter: mounter, DeviceUtils: deviceUtils, MetadataService: meta, + volumeLocks: common.NewVolumeLocks(), } } @@ -144,6 +146,7 @@ func NewControllerServer(gceDriver *GCEDriver, cloudProvider gce.GCECompute, met Driver: gceDriver, CloudProvider: cloudProvider, MetadataService: meta, + volumeLocks: common.NewVolumeLocks(), } } diff --git a/pkg/gce-pd-csi-driver/gce-pd-driver_test.go b/pkg/gce-pd-csi-driver/gce-pd-driver_test.go index 501b01814..d06e7c23b 100644 --- a/pkg/gce-pd-csi-driver/gce-pd-driver_test.go +++ b/pkg/gce-pd-csi-driver/gce-pd-driver_test.go @@ -22,13 +22,29 @@ import ( ) func initGCEDriver(t *testing.T, cloudDisks []*gce.CloudDisk) *GCEDriver { - vendorVersion := "test-vendor" - gceDriver := GetGCEDriver() - fakeCloudProvider, err := gce.FakeCreateCloudProvider(project, zone, cloudDisks) + fakeCloudProvider, err := gce.CreateFakeCloudProvider(project, zone, cloudDisks) + if err != nil { + t.Fatalf("Failed to create fake cloud provider: %v", err) + } + return initGCEDriverWithCloudProvider(t, fakeCloudProvider) +} + +func initBlockingGCEDriver(t *testing.T, cloudDisks []*gce.CloudDisk, readyToExecute chan chan struct{}) *GCEDriver { + fakeCloudProvider, err := gce.CreateFakeCloudProvider(project, zone, cloudDisks) if err != nil { t.Fatalf("Failed to create fake cloud provider: %v", err) } - err = gceDriver.SetupGCEDriver(fakeCloudProvider, nil, nil, metadataservice.NewFakeService(), driver, vendorVersion) + fakeBlockingBlockProvider := &gce.FakeBlockingCloudProvider{ + FakeCloudProvider: fakeCloudProvider, + ReadyToExecute: readyToExecute, + } + return initGCEDriverWithCloudProvider(t, fakeBlockingBlockProvider) +} + +func initGCEDriverWithCloudProvider(t *testing.T, cloudProvider gce.GCECompute) *GCEDriver { + vendorVersion := "test-vendor" + gceDriver := GetGCEDriver() + err := gceDriver.SetupGCEDriver(cloudProvider, nil, nil, metadataservice.NewFakeService(), driver, vendorVersion) if err != nil { t.Fatalf("Failed to setup GCE Driver: %v", err) } diff --git a/pkg/gce-pd-csi-driver/node.go b/pkg/gce-pd-csi-driver/node.go index 025c176e1..8e4769663 100644 --- a/pkg/gce-pd-csi-driver/node.go +++ b/pkg/gce-pd-csi-driver/node.go @@ -18,7 +18,6 @@ import ( "fmt" "os" "strings" - "sync" "context" csi "github.com/container-storage-interface/spec/lib/go/csi" @@ -39,7 +38,7 @@ type GCENodeServer struct { // A map storing all volumes with ongoing operations so that additional operations // for that same volume (as defined by VolumeID) return an Aborted error - volumes sync.Map + volumeLocks *common.VolumeLocks } var _ csi.NodeServer = &GCENodeServer{} @@ -74,10 +73,10 @@ func (ns *GCENodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePub return nil, status.Error(codes.InvalidArgument, "NodePublishVolume Volume Capability must be provided") } - if _, alreadyExists := ns.volumes.LoadOrStore(volumeID, true); alreadyExists { - return nil, status.Error(codes.Aborted, fmt.Sprintf("An operation with the given Volume ID %s already exists", volumeID)) + if acquired := ns.volumeLocks.TryAcquire(volumeID); !acquired { + return nil, status.Errorf(codes.Aborted, common.VolumeOperationAlreadyExistsFmt, volumeID) } - defer ns.volumes.Delete(volumeID) + defer ns.volumeLocks.Release(volumeID) if err := validateVolumeCapability(volumeCapability); err != nil { return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("VolumeCapability is invalid: %v", err)) @@ -189,18 +188,18 @@ func (ns *GCENodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeU // Validate Arguments targetPath := req.GetTargetPath() - volID := req.GetVolumeId() - if len(volID) == 0 { + volumeID := req.GetVolumeId() + if len(volumeID) == 0 { return nil, status.Error(codes.InvalidArgument, "NodeUnpublishVolume Volume ID must be provided") } if len(targetPath) == 0 { return nil, status.Error(codes.InvalidArgument, "NodeUnpublishVolume Target Path must be provided") } - if _, alreadyExists := ns.volumes.LoadOrStore(volID, true); alreadyExists { - return nil, status.Error(codes.Aborted, fmt.Sprintf("An operation with the given Volume ID %s already exists", volID)) + if acquired := ns.volumeLocks.TryAcquire(volumeID); !acquired { + return nil, status.Errorf(codes.Aborted, common.VolumeOperationAlreadyExistsFmt, volumeID) } - defer ns.volumes.Delete(volID) + defer ns.volumeLocks.Release(volumeID) err := mount.CleanupMountPoint(targetPath, ns.Mounter.Interface, false /* bind mount */) if err != nil { @@ -227,10 +226,10 @@ func (ns *GCENodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStage return nil, status.Error(codes.InvalidArgument, "NodeStageVolume Volume Capability must be provided") } - if _, alreadyExists := ns.volumes.LoadOrStore(volumeID, true); alreadyExists { - return nil, status.Error(codes.Aborted, fmt.Sprintf("An operation with the given Volume ID %s already exists", volumeID)) + if acquired := ns.volumeLocks.TryAcquire(volumeID); !acquired { + return nil, status.Errorf(codes.Aborted, common.VolumeOperationAlreadyExistsFmt, volumeID) } - defer ns.volumes.Delete(volumeID) + defer ns.volumeLocks.Release(volumeID) if err := validateVolumeCapability(volumeCapability); err != nil { return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("VolumeCapability is invalid: %v", err)) @@ -321,10 +320,10 @@ func (ns *GCENodeServer) NodeUnstageVolume(ctx context.Context, req *csi.NodeUns return nil, status.Error(codes.InvalidArgument, "NodeUnstageVolume Staging Target Path must be provided") } - if _, alreadyExists := ns.volumes.LoadOrStore(volumeID, true); alreadyExists { + if acquired := ns.volumeLocks.TryAcquire(volumeID); !acquired { return nil, status.Error(codes.Aborted, fmt.Sprintf("An operation with the given Volume ID %s already exists", volumeID)) } - defer ns.volumes.Delete(volumeID) + defer ns.volumeLocks.Release(volumeID) err := mount.CleanupMountPoint(stagingTargetPath, ns.Mounter.Interface, false /* bind mount */) if err != nil { diff --git a/pkg/gce-pd-csi-driver/node_test.go b/pkg/gce-pd-csi-driver/node_test.go index 6ff01f73e..cb486ee9b 100644 --- a/pkg/gce-pd-csi-driver/node_test.go +++ b/pkg/gce-pd-csi-driver/node_test.go @@ -15,9 +15,9 @@ limitations under the License. package gceGCEDriver import ( + "context" "testing" - "context" csi "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -38,9 +38,9 @@ func getTestGCEDriver(t *testing.T) *GCEDriver { return gceDriver } -func getTestBlockingGCEDriver(t *testing.T, mountToRun chan mountmanager.MountSourceAndTarget, readyToMount chan struct{}) *GCEDriver { +func getTestBlockingGCEDriver(t *testing.T, readyToExecute chan chan struct{}) *GCEDriver { gceDriver := GetGCEDriver() - err := gceDriver.SetupGCEDriver(nil, mountmanager.NewFakeSafeBlockingMounter(mountToRun, readyToMount), mountmanager.NewFakeDeviceUtils(), metadataservice.NewFakeService(), driver, "test-vendor") + err := gceDriver.SetupGCEDriver(nil, mountmanager.NewFakeSafeBlockingMounter(readyToExecute), mountmanager.NewFakeDeviceUtils(), metadataservice.NewFakeService(), driver, "test-vendor") if err != nil { t.Fatalf("Failed to setup GCE Driver: %v", err) } @@ -389,12 +389,10 @@ func TestNodeGetCapabilities(t *testing.T) { } func TestConcurrentNodeOperations(t *testing.T) { - mountToRun := make(chan mountmanager.MountSourceAndTarget, 3) - readyToMount := make(chan struct{}, 2) - reqFinished := make(chan error, 2) - - gceDriver := getTestBlockingGCEDriver(t, mountToRun, readyToMount) + readyToExecute := make(chan chan struct{}, 1) + gceDriver := getTestBlockingGCEDriver(t, readyToExecute) ns := gceDriver.ns + vol1PublishTargetAReq := &csi.NodePublishVolumeRequest{ VolumeId: defaultVolumeID + "1", TargetPath: defaultTargetPath + "a", @@ -417,52 +415,51 @@ func TestConcurrentNodeOperations(t *testing.T) { VolumeCapability: stdVolCap, } - runRequestInBackground := func(req *csi.NodePublishVolumeRequest) { - _, err := ns.NodePublishVolume(context.Background(), req) - reqFinished <- err + runRequest := func(req *csi.NodePublishVolumeRequest) chan error { + response := make(chan error) + go func() { + _, err := ns.NodePublishVolume(context.Background(), req) + response <- err + }() + return response } - // Start first valid request vol1PublishTargetAReq and block until it reaches the Mount - go runRequestInBackground(vol1PublishTargetAReq) - <-readyToMount + // Start first valid request vol1PublishTargetA and block until it reaches the Mount + vol1PublishTargetAResp := runRequest(vol1PublishTargetAReq) + execVol1PublishTargetA := <-readyToExecute - // Check that vol1PublishTargetBReq is rejected, due to same volume ID - // Also allow vol1PublishTargetBReq to complete, in case it is allowed to Mount - mountToRun <- mountmanager.MountSourceAndTarget{ - Source: vol1PublishTargetBReq.StagingTargetPath, - Target: vol1PublishTargetBReq.TargetPath, - } - _, err := ns.NodePublishVolume(context.Background(), vol1PublishTargetBReq) - if err != nil { - serverError, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from err: %v", err) - } - if serverError.Code() != codes.Aborted { - t.Fatalf("Expected error code: %v, got: %v. err : %v", codes.Aborted, serverError.Code(), err) + // Start vol1PublishTargetB and allow it to execute to completion. Then check for Aborted error. + // If a non Abort error is received or if the operation was started, then there is a problem + // with volume locking. + vol1PublishTargetBResp := runRequest(vol1PublishTargetBReq) + select { + case err := <-vol1PublishTargetBResp: + if err != nil { + serverError, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from err: %v", err) + } + if serverError.Code() != codes.Aborted { + t.Errorf("Expected error code: %v, got: %v. err : %v", codes.Aborted, serverError.Code(), err) + } + } else { + t.Errorf("Expected error: %v, got no error", codes.Aborted) } - } else { - t.Fatalf("Expected error: %v, got no error", codes.Aborted) + case <-readyToExecute: + t.Errorf("The operation for vol1PublishTargetB should have been aborted, but was started") } - // Start second valid request vol2PublishTargetCReq - go runRequestInBackground(vol2PublishTargetCReq) - - // Allow the vol2PublishTargetCReq to complete, which it can concurrently with vol1PublishTargetAReq - mountToRun <- mountmanager.MountSourceAndTarget{ - Source: vol2PublishTargetCReq.StagingTargetPath, - Target: vol2PublishTargetCReq.TargetPath, - } - if err = <-reqFinished; err != nil { - t.Fatalf("Unexpected error: %v", err) + // Start vol2PublishTargetC and allow it to execute to completion. Then check for success. + vol2PublishTargetCResp := runRequest(vol2PublishTargetCReq) + execVol2PublishTargetC := <-readyToExecute + execVol2PublishTargetC <- struct{}{} + if err := <-vol2PublishTargetCResp; err != nil { + t.Errorf("Unexpected error: %v", err) } - // To clean up, allow the vol1PublishTargetAReq to complete - mountToRun <- mountmanager.MountSourceAndTarget{ - Source: vol1PublishTargetAReq.StagingTargetPath, - Target: vol1PublishTargetAReq.TargetPath, - } - if err = <-reqFinished; err != nil { - t.Fatalf("Unexpected error: %v", err) + // To clean up, allow the vol1PublishTargetA to complete + execVol1PublishTargetA <- struct{}{} + if err := <-vol1PublishTargetAResp; err != nil { + t.Errorf("Unexpected error: %v", err) } } diff --git a/pkg/mount-manager/fake-safe-mounter.go b/pkg/mount-manager/fake-safe-mounter.go index 9216cd5e3..af8216d3e 100644 --- a/pkg/mount-manager/fake-safe-mounter.go +++ b/pkg/mount-manager/fake-safe-mounter.go @@ -52,35 +52,24 @@ func NewFakeSafeMounter() *mount.SafeFormatAndMount { type FakeBlockingMounter struct { *mount.FakeMounter - mountToRun chan MountSourceAndTarget - readyToMount chan struct{} + ReadyToExecute chan chan struct{} } -type MountSourceAndTarget struct { - Source string - Target string -} - -// FakeBlockingMounter's method adds two channels to the Mount process in order to provide functionality to finely -// control the order of execution of Mount calls. readToMount signals that a Mount operation has been called. -// Then it cycles through the mountToRun channel, waiting for permission to actually make the mount operation. +// Mount is ovverridden and adds functionality to finely control the order of execution of FakeMounter's Mount calls. +// Upon starting a Mount, it passes a chan 'executeMount' into readyToExecute, then blocks on executeMount. +// The test calling this function can block on readyToExecute to ensure that the operation has started and +// allowed the Mount to continue by passing a struct into executeMount. func (mounter *FakeBlockingMounter) Mount(source string, target string, fstype string, options []string) error { - mounter.readyToMount <- struct{}{} - for mountToRun := range mounter.mountToRun { - if mountToRun.Source == source && mountToRun.Target == target { - break - } else { - mounter.mountToRun <- mountToRun - } - } + executeMount := make(chan struct{}) + mounter.ReadyToExecute <- executeMount + <-executeMount return mounter.FakeMounter.Mount(source, target, fstype, options) } -func NewFakeSafeBlockingMounter(mountToRun chan MountSourceAndTarget, readyToMount chan struct{}) *mount.SafeFormatAndMount { +func NewFakeSafeBlockingMounter(readyToExecute chan chan struct{}) *mount.SafeFormatAndMount { fakeBlockingMounter := &FakeBlockingMounter{ - FakeMounter: fakeMounter, - mountToRun: mountToRun, - readyToMount: readyToMount, + FakeMounter: fakeMounter, + ReadyToExecute: readyToExecute, } return &mount.SafeFormatAndMount{ Interface: fakeBlockingMounter, diff --git a/test/sanity/sanity_test.go b/test/sanity/sanity_test.go index 4529bd0a5..6c95a92ae 100644 --- a/test/sanity/sanity_test.go +++ b/test/sanity/sanity_test.go @@ -41,7 +41,7 @@ func TestSanity(t *testing.T) { // Set up driver and env gceDriver := driver.GetGCEDriver() - cloudProvider, err := gce.FakeCreateCloudProvider(project, zone, nil) + cloudProvider, err := gce.CreateFakeCloudProvider(project, zone, nil) if err != nil { t.Fatalf("Failed to get cloud provider: %v", err) }