@@ -7,12 +7,37 @@ package kotlinx.coroutines.flow
7
7
import kotlinx.coroutines.*
8
8
import kotlinx.coroutines.channels.*
9
9
import kotlin.coroutines.*
10
+ import kotlin.reflect.*
10
11
import kotlin.test.*
11
12
12
13
class FlowInvariantsTest : TestBase () {
13
14
15
+ private fun <T > runParametrizedTest (
16
+ expectedException : KClass <out Throwable >? = null,
17
+ testBody : suspend (flowFactory: (suspend FlowCollector <T >.() -> Unit ) -> Flow <T >) -> Unit
18
+ ) = runTest {
19
+ val r1 = runCatching { testBody { flow(it) } }.exceptionOrNull()
20
+ check(r1, expectedException)
21
+ reset()
22
+
23
+ val r2 = runCatching { testBody { abstractFlow(it) } }.exceptionOrNull()
24
+ check(r2, expectedException)
25
+ }
26
+
27
+ private fun <T > abstractFlow (block : suspend FlowCollector <T >.() -> Unit ): Flow <T > = object : AbstractFlow <T >() {
28
+ override suspend fun collectSafely (collector : FlowCollector <T >) {
29
+ collector.block()
30
+ }
31
+ }
32
+
33
+ private fun check (exception : Throwable ? , expectedException : KClass <out Throwable >? ) {
34
+ if (expectedException != null && exception == null ) fail(" Expected $expectedException , but test completed successfully" )
35
+ if (expectedException != null && exception != null ) assertTrue(expectedException.isInstance(exception))
36
+ if (expectedException == null && exception != null ) throw exception
37
+ }
38
+
14
39
@Test
15
- fun testWithContextContract () = runTest({ it is IllegalStateException } ) {
40
+ fun testWithContextContract () = runParametrizedTest< Int >( IllegalStateException :: class ) { flow ->
16
41
flow {
17
42
kotlinx.coroutines.withContext(NonCancellable ) {
18
43
emit(1 )
@@ -23,7 +48,7 @@ class FlowInvariantsTest : TestBase() {
23
48
}
24
49
25
50
@Test
26
- fun testWithDispatcherContractViolated () = runTest({ it is IllegalStateException } ) {
51
+ fun testWithDispatcherContractViolated () = runParametrizedTest< Int >( IllegalStateException :: class ) { flow ->
27
52
flow {
28
53
kotlinx.coroutines.withContext(NamedDispatchers (" foo" )) {
29
54
emit(1 )
@@ -34,7 +59,7 @@ class FlowInvariantsTest : TestBase() {
34
59
}
35
60
36
61
@Test
37
- fun testCachedInvariantCheckResult () = runTest {
62
+ fun testCachedInvariantCheckResult () = runParametrizedTest< Int > { flow ->
38
63
flow {
39
64
emit(1 )
40
65
@@ -55,7 +80,7 @@ class FlowInvariantsTest : TestBase() {
55
80
}
56
81
57
82
@Test
58
- fun testWithNameContractViolated () = runTest({ it is IllegalStateException } ) {
83
+ fun testWithNameContractViolated () = runParametrizedTest< Int >( IllegalStateException :: class ) { flow ->
59
84
flow {
60
85
kotlinx.coroutines.withContext(CoroutineName (" foo" )) {
61
86
emit(1 )
@@ -86,25 +111,25 @@ class FlowInvariantsTest : TestBase() {
86
111
}
87
112
88
113
@Test
89
- fun testScopedJob () = runTest({ it is IllegalStateException } ) {
90
- flow { emit(1 ) }.buffer(EmptyCoroutineContext ).collect {
114
+ fun testScopedJob () = runParametrizedTest< Int >( IllegalStateException :: class ) { flow ->
115
+ flow { emit(1 ) }.buffer(EmptyCoroutineContext , flow ).collect {
91
116
expect(1 )
92
117
}
93
118
94
119
finish(2 )
95
120
}
96
121
97
122
@Test
98
- fun testScopedJobWithViolation () = runTest({ it is IllegalStateException } ) {
99
- flow { emit(1 ) }.buffer(Dispatchers .Unconfined ).collect {
123
+ fun testScopedJobWithViolation () = runParametrizedTest< Int >( IllegalStateException :: class ) { flow ->
124
+ flow { emit(1 ) }.buffer(Dispatchers .Unconfined , flow ).collect {
100
125
expect(1 )
101
126
}
102
127
103
128
finish(2 )
104
129
}
105
130
106
131
@Test
107
- fun testMergeViolation () = runTest {
132
+ fun testMergeViolation () = runParametrizedTest< Int > { flow ->
108
133
fun Flow<Int>.merge (other : Flow <Int >): Flow <Int > = flow {
109
134
coroutineScope {
110
135
launch {
@@ -130,17 +155,6 @@ class FlowInvariantsTest : TestBase() {
130
155
assertFailsWith<IllegalStateException > { flow.trickyMerge(flow).toList() }
131
156
}
132
157
133
- // TODO merge artifact
134
- private fun <T > channelFlow (bufferSize : Int = 16, @BuilderInference block : suspend ProducerScope <T >.() -> Unit ): Flow <T > =
135
- flow {
136
- coroutineScope {
137
- val channel = produce(capacity = bufferSize, block = block)
138
- channel.consumeEach { value ->
139
- emit(value)
140
- }
141
- }
142
- }
143
-
144
158
@Test
145
159
fun testNoMergeViolation () = runTest {
146
160
fun Flow<Int>.merge (other : Flow <Int >): Flow <Int > = channelFlow {
@@ -167,7 +181,7 @@ class FlowInvariantsTest : TestBase() {
167
181
}
168
182
169
183
@Test
170
- fun testScopedCoroutineNoViolation () = runTest {
184
+ fun testScopedCoroutineNoViolation () = runParametrizedTest< Int > { flow ->
171
185
fun Flow<Int>.buffer (): Flow <Int > = flow {
172
186
coroutineScope {
173
187
val channel = produce {
@@ -180,11 +194,10 @@ class FlowInvariantsTest : TestBase() {
180
194
}
181
195
}
182
196
}
183
-
184
197
assertEquals(listOf (1 , 1 ), flowOf(1 , 1 ).buffer().toList())
185
198
}
186
199
187
- private fun Flow<Int>.buffer (coroutineContext : CoroutineContext ): Flow <Int > = flow {
200
+ private fun Flow<Int>.buffer (coroutineContext : CoroutineContext , flow : (suspend FlowCollector < Int >.() -> Unit ) -> Flow < Int > ): Flow <Int > = flow {
188
201
coroutineScope {
189
202
val channel = Channel <Int >()
190
203
launch {
0 commit comments