Skip to content

Commit d87ad1f

Browse files
authored
feat: fail test on trapped but unreleased calls (#15)
part of #13 Will fail test with an error message if you trap a call and don't release it. I've also added 2 test cases to the unit test. One is a skipped test that you can unskip to see what it looks like if you don't release the trapped call. One tests what happens if two different traps catch the same call. It gets a little awkward because when you `Release()` the call, it waits for the call to complete. One of the main purposes is to ensure that a timer or ticker is set, and if we don't wait for the call to complete, you can't ensure the timer is set before you advance the clock. The awkward consequence is that the `Release()` calls will deadlock if they happen on the same goroutine because _all_ traps have to release the call before it will return. We separate out the trapping of the call and releasing of the call so that you have a chance to manipulate the clock before the call returns. But, actually, there are really 3 phases to a trapped call: 1. Call is trapped 2. All traps released, we get the time and do the work (e.g. actually setting the timer) 3. Call completes After `trap.Wait()` returns, we know phase 1 is complete. But, `Release()` actually conflates phase 2 and 3, so there is no way to release the trap without waiting for phase 3. Generally we don't care that much about the distinction, it's really only in the case of multple traps that you'd need to release without waiting to avoid the deadlock. We could make those phases explicit: `trap.Wait().Release().WaitForComplete()`, but that seems pretty involved for what I think is generally an edge case. WDYT?
1 parent d4dbd83 commit d87ad1f

File tree

2 files changed

+165
-15
lines changed

2 files changed

+165
-15
lines changed

mock.go

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func (m *Mock) removeEventLocked(e event) {
190190
}
191191
}
192192

