@@ -35,11 +35,25 @@ type Weighted struct {
35
35
// Acquire acquires the semaphore with a weight of n, blocking until resources
36
36
// are available or ctx is done. On success, returns nil. On failure, returns
37
37
// ctx.Err() and leaves the semaphore unchanged.
38
- //
39
- // If ctx is already done, Acquire may still succeed without blocking.
40
38
func (s * Weighted ) Acquire (ctx context.Context , n int64 ) error {
39
+ done := ctx .Done ()
40
+
41
41
s .mu .Lock ()
42
+ select {
43
+ case <- done :
44
+ // ctx becoming done has "happened before" acquiring the semaphore,
45
+ // whether it became done before the call began or while we were
46
+ // waiting for the mutex. We prefer to fail even if we could acquire
47
+ // the mutex without blocking.
48
+ s .mu .Unlock ()
49
+ return ctx .Err ()
50
+ default :
51
+ }
42
52
if s .size - s .cur >= n && s .waiters .Len () == 0 {
53
+ // Since we hold s.mu and haven't synchronized since checking done, if
54
+ // ctx becomes done before we return here, it becoming done must have
55
+ // "happened concurrently" with this call - it cannot "happen before"
56
+ // we return in this branch. So, we're ok to always acquire here.
43
57
s .cur += n
44
58
s .mu .Unlock ()
45
59
return nil
@@ -48,7 +62,7 @@ func (s *Weighted) Acquire(ctx context.Context, n int64) error {
48
62
if n > s .size {
49
63
// Don't make other Acquire calls block on one that's doomed to fail.
50
64
s .mu .Unlock ()
51
- <- ctx . Done ()
65
+ <- done
52
66
return ctx .Err ()
53
67
}
54
68
@@ -58,14 +72,14 @@ func (s *Weighted) Acquire(ctx context.Context, n int64) error {
58
72
s .mu .Unlock ()
59
73
60
74
select {
61
- case <- ctx .Done ():
62
- err := ctx .Err ()
75
+ case <- done :
63
76
s .mu .Lock ()
64
77
select {
65
78
case <- ready :
66
- // Acquired the semaphore after we were canceled. Rather than trying to
67
- // fix up the queue, just pretend we didn't notice the cancelation.
68
- err = nil
79
+ // Acquired the semaphore after we were canceled.
80
+ // Pretend we didn't and put the tokens back.
81
+ s .cur -= n
82
+ s .notifyWaiters ()
69
83
default :
70
84
isFront := s .waiters .Front () == elem
71
85
s .waiters .Remove (elem )
@@ -75,9 +89,19 @@ func (s *Weighted) Acquire(ctx context.Context, n int64) error {
75
89
}
76
90
}
77
91
s .mu .Unlock ()
78
- return err
92
+ return ctx . Err ()
79
93
80
94
case <- ready :
95
+ // Acquired the semaphore. Check that ctx isn't already done.
96
+ // We check the done channel instead of calling ctx.Err because we
97
+ // already have the channel, and ctx.Err is O(n) with the nesting
98
+ // depth of ctx.
99
+ select {
100
+ case <- done :
101
+ s .Release (n )
102
+ return ctx .Err ()
103
+ default :
104
+ }
81
105
return nil
82
106
}
83
107
}
0 commit comments