diff --git a/pkg/gce-pd-csi-driver/gce-pd-driver.go b/pkg/gce-pd-csi-driver/gce-pd-driver.go index 5ae69fbf0..77d383f3b 100644 --- a/pkg/gce-pd-csi-driver/gce-pd-driver.go +++ b/pkg/gce-pd-csi-driver/gce-pd-driver.go @@ -136,6 +136,7 @@ func NewNodeServer(gceDriver *GCEDriver, mounter *mount.SafeFormatAndMount, devi Mounter: mounter, DeviceUtils: deviceUtils, MetadataService: meta, + LockManager: NewLockManager(NewSyncMutex), } } diff --git a/pkg/gce-pd-csi-driver/lock_manager.go b/pkg/gce-pd-csi-driver/lock_manager.go new file mode 100644 index 000000000..328266083 --- /dev/null +++ b/pkg/gce-pd-csi-driver/lock_manager.go @@ -0,0 +1,62 @@ +package gceGCEDriver + +import ( + "github.com/golang/glog" + "sync" +) + +type lockWithWaiters struct { + mux sync.Locker + waiters uint32 +} + +type LockManager struct { + mux sync.Mutex + newLocker func(...interface{}) sync.Locker + locks map[string]*lockWithWaiters +} + +func NewLockManager(f func(...interface{}) sync.Locker) *LockManager { + return &LockManager{ + newLocker: f, + locks: make(map[string]*lockWithWaiters), + } +} + +func NewSyncMutex(lockerParams ...interface{}) sync.Locker { + return &sync.Mutex{} +} + +// Acquires the lock corresponding to a key, and allocates that lock if one does not exist. +func (lm *LockManager) Acquire(key string, lockerParams ...interface{}) { + lm.mux.Lock() + lockForKey, ok := lm.locks[key] + if !ok { + lockForKey = &lockWithWaiters{ + mux: lm.newLocker(lockerParams...), + waiters: 0, + } + lm.locks[key] = lockForKey + } + lockForKey.waiters += 1 + lm.mux.Unlock() + lockForKey.mux.Lock() +} + +// Releases the lock corresponding to a key, and deallocates that lock if no other thread +// is waiting to acquire it. Logs an error and returns if the lock for a key does not exist. +func (lm *LockManager) Release(key string) { + lm.mux.Lock() + lockForKey, ok := lm.locks[key] + if !ok { + // This should not happen, but if it does the only thing to do is to log the error + glog.Errorf("the key being released does not correspond to an existing lock") + return + } + lockForKey.waiters -= 1 + lockForKey.mux.Unlock() + if lockForKey.waiters == 0 { + delete(lm.locks, key) + } + lm.mux.Unlock() +} diff --git a/pkg/gce-pd-csi-driver/lock_manager_test.go b/pkg/gce-pd-csi-driver/lock_manager_test.go new file mode 100644 index 000000000..aba8333e1 --- /dev/null +++ b/pkg/gce-pd-csi-driver/lock_manager_test.go @@ -0,0 +1,145 @@ +package gceGCEDriver + +import ( + "sync" + "testing" + "time" +) + +// Checks that the lock manager has the expected number of locks allocated. +// this function is implementation dependant! It acquires the lock and directly +// checks the map of the lock manager. +func checkAllocation(lm *LockManager, expectedNumAllocated int, t *testing.T) { + lm.mux.Lock() + defer lm.mux.Unlock() + if len(lm.locks) != expectedNumAllocated { + t.Fatalf("expected %d locks allocated, but found %d", expectedNumAllocated, len(lm.locks)) + } +} + +// coinOperatedMutex is a mutex that only acquires if a "coin" is provided. Otherwise +// it sleeps until there is both a coin and the lock is free. This is used +// so a parent thread can control the execution of children's lock. +type coinOperatedMutex struct { + mux *sync.Mutex + cond *sync.Cond + held bool + coin chan coin + t *testing.T +} + +type coin struct{} + +func (m *coinOperatedMutex) Lock() { + m.mux.Lock() + defer m.mux.Unlock() + + for m.held || len(m.coin) == 0 { + m.cond.Wait() + } + <-m.coin + m.held = true +} + +func (m *coinOperatedMutex) Unlock() { + m.mux.Lock() + defer m.mux.Unlock() + + m.held = false + m.cond.Broadcast() +} + +func (m *coinOperatedMutex) Deposit() { + m.mux.Lock() + defer m.mux.Unlock() + + m.coin <- coin{} + m.cond.Broadcast() +} + +func passCoinOperatedMutex(lockerParams ...interface{}) sync.Locker { + return lockerParams[0].(*coinOperatedMutex) +} + +func TestLockManagerSingle(t *testing.T) { + lm := NewLockManager(NewSyncMutex) + lm.Acquire("A") + checkAllocation(lm, 1, t) + lm.Acquire("B") + checkAllocation(lm, 2, t) + lm.Release("A") + checkAllocation(lm, 1, t) + lm.Release("B") + checkAllocation(lm, 0, t) +} + +func TestLockManagerMultiple(t *testing.T) { + lm := NewLockManager(passCoinOperatedMutex) + m := &sync.Mutex{} + com := &coinOperatedMutex{ + mux: m, + cond: sync.NewCond(m), + coin: make(chan coin, 1), + held: false, + t: t, + } + + // start thread 1 + t1OperationFinished := make(chan coin, 1) + t1OkToRelease := make(chan coin, 1) + go func() { + lm.Acquire("A", com) + t1OperationFinished <- coin{} + <-t1OkToRelease + lm.Release("A") + t1OperationFinished <- coin{} + }() + + // this allows the acquire by thread 1 to acquire + com.Deposit() + <-t1OperationFinished + + // thread 1 should have acquired the lock, putting allocation at 1 + checkAllocation(lm, 1, t) + + // start thread 2 + // this should allow thread 2 to start the acquire for A through the + // lock manager, but block on the acquire Lock() of the lock for A. + t2OperationFinished := make(chan coin, 1) + t2OkToRelease := make(chan coin, 1) + go func() { + lm.Acquire("A") + t2OperationFinished <- coin{} + <-t2OkToRelease + lm.Release("A") + t2OperationFinished <- coin{} + }() + + // because now thread 2 is the only thread that can run, we must wait + // until it runs until it is blocked on acquire. for simplicity just wait + // 5 seconds. + time.Sleep(time.Second * 3) + + // this allows the release by thread 1 to complete + // only the release can run because the acquire by thread 1 can only run if + // there is both a coin and the lock is free + t1OkToRelease <- coin{} + <-t1OperationFinished + + // check that the lock has not been deallocated, since thread 2 is still waiting to acquire it + checkAllocation(lm, 1, t) + + // this allows t2 to finish its acquire + com.Deposit() + <-t2OperationFinished + + // check that the lock has been deallocated, since thread 2 still holds it + checkAllocation(lm, 1, t) + + // this allows the release by thread 2 to release + t2OkToRelease <- coin{} + <-t2OperationFinished + + // check that the lock has been deallocated + checkAllocation(lm, 0, t) +} diff --git a/pkg/gce-pd-csi-driver/node.go b/pkg/gce-pd-csi-driver/node.go index 8354b196b..afcd474ef 100644 --- a/pkg/gce-pd-csi-driver/node.go +++ b/pkg/gce-pd-csi-driver/node.go @@ -18,7 +18,6 @@ import ( "fmt" "os" "strings" - "sync" csi "github.com/container-storage-interface/spec/lib/go/csi" "github.com/golang/glog" @@ -37,8 +36,7 @@ type GCENodeServer struct { Mounter *mount.SafeFormatAndMount DeviceUtils mountmanager.DeviceUtils MetadataService metadataservice.MetadataService - // TODO: Only lock mutually exclusive calls and make locking more fine grained - mux sync.Mutex + LockManager *LockManager } var _ csi.NodeServer = &GCENodeServer{} @@ -52,8 +50,6 @@ const ( ) func (ns *GCENodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { - ns.mux.Lock() - defer ns.mux.Unlock() glog.V(4).Infof("NodePublishVolume called with req: %#v", req) // Validate Arguments @@ -74,6 +70,8 @@ func (ns *GCENodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePub if volumeCapability == nil { return nil, status.Error(codes.InvalidArgument, "NodePublishVolume Volume Capability must be provided") } + ns.LockManager.Acquire(string(volumeID)) + defer ns.LockManager.Release(string(volumeID)) if err := validateVolumeCapability(volumeCapability); err != nil { return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("VolumeCapability is invalid: %v", err)) @@ -181,8 +179,6 @@ func (ns *GCENodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePub } func (ns *GCENodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { - ns.mux.Lock() - defer ns.mux.Unlock() glog.V(4).Infof("NodeUnpublishVolume called with args: %v", req) // Validate Arguments targetPath := req.GetTargetPath() @@ -194,6 +190,9 @@ func (ns *GCENodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeU return nil, status.Error(codes.InvalidArgument, "NodeUnpublishVolume Target Path must be provided") } + ns.LockManager.Acquire(string(volID)) + defer ns.LockManager.Release(string(volID)) + err := volumeutils.UnmountMountPoint(targetPath, ns.Mounter.Interface, false /* bind mount */) if err != nil { return nil, status.Error(codes.Internal, fmt.Sprintf("Unmount failed: %v\nUnmounting arguments: %s\n", err, targetPath)) @@ -203,8 +202,6 @@ func (ns *GCENodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeU } func (ns *GCENodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { - ns.mux.Lock() - defer ns.mux.Unlock() glog.V(4).Infof("NodeStageVolume called with req: %#v", req) // Validate Arguments @@ -221,6 +218,9 @@ func (ns *GCENodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStage return nil, status.Error(codes.InvalidArgument, "NodeStageVolume Volume Capability must be provided") } + ns.LockManager.Acquire(string(volumeID)) + defer ns.LockManager.Release(string(volumeID)) + if err := validateVolumeCapability(volumeCapability); err != nil { return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("VolumeCapability is invalid: %v", err)) } @@ -298,8 +298,6 @@ func (ns *GCENodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStage } func (ns *GCENodeServer) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { - ns.mux.Lock() - defer ns.mux.Unlock() glog.V(4).Infof("NodeUnstageVolume called with req: %#v", req) // Validate arguments volumeID := req.GetVolumeId() @@ -310,6 +308,8 @@ func (ns *GCENodeServer) NodeUnstageVolume(ctx context.Context, req *csi.NodeUns if len(stagingTargetPath) == 0 { return nil, status.Error(codes.InvalidArgument, "NodeUnstageVolume Staging Target Path must be provided") } + ns.LockManager.Acquire(string(volumeID)) + defer ns.LockManager.Release(string(volumeID)) err := volumeutils.UnmountMountPoint(stagingTargetPath, ns.Mounter.Interface, false /* bind mount */) if err != nil {