193-
func (m *Mock) matchCallLocked(c *Call) {
193+
func (m *Mock) matchCallLocked(c *apiCall) {
194194
var traps []*Trap
195195
for _, t := range m.traps {
196196
if t.matches(c) {
@@ -435,7 +435,7 @@ func (m *Mock) newTrap(fn clockFunction, tags []string) *Trap {
435435
fn: fn,
436436
tags: tags,
437437
mock: m,
438-
calls: make(chan *Call),
438+
calls: make(chan *apiCall),
439439
done: make(chan struct{}),
440440
}
441441
m.traps = append(m.traps, tr)
@@ -557,9 +557,10 @@ const (
557557
clockFunctionUntil
558558
)
559559

560-
type callArg func(c *Call)
560+
type callArg func(c *apiCall)
561561

562-
type Call struct {
562+
// apiCall represents a single call to one of the Clock APIs.
563+
type apiCall struct {
563564
Time time.Time
564565
Duration time.Duration
565566
Tags []string
@@ -569,25 +570,36 @@ type Call struct {
569570
complete chan struct{}
570571
}
571572

573+
// Call represents an apiCall that has been trapped.
574+
type Call struct {
575+
Time time.Time
576+
Duration time.Duration
577+
Tags []string
578+
579+
apiCall *apiCall
580+
trap *Trap
581+
}
582+
572583
func (c *Call) Release() {
573-
c.releases.Done()
574-
<-c.complete
584+
c.apiCall.releases.Done()
585+
<-c.apiCall.complete
586+
c.trap.callReleased()
575587
}
576588

577589
func withTime(t time.Time) callArg {
578-
return func(c *Call) {
590+
return func(c *apiCall) {
579591
c.Time = t
580592
}
581593
}
582594

583595
func withDuration(d time.Duration) callArg {
584-
return func(c *Call) {
596+
return func(c *apiCall) {
585597
c.Duration = d
586598
}
587599
}
588600

589-
func newCall(fn clockFunction, tags []string, args ...callArg) *Call {
590-
c := &Call{
601+
func newCall(fn clockFunction, tags []string, args ...callArg) *apiCall {
602+
c := &apiCall{
591603
fn: fn,
592604
Tags: tags,
593605
complete: make(chan struct{}),
@@ -602,19 +614,23 @@ type Trap struct {
602614
fn clockFunction
603615
tags []string
604616
mock *Mock
605-
calls chan *Call
617+
calls chan *apiCall
606618
done chan struct{}
619+
620+
// mu protects the unreleasedCalls count
621+
mu sync.Mutex
622+
unreleasedCalls int
607623
}
608624

609-
func (t *Trap) catch(c *Call) {
625+
func (t *Trap) catch(c *apiCall) {
610626
select {
611627
case t.calls <- c:
612628
case <-t.done:
613-
c.Release()
629+
c.releases.Done()
614630
}
615631
}
616632

617-
func (t *Trap) matches(c *Call) bool {
633+
func (t *Trap) matches(c *apiCall) bool {
618634
if t.fn != c.fn {
619635
return false
620636
}
@@ -629,6 +645,10 @@ func (t *Trap) matches(c *Call) bool {
629645
func (t *Trap) Close() {
630646
t.mock.mu.Lock()
631647
defer t.mock.mu.Unlock()
648+
if t.unreleasedCalls != 0 {
649+
t.mock.tb.Helper()
650+
t.mock.tb.Errorf("trap Closed() with %d unreleased calls", t.unreleasedCalls)
651+
}
632652
for i, tr := range t.mock.traps {
633653
if t == tr {
634654
t.mock.traps = append(t.mock.traps[:i], t.mock.traps[i+1:]...)
@@ -637,6 +657,12 @@ func (t *Trap) Close() {
637657
close(t.done)
638658
}
639659

660+
func (t *Trap) callReleased() {
661+
t.mu.Lock()
662+
defer t.mu.Unlock()
663+
t.unreleasedCalls--
664+
}
665+
640666
var ErrTrapClosed = errors.New("trap closed")
641667

642668
func (t *Trap) Wait(ctx context.Context) (*Call, error) {
@@ -645,7 +671,17 @@ func (t *Trap) Wait(ctx context.Context) (*Call, error) {
645671
return nil, ctx.Err()
646672
case <-t.done:
647673
return nil, ErrTrapClosed
648-
case c := <-t.calls:
674+
case a := <-t.calls:
675+
c := &Call{
676+
Time: a.Time,
677+
Duration: a.Duration,
678+
Tags: a.Tags,
679+
apiCall: a,
680+
trap: t,
681+
}
682+
t.mu.Lock()
683+
defer t.mu.Unlock()
684+
t.unreleasedCalls++
649685
return c, nil
650686
}
651687
}

mock_test.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,117 @@ func TestTickerFunc_LongCallback(t *testing.T) {
319319
}
320320
w.MustWait(testCtx)
321321
}
322+
323+
func Test_MultipleTraps(t *testing.T) {
324+
t.Parallel()
325+
testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second)
326+
defer testCancel()
327+
mClock := quartz.NewMock(t)
328+
329+
trap0 := mClock.Trap().Now("0")
330+
defer trap0.Close()
331+
trap1 := mClock.Trap().Now("1")
332+
defer trap1.Close()
333+
334+
timeCh := make(chan time.Time)
335+
go func() {
336+
timeCh <- mClock.Now("0", "1")
337+
}()
338+
339+
c0 := trap0.MustWait(testCtx)
340+
mClock.Advance(time.Second)
341+
// the two trapped call instances need to be released on separate goroutines since they each wait for the Now() call
342+
// to return, which is blocked on both releases happening. If you release them on the same goroutine, in either
343+
// order, it will deadlock.
344+
done := make(chan struct{})
345+
go func() {
346+
defer close(done)
347+
c0.Release()
348+
}()
349+
c1 := trap1.MustWait(testCtx)
350+
mClock.Advance(time.Second)
351+
c1.Release()
352+
353+
select {
354+
case <-done:
355+
case <-testCtx.Done():
356+
t.Fatal("timed out waiting for c0.Release()")
357+
}
358+
359+
select {
360+
case got := <-timeCh:
361+
end := mClock.Now("end")
362+
if !got.Equal(end) {
363+
t.Fatalf("expected %s got %s", end, got)
364+
}
365+
case <-testCtx.Done():
366+
t.Fatal("timed out waiting for Now()")
367+
}
368+
}
369+
370+
func Test_UnreleasedCalls(t *testing.T) {
371+
t.Parallel()
372+
tRunFail(t, func(t testing.TB) {
373+
testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second)
374+
defer testCancel()
375+
mClock := quartz.NewMock(t)
376+
377+
trap := mClock.Trap().Now()
378+
defer trap.Close()
379+
380+
go func() {
381+
_ = mClock.Now()
382+
}()
383+
384+
trap.MustWait(testCtx) // missing release
385+
})
386+
}
387+
388+
type captureFailTB struct {
389+
failed bool
390+
testing.TB
391+
}
392+
393+
func (t *captureFailTB) Errorf(format string, args ...any) {
394+
t.Helper()
395+
t.Logf(format, args...)
396+
t.failed = true
397+
}
398+
399+
func (t *captureFailTB) Error(args ...any) {
400+
t.Helper()
401+
t.Log(args...)
402+
t.failed = true
403+
}
404+
405+
func (t *captureFailTB) Fatal(args ...any) {
406+
t.Helper()
407+
t.Log(args...)
408+
t.failed = true
409+
}
410+
411+
func (t *captureFailTB) Fatalf(format string, args ...any) {
412+
t.Helper()
413+
t.Logf(format, args...)
414+
t.failed = true
415+
}
416+
417+
func (t *captureFailTB) Fail() {
418+
t.failed = true
419+
}
420+
421+
func (t *captureFailTB) FailNow() {
422+
t.failed = true
423+
}
424+
425+
func (t *captureFailTB) Failed() bool {
426+
return t.failed
427+
}
428+
429+
func tRunFail(t testing.TB, f func(t testing.TB)) {
430+
tb := &captureFailTB{TB: t}
431+
f(tb)
432+
if !tb.Failed() {
433+
t.Fatal("want test to fail")
434+
}
435+
}

0 commit comments

Comments
 (0)