1
+ /*
2
+ * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3
+ */
4
+
5
+ package kotlinx.coroutines.reactive
6
+
7
+ import kotlinx.coroutines.*
8
+ import kotlinx.coroutines.flow.*
9
+ import kotlinx.coroutines.flow.Flow
10
+ import org.junit.*
11
+ import org.reactivestreams.*
12
+ import java.util.concurrent.*
13
+ import java.util.concurrent.atomic.*
14
+ import kotlin.coroutines.*
15
+ import kotlin.random.*
16
+
17
+ /* *
18
+ * This stress-test is self-contained reproducer for the race in [Flow.asPublisher] extension
19
+ * that was originally reported in the issue
20
+ * [#2109](https://github.com/Kotlin/kotlinx.coroutines/issues/2109).
21
+ * The original reproducer used a flow that loads a file using AsynchronousFileChannel
22
+ * (that issues completion callbacks from multiple threads)
23
+ * and uploads it to S3 via Amazon SDK, which internally uses netty for I/O
24
+ * (which uses a single thread for connection-related callbacks).
25
+ *
26
+ * This stress-test essentially mimics the logic in multiple interacting threads: several emitter threads that form
27
+ * the flow and a single requesting thread works on the subscriber's side to periodically request more
28
+ * values when the number of items requested drops below the threshold.
29
+ */
30
+ @Suppress(" ReactiveStreamsSubscriberImplementation" )
31
+ class PublisherRequestStressTest : TestBase () {
32
+ private val testDurationSec = 3 * stressTestMultiplier
33
+
34
+ // Original code in Amazon SDK uses 4 and 16 as low/high watermarks.
35
+ // There constants were chosen so that problem reproduces asap with particular this code.
36
+ private val minDemand = 8L
37
+ private val maxDemand = 16L
38
+
39
+ private val nEmitThreads = 4
40
+
41
+ private val emitThreadNo = AtomicInteger ()
42
+
43
+ private val emitPool = Executors .newFixedThreadPool(nEmitThreads) { r ->
44
+ Thread (r, " PublisherRequestStressTest-emit-${emitThreadNo.incrementAndGet()} " )
45
+ }
46
+
47
+ private val reqPool = Executors .newSingleThreadExecutor { r ->
48
+ Thread (r, " PublisherRequestStressTest-req" )
49
+ }
50
+
51
+ private val nextValue = AtomicLong (0 )
52
+
53
+ @After
54
+ fun tearDown () {
55
+ emitPool.shutdown()
56
+ reqPool.shutdown()
57
+ emitPool.awaitTermination(10 , TimeUnit .SECONDS )
58
+ reqPool.awaitTermination(10 , TimeUnit .SECONDS )
59
+ }
60
+
61
+ private lateinit var subscription: Subscription
62
+
63
+ @Test
64
+ fun testRequestStress () {
65
+ val expectedValue = AtomicLong (0 )
66
+ val requestedTill = AtomicLong (0 )
67
+ val completionLatch = CountDownLatch (1 )
68
+ val callingOnNext = AtomicInteger ()
69
+
70
+ val publisher = mtFlow().asPublisher()
71
+ var error = false
72
+
73
+ publisher.subscribe(object : Subscriber <Long > {
74
+ private var demand = 0L // only updated from reqPool
75
+
76
+ override fun onComplete () {
77
+ completionLatch.countDown()
78
+ }
79
+
80
+ override fun onSubscribe (sub : Subscription ) {
81
+ subscription = sub
82
+ maybeRequestMore()
83
+ }
84
+
85
+ private fun maybeRequestMore () {
86
+ if (demand >= minDemand) return
87
+ val nextDemand = Random .nextLong(minDemand + 1 .. maxDemand)
88
+ val more = nextDemand - demand
89
+ demand = nextDemand
90
+ requestedTill.addAndGet(more)
91
+ subscription.request(more)
92
+ }
93
+
94
+ override fun onNext (value : Long ) {
95
+ check(callingOnNext.getAndIncrement() == 0 ) // make sure it is not concurrent
96
+ // check for expected value
97
+ check(value == expectedValue.get())
98
+ // check that it does not exceed requested values
99
+ check(value < requestedTill.get())
100
+ val nextExpected = value + 1
101
+ expectedValue.set(nextExpected)
102
+ // send more requests from request thread
103
+ reqPool.execute {
104
+ demand-- // processed an item
105
+ maybeRequestMore()
106
+ }
107
+ callingOnNext.decrementAndGet()
108
+ }
109
+
110
+ override fun onError (ex : Throwable ? ) {
111
+ error = true
112
+ error(" Failed" , ex)
113
+ }
114
+ })
115
+ var prevExpected = - 1L
116
+ for (second in 1 .. testDurationSec) {
117
+ if (error) break
118
+ Thread .sleep(1000 )
119
+ val expected = expectedValue.get()
120
+ println (" $second : expectedValue = $expected " )
121
+ check(expected > prevExpected) // should have progress
122
+ prevExpected = expected
123
+ }
124
+ if (! error) {
125
+ subscription.cancel()
126
+ completionLatch.await()
127
+ }
128
+ }
129
+
130
+ private fun mtFlow (): Flow <Long > = flow {
131
+ while (currentCoroutineContext().isActive) {
132
+ emit(aWait())
133
+ }
134
+ }
135
+
136
+ private suspend fun aWait (): Long = suspendCancellableCoroutine { cont ->
137
+ emitPool.execute(Runnable {
138
+ cont.resume(nextValue.getAndIncrement())
139
+ })
140
+ }
141
+ }
0 commit comments