Skip to content

Commit 4ed68e1

Browse files
czeslavodolmen
authored andcommitted
fix: make EventuallyWithT concurrency safe
1 parent 11a6452 commit 4ed68e1

File tree

2 files changed

+67
-22
lines changed

2 files changed

+67
-22
lines changed

assert/assertions.go

+18-19
Original file line numberDiff line numberDiff line change
@@ -1873,23 +1873,18 @@ func (c *CollectT) Errorf(format string, args ...interface{}) {
18731873
}
18741874

18751875
// FailNow panics.
1876-
func (c *CollectT) FailNow() {
1876+
func (*CollectT) FailNow() {
18771877
panic("Assertion failed")
18781878
}
18791879

1880-
// Reset clears the collected errors.
1881-
func (c *CollectT) Reset() {
1882-
c.errors = nil
1880+
// Deprecated: That was a method for internal usage that should not have been published. Now just panics.
1881+
func (*CollectT) Reset() {
1882+
panic("Reset() is deprecated")
18831883
}
18841884

1885-
// Copy copies the collected errors to the supplied t.
1886-
func (c *CollectT) Copy(t TestingT) {
1887-
if tt, ok := t.(tHelper); ok {
1888-
tt.Helper()
1889-
}
1890-
for _, err := range c.errors {
1891-
t.Errorf("%v", err)
1892-
}
1885+
// Deprecated: That was a method for internal usage that should not have been published. Now just panics.
1886+
func (*CollectT) Copy(TestingT) {
1887+
panic("Copy() is deprecated")
18931888
}
18941889

18951890
// EventuallyWithT asserts that given condition will be met in waitFor time,
@@ -1915,8 +1910,8 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
19151910
h.Helper()
19161911
}
19171912

1918-
collect := new(CollectT)
1919-
ch := make(chan bool, 1)
1913+
var lastFinishedTickErrs []error
1914+
ch := make(chan []error, 1)
19201915

19211916
timer := time.NewTimer(waitFor)
19221917
defer timer.Stop()
@@ -1927,19 +1922,23 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
19271922
for tick := ticker.C; ; {
19281923
select {
19291924
case <-timer.C:
1930-
collect.Copy(t)
1925+
for _, err := range lastFinishedTickErrs {
1926+
t.Errorf("%v", err)
1927+
}
19311928
return Fail(t, "Condition never satisfied", msgAndArgs...)
19321929
case <-tick:
19331930
tick = nil
1934-
collect.Reset()
19351931
go func() {
1932+
collect := new(CollectT)
19361933
condition(collect)
1937-
ch <- len(collect.errors) == 0
1934+
ch <- collect.errors
19381935
}()
1939-
case v := <-ch:
1940-
if v {
1936+
case errs := <-ch:
1937+
if len(errs) == 0 {
19411938
return true
19421939
}
1940+
// Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached.
1941+
lastFinishedTickErrs = errs
19431942
tick = ticker.C
19441943
}
19451944
}

assert/assertions_test.go

+49-3
Original file line numberDiff line numberDiff line change
@@ -2766,19 +2766,30 @@ func TestEventuallyTrue(t *testing.T) {
27662766
True(t, Eventually(t, condition, 100*time.Millisecond, 20*time.Millisecond))
27672767
}
27682768

2769+
// errorsCapturingT is a mock implementation of TestingT that captures errors reported with Errorf.
2770+
type errorsCapturingT struct {
2771+
errors []error
2772+
}
2773+
2774+
func (t *errorsCapturingT) Errorf(format string, args ...interface{}) {
2775+
t.errors = append(t.errors, fmt.Errorf(format, args...))
2776+
}
2777+
2778+
func (t *errorsCapturingT) Helper() {}
2779+
27692780
func TestEventuallyWithTFalse(t *testing.T) {
2770-
mockT := new(CollectT)
2781+
mockT := new(errorsCapturingT)
27712782

27722783
condition := func(collect *CollectT) {
2773-
True(collect, false)
2784+
Fail(collect, "condition fixed failure")
27742785
}
27752786

27762787
False(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, 20*time.Millisecond))
27772788
Len(t, mockT.errors, 2)
27782789
}
27792790

27802791
func TestEventuallyWithTTrue(t *testing.T) {
2781-
mockT := new(CollectT)
2792+
mockT := new(errorsCapturingT)
27822793

27832794
state := 0
27842795
condition := func(collect *CollectT) {
@@ -2792,6 +2803,41 @@ func TestEventuallyWithTTrue(t *testing.T) {
27922803
Len(t, mockT.errors, 0)
27932804
}
27942805

2806+
func TestEventuallyWithT_ConcurrencySafe(t *testing.T) {
2807+
mockT := new(errorsCapturingT)
2808+
2809+
condition := func(collect *CollectT) {
2810+
Fail(collect, "condition fixed failure")
2811+
}
2812+
2813+
// To trigger race conditions, we run EventuallyWithT with a nanosecond tick.
2814+
False(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, time.Nanosecond))
2815+
Len(t, mockT.errors, 2)
2816+
}
2817+
2818+
func TestEventuallyWithT_ReturnsTheLatestFinishedConditionErrors(t *testing.T) {
2819+
// We'll use a channel to control whether a condition should sleep or not.
2820+
mustSleep := make(chan bool, 2)
2821+
mustSleep <- false
2822+
mustSleep <- true
2823+
close(mustSleep)
2824+
2825+
condition := func(collect *CollectT) {
2826+
if <-mustSleep {
2827+
// Sleep to ensure that the second condition runs longer than timeout.
2828+
time.Sleep(time.Second)
2829+
return
2830+
}
2831+
2832+
// The first condition will fail. We expect to get this error as a result.
2833+
Fail(collect, "condition fixed failure")
2834+
}
2835+
2836+
mockT := new(errorsCapturingT)
2837+
False(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, 20*time.Millisecond))
2838+
Len(t, mockT.errors, 2)
2839+
}
2840+
27952841
func TestNeverFalse(t *testing.T) {
27962842
condition := func() bool {
27972843
return false

0 commit comments

Comments
 (0)