1
+ /*
2
+ * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3
+ */
4
+
5
+ package kotlinx.coroutines.reactive
6
+
7
+ import kotlinx.coroutines.*
8
+ import kotlinx.coroutines.flow.*
9
+ import kotlinx.coroutines.flow.Flow
10
+ import org.junit.*
11
+ import org.reactivestreams.*
12
+ import java.util.concurrent.*
13
+ import java.util.concurrent.atomic.*
14
+ import kotlin.coroutines.*
15
+
16
+ @Suppress(" ReactiveStreamsSubscriberImplementation" )
17
+ class PublisherRequestStressTest : TestBase () {
18
+ private val testDurationSec = 3 * stressTestMultiplier
19
+
20
+ private val minDemand = 8L
21
+ private val maxDemand = 10L
22
+ private val nEmitThreads = 4
23
+
24
+ private val emitThreadNo = AtomicInteger ()
25
+
26
+ private val emitPool = Executors .newFixedThreadPool(nEmitThreads) { r ->
27
+ Thread (r, " PublisherRequestStressTest-emit-${emitThreadNo.incrementAndGet()} " )
28
+ }
29
+
30
+ private val reqPool = Executors .newSingleThreadExecutor { r ->
31
+ Thread (r, " PublisherRequestStressTest-req" )
32
+ }
33
+
34
+ private val nextValue = AtomicLong (0 )
35
+
36
+ @After
37
+ fun tearDown () {
38
+ emitPool.shutdown()
39
+ reqPool.shutdown()
40
+ emitPool.awaitTermination(10 , TimeUnit .SECONDS )
41
+ reqPool.awaitTermination(10 , TimeUnit .SECONDS )
42
+ }
43
+
44
+ private lateinit var subscription: Subscription
45
+
46
+ @Test
47
+ fun testRequestStress () {
48
+ val expectedValue = AtomicLong (0 )
49
+ val requestedTill = AtomicLong (0 )
50
+ val completionLatch = CountDownLatch (1 )
51
+ val callingOnNext = AtomicInteger ()
52
+
53
+ val publisher = mtFlow().asPublisher()
54
+ var error = false
55
+
56
+ publisher.subscribe(object : Subscriber <Long > {
57
+ private var demand = 0L // only updated from reqPool
58
+
59
+ override fun onComplete () {
60
+ completionLatch.countDown()
61
+ }
62
+
63
+ override fun onSubscribe (sub : Subscription ) {
64
+ subscription = sub
65
+ maybeRequestMore()
66
+ }
67
+
68
+ private fun maybeRequestMore () {
69
+ if (demand >= minDemand) return
70
+ val more = maxDemand - demand
71
+ demand = maxDemand
72
+ requestedTill.addAndGet(more)
73
+ subscription.request(more)
74
+ }
75
+
76
+ override fun onNext (value : Long ) {
77
+ check(callingOnNext.getAndIncrement() == 0 ) // make sure it is not concurrent
78
+ // check for expected value
79
+ check(value == expectedValue.get())
80
+ // check that it does not exceed requested values
81
+ check(value < requestedTill.get())
82
+ val nextExpected = value + 1
83
+ expectedValue.set(nextExpected)
84
+ // send more requests from request thread
85
+ reqPool.execute {
86
+ demand-- // processed an item
87
+ maybeRequestMore()
88
+ }
89
+ callingOnNext.decrementAndGet()
90
+ }
91
+
92
+ override fun onError (ex : Throwable ? ) {
93
+ error = true
94
+ error(" Failed" , ex)
95
+ }
96
+ })
97
+ for (second in 1 .. testDurationSec) {
98
+ if (error) break
99
+ Thread .sleep(1000 )
100
+ println (" $second : nextValue = ${nextValue.get()} , expectedValue = ${expectedValue.get()} " )
101
+ }
102
+ if (! error) {
103
+ subscription.cancel()
104
+ completionLatch.await()
105
+ }
106
+ }
107
+
108
+ private fun mtFlow (): Flow <Long > = flow {
109
+ while (currentCoroutineContext().isActive) {
110
+ emit(aWait())
111
+ }
112
+ }
113
+
114
+ private suspend fun aWait (): Long = suspendCancellableCoroutine { cont ->
115
+ emitPool.execute(Runnable {
116
+ cont.resume(nextValue.getAndIncrement())
117
+ })
118
+ }
119
+ }
0 commit comments