Skip to content

Commit ec1f186

Browse files
thediveoonsi
authored andcommitted
feat: receiver matcher accepting (POINTER, MATCHER), includes unit tests
Signed-off-by: thediveo <[email protected]>
1 parent 9999deb commit ec1f186

File tree

3 files changed

+110
-28
lines changed

3 files changed

+110
-28
lines changed

matchers.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -194,20 +194,21 @@ func BeClosed() types.GomegaMatcher {
194194
//
195195
// will repeatedly attempt to pull values out of `c` until a value matching "bar" is received.
196196
//
197-
// Finally, if you want to have a reference to the value *sent* to the channel you can pass the `Receive` matcher a pointer to a variable of the appropriate type:
197+
// Furthermore, if you want to have a reference to the value *sent* to the channel you can pass the `Receive` matcher a pointer to a variable of the appropriate type:
198198
//
199199
// var myThing thing
200200
// Eventually(thingChan).Should(Receive(&myThing))
201201
// Expect(myThing.Sprocket).Should(Equal("foo"))
202202
// Expect(myThing.IsValid()).Should(BeTrue())
203+
//
204+
// Finally, if you want to match the received object as well as get the actual received value into a variable, so you can reason further about the value received,
205+
// you can pass a pointer to a variable of the approriate type first, and second a matcher:
206+
//
207+
// var myThing thing
208+
// Eventually(thingChan).Should(Receive(&myThing, ContainSubstring("bar")))
203209
func Receive(args ...interface{}) types.GomegaMatcher {
204-
var arg interface{}
205-
if len(args) > 0 {
206-
arg = args[0]
207-
}
208-
209210
return &matchers.ReceiveMatcher{
210-
Arg: arg,
211+
Args: args,
211212
}
212213
}
213214

matchers/receive_matcher.go

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
package matchers
44

55
import (
6+
"errors"
67
"fmt"
78
"reflect"
89

910
"github.com/onsi/gomega/format"
1011
)
1112

1213
type ReceiveMatcher struct {
13-
Arg interface{}
14+
Args []interface{}
1415
receivedValue reflect.Value
1516
channelClosed bool
1617
}
@@ -29,15 +30,38 @@ func (matcher *ReceiveMatcher) Match(actual interface{}) (success bool, err erro
2930

3031
var subMatcher omegaMatcher
3132
var hasSubMatcher bool
32-
33-
if matcher.Arg != nil {
34-
subMatcher, hasSubMatcher = (matcher.Arg).(omegaMatcher)
33+
var resultReference interface{}
34+
35+
// Valid arg formats are as follows, always with optional POINTER before
36+
// optional MATCHER:
37+
// - Receive()
38+
// - Receive(POINTER)
39+
// - Receive(MATCHER)
40+
// - Receive(POINTER, MATCHER)
41+
args := matcher.Args
42+
if len(args) > 0 {
43+
arg := args[0]
44+
_, isSubMatcher := arg.(omegaMatcher)
45+
if !isSubMatcher && reflect.ValueOf(arg).Kind() == reflect.Ptr {
46+
// Consume optional POINTER arg first, if it ain't no matcher ;)
47+
resultReference = arg
48+
args = args[1:]
49+
}
50+
}
51+
if len(args) > 0 {
52+
arg := args[0]
53+
subMatcher, hasSubMatcher = arg.(omegaMatcher)
3554
if !hasSubMatcher {
36-
argType := reflect.TypeOf(matcher.Arg)
37-
if argType.Kind() != reflect.Ptr {
38-
return false, fmt.Errorf("Cannot assign a value from the channel:\n%s\nTo:\n%s\nYou need to pass a pointer!", format.Object(actual, 1), format.Object(matcher.Arg, 1))
39-
}
55+
// At this point we assume the dev user wanted to assign a received
56+
// value, so [POINTER,]MATCHER.
57+
return false, fmt.Errorf("Cannot assign a value from the channel:\n%s\nTo:\n%s\nYou need to pass a pointer!", format.Object(actual, 1), format.Object(arg, 1))
4058
}
59+
// Consume optional MATCHER arg.
60+
args = args[1:]
61+
}
62+
if len(args) > 0 {
63+
// If there are still args present, reject all.
64+
return false, errors.New("Receive matcher expects at most an optional pointer and/or an optional matcher")
4165
}
4266

4367
winnerIndex, value, open := reflect.Select([]reflect.SelectCase{
@@ -58,16 +82,20 @@ func (matcher *ReceiveMatcher) Match(actual interface{}) (success bool, err erro
5882
}
5983

6084
if hasSubMatcher {
61-
if didReceive {
62-
matcher.receivedValue = value
63-
return subMatcher.Match(matcher.receivedValue.Interface())
85+
if !didReceive {
86+
return false, nil
6487
}
65-
return false, nil
88+
matcher.receivedValue = value
89+
if match, err := subMatcher.Match(matcher.receivedValue.Interface()); err != nil || !match {
90+
return match, err
91+
}
92+
// if we received a match, then fall through in order to handle an
93+
// optional assignment of the received value to the specified reference.
6694
}
6795

6896
if didReceive {
69-
if matcher.Arg != nil {
70-
outValue := reflect.ValueOf(matcher.Arg)
97+
if resultReference != nil {
98+
outValue := reflect.ValueOf(resultReference)
7199

72100
if value.Type().AssignableTo(outValue.Elem().Type()) {
73101
outValue.Elem().Set(value)
@@ -77,7 +105,7 @@ func (matcher *ReceiveMatcher) Match(actual interface{}) (success bool, err erro
77105
outValue.Elem().Set(value.Elem())
78106
return true, nil
79107
} else {
80-
return false, fmt.Errorf("Cannot assign a value from the channel:\n%s\nType:\n%s\nTo:\n%s", format.Object(actual, 1), format.Object(value.Interface(), 1), format.Object(matcher.Arg, 1))
108+
return false, fmt.Errorf("Cannot assign a value from the channel:\n%s\nType:\n%s\nTo:\n%s", format.Object(actual, 1), format.Object(value.Interface(), 1), format.Object(resultReference, 1))
81109
}
82110

83111
}
@@ -88,7 +116,11 @@ func (matcher *ReceiveMatcher) Match(actual interface{}) (success bool, err erro
88116
}
89117

90118
func (matcher *ReceiveMatcher) FailureMessage(actual interface{}) (message string) {
91-
subMatcher, hasSubMatcher := (matcher.Arg).(omegaMatcher)
119+
var matcherArg interface{}
120+
if len(matcher.Args) > 0 {
121+
matcherArg = matcher.Args[len(matcher.Args)-1]
122+
}
123+
subMatcher, hasSubMatcher := (matcherArg).(omegaMatcher)
92124

93125
closedAddendum := ""
94126
if matcher.channelClosed {
@@ -105,7 +137,11 @@ func (matcher *ReceiveMatcher) FailureMessage(actual interface{}) (message strin
105137
}
106138

107139
func (matcher *ReceiveMatcher) NegatedFailureMessage(actual interface{}) (message string) {
108-
subMatcher, hasSubMatcher := (matcher.Arg).(omegaMatcher)
140+
var matcherArg interface{}
141+
if len(matcher.Args) > 0 {
142+
matcherArg = matcher.Args[len(matcher.Args)-1]
143+
}
144+
subMatcher, hasSubMatcher := (matcherArg).(omegaMatcher)
109145

110146
closedAddendum := ""
111147
if matcher.channelClosed {

matchers/receive_matcher_test.go

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,39 @@ var _ = Describe("ReceiveMatcher", func() {
5454
})
5555
})
5656

57+
Context("with too many arguments", func() {
58+
It("should error", func() {
59+
channel := make(chan bool, 1)
60+
var actual bool
61+
62+
channel <- true
63+
64+
success, err := (&ReceiveMatcher{Args: []interface{}{
65+
&actual,
66+
Equal(true),
67+
42,
68+
}}).Match(channel)
69+
Expect(success).To(BeFalse())
70+
Expect(err).To(HaveOccurred())
71+
})
72+
})
73+
74+
Context("with swapped arguments", func() {
75+
It("should error", func() {
76+
channel := make(chan bool, 1)
77+
var actual bool
78+
79+
channel <- true
80+
81+
success, err := (&ReceiveMatcher{Args: []interface{}{
82+
Equal(true),
83+
&actual,
84+
}}).Match(channel)
85+
Expect(success).To(BeFalse())
86+
Expect(err).To(HaveOccurred())
87+
})
88+
})
89+
5790
Context("with a pointer argument", func() {
5891
Context("of the correct type", func() {
5992
When("the channel has an interface type", func() {
@@ -134,12 +167,12 @@ var _ = Describe("ReceiveMatcher", func() {
134167

135168
var incorrectType bool
136169

137-
success, err := (&ReceiveMatcher{Arg: &incorrectType}).Match(channel)
170+
success, err := (&ReceiveMatcher{Args: []interface{}{&incorrectType}}).Match(channel)
138171
Expect(success).Should(BeFalse())
139172
Expect(err).Should(HaveOccurred())
140173

141174
var notAPointer int
142-
success, err = (&ReceiveMatcher{Arg: notAPointer}).Match(channel)
175+
success, err = (&ReceiveMatcher{Args: []interface{}{notAPointer}}).Match(channel)
143176
Expect(success).Should(BeFalse())
144177
Expect(err).Should(HaveOccurred())
145178
})
@@ -192,7 +225,7 @@ var _ = Describe("ReceiveMatcher", func() {
192225
It("should error", func() {
193226
channel := make(chan int, 1)
194227
channel <- 3
195-
success, err := (&ReceiveMatcher{Arg: ContainSubstring("three")}).Match(channel)
228+
success, err := (&ReceiveMatcher{Args: []interface{}{ContainSubstring("three")}}).Match(channel)
196229
Expect(success).Should(BeFalse())
197230
Expect(err).Should(HaveOccurred())
198231
})
@@ -201,13 +234,25 @@ var _ = Describe("ReceiveMatcher", func() {
201234
Context("if nothing is received", func() {
202235
It("should fail", func() {
203236
channel := make(chan int, 1)
204-
success, err := (&ReceiveMatcher{Arg: Equal(1)}).Match(channel)
237+
success, err := (&ReceiveMatcher{Args: []interface{}{Equal(1)}}).Match(channel)
205238
Expect(success).Should(BeFalse())
206239
Expect(err).ShouldNot(HaveOccurred())
207240
})
208241
})
209242
})
210243

244+
Context("with a pointer and a matcher argument", func() {
245+
It("should succeed", func() {
246+
channel := make(chan bool, 1)
247+
channel <- true
248+
249+
var received bool
250+
251+
Expect(channel).Should(Receive(&received, Equal(true)))
252+
Expect(received).Should(BeTrue())
253+
})
254+
})
255+
211256
Context("When actual is a *closed* channel", func() {
212257
Context("for a buffered channel", func() {
213258
It("should work until it hits the end of the buffer", func() {

0 commit comments

Comments
 (0)