Skip to content

Commit 357dcbb

Browse files
committed
* Add tests for SegmentQueue and fix remove
1 parent 9485488 commit 357dcbb

File tree

2 files changed

+287
-23
lines changed

2 files changed

+287
-23
lines changed

kotlinx-coroutines-core/common/src/internal/SegmentQueue.kt

+34-23
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ internal abstract class SegmentQueue<S: Segment<S>>(createFirstSegment: Boolean
2424
val firstSegment = newSegment(0)
2525
if (_head.compareAndSet(null, firstSegment))
2626
startFrom = firstSegment
27-
else
27+
else {
2828
startFrom = head!!
29+
}
2930
}
3031
if (startFrom.id > id) return null
3132
// This method goes through `next` references and
@@ -49,6 +50,7 @@ internal abstract class SegmentQueue<S: Segment<S>>(createFirstSegment: Boolean
4950
}
5051
cur = curNext
5152
}
53+
if (cur.id != id) return null
5254
return cur
5355
}
5456

@@ -65,7 +67,7 @@ internal abstract class SegmentQueue<S: Segment<S>>(createFirstSegment: Boolean
6567
*/
6668
private fun moveHeadForward(new: S) {
6769
while (true) {
68-
val cur = _head.value!!
70+
val cur = head!!
6971
if (cur.id > new.id) return
7072
if (_head.compareAndSet(cur, new)) {
7173
new.prev.value = null
@@ -80,7 +82,7 @@ internal abstract class SegmentQueue<S: Segment<S>>(createFirstSegment: Boolean
8082
*/
8183
private fun moveTailForward(new: S) {
8284
while (true) {
83-
val cur = _tail.value
85+
val cur = tail
8486
if (cur !== null && cur.id > new.id) return
8587
if (_tail.compareAndSet(cur, new)) return
8688
}
@@ -105,27 +107,36 @@ internal abstract class Segment<S: Segment<S>>(val id: Long, prev: S?) {
105107
* Removes this node from the waiting queue and cleans all references to it.
106108
*/
107109
fun remove() {
108-
var next = this.next.value ?: return // tail can't be removed
110+
check(removed) { " The segment should be logically removed at first "}
111+
val next = this.next.value ?: return // tail can't be removed
109112
// Find the first non-removed node (tail is always non-removed)
110-
while (next.removed) {
111-
next = this.next.value ?: return
112-
}
113-
// Find the first non-removed `prev` and remove this node
114-
var prev = prev.value
115-
while (true) {
116-
if (prev == null) {
117-
next.prev.value = null
118-
return
119-
}
120-
if (prev.removed) {
121-
prev = prev.prev.value
122-
continue
123-
}
124-
next.movePrevToLeft(prev)
125-
prev.movePrevNextToRight(next)
126-
if (next.removed || !prev.removed) return
127-
prev = prev.prev.value
128-
}
113+
val prev = prev.value ?: return // head cannot be removed
114+
next.movePrevToLeft(prev)
115+
prev.movePrevNextToRight(next)
116+
if (prev.removed)
117+
prev.remove()
118+
if (next.removed)
119+
next.remove()
120+
121+
// while (next.removed) {
122+
// next = next.next.value ?: return
123+
// }
124+
// // Find the first non-removed `prev` and remove this node
125+
// var prev = prev.value
126+
// while (true) {
127+
// if (prev === null) {
128+
// next.prev.value = null
129+
// return
130+
// }
131+
// if (prev.removed) {
132+
// prev = prev.prev.value
133+
// continue
134+
// }
135+
// next.movePrevToLeft(prev)
136+
// prev.movePrevNextToRight(next)
137+
// if (next.removed || !prev.removed) return
138+
// prev = prev.prev.value
139+
// }
129140
}
130141

131142
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
package kotlinx.coroutines.internal
2+
3+
import com.devexperts.dxlab.lincheck.LinChecker
4+
import com.devexperts.dxlab.lincheck.annotations.Operation
5+
import com.devexperts.dxlab.lincheck.annotations.Param
6+
import com.devexperts.dxlab.lincheck.paramgen.IntGen
7+
import com.devexperts.dxlab.lincheck.strategy.stress.StressCTest
8+
import kotlinx.atomicfu.atomic
9+
import org.junit.Test
10+
import org.junit.runner.RunWith
11+
import org.junit.runners.Parameterized
12+
import java.util.*
13+
import java.util.concurrent.CyclicBarrier
14+
import java.util.concurrent.atomic.AtomicInteger
15+
import kotlin.concurrent.thread
16+
import kotlin.random.Random
17+
import kotlin.test.assertEquals
18+
import kotlin.test.assertTrue
19+
20+
private class OneElementSegment<T>(id: Long, prev: OneElementSegment<T>?) : Segment<OneElementSegment<T>>(id, prev) {
21+
val element = atomic<Any?>(null)
22+
23+
override val removed get() = element.value === REMOVED
24+
25+
fun removeSegmentLogically() {
26+
element.value = REMOVED
27+
}
28+
29+
fun removeSegmentPhysically() {
30+
remove()
31+
}
32+
}
33+
34+
private class SegmentBasedQueue<T>(createFirstSegment: Boolean) : SegmentQueue<OneElementSegment<T>>(createFirstSegment) {
35+
override fun newSegment(id: Long, prev: OneElementSegment<T>?): OneElementSegment<T> = OneElementSegment(id, prev)
36+
37+
private val enqIdx = atomic(0L)
38+
private val deqIdx = atomic(0L)
39+
40+
fun add(element: T): OneElementSegment<T> {
41+
while (true) {
42+
var tail = this.tail
43+
val enqIdx = this.enqIdx.getAndIncrement()
44+
tail = getSegment(tail, enqIdx) ?: continue
45+
if (tail.element.value === BROKEN) continue
46+
if (tail.element.compareAndSet(null, element)) return tail
47+
}
48+
}
49+
50+
fun poll(): T? {
51+
while (true) {
52+
if (this.deqIdx.value >= this.enqIdx.value) return null
53+
var head = this.head
54+
val deqIdx = this.deqIdx.getAndIncrement()
55+
head = getSegmentAndMoveHeadForward(head, deqIdx) ?: continue
56+
var el = head.element.value
57+
if (el === null) {
58+
if (head.element.compareAndSet(null, BROKEN)) continue
59+
else el = head.element.value
60+
}
61+
if (el === REMOVED) continue
62+
return el as T
63+
}
64+
}
65+
66+
val numberOfSegments: Int get() {
67+
var s: OneElementSegment<T>? = head
68+
var i = 0
69+
while (s != null) {
70+
s = s.next.value
71+
i++
72+
}
73+
return i
74+
}
75+
}
76+
77+
private val BROKEN = Symbol("BROKEN")
78+
private val REMOVED = Symbol("REMOVED")
79+
80+
@RunWith(Parameterized::class)
81+
class SegmentQueueTest(private val createFirstSegment: Boolean) {
82+
companion object {
83+
@JvmStatic
84+
@Parameterized.Parameters(name = "createFirstSegment={0}")
85+
fun testArguments() = listOf(true, false)
86+
}
87+
88+
@Test
89+
fun simpleTest() {
90+
val q = SegmentBasedQueue<Int>(createFirstSegment)
91+
assertEquals(if (createFirstSegment) 1 else 0, q.numberOfSegments)
92+
assertEquals(null, q.poll())
93+
q.add(1)
94+
assertEquals(1, q.numberOfSegments)
95+
q.add(2)
96+
assertEquals(2, q.numberOfSegments)
97+
assertEquals(1, q.poll())
98+
assertEquals(2, q.numberOfSegments)
99+
assertEquals(2, q.poll())
100+
assertEquals(1, q.numberOfSegments)
101+
assertEquals(null, q.poll())
102+
103+
}
104+
105+
@Test
106+
fun testSegmentRemoving() {
107+
val q = SegmentBasedQueue<Int>(createFirstSegment)
108+
q.add(1)
109+
val s = q.add(2)
110+
q.add(3)
111+
assertEquals(3, q.numberOfSegments)
112+
s.removeSegmentLogically()
113+
s.removeSegmentPhysically()
114+
assertEquals(2, q.numberOfSegments)
115+
assertEquals(1, q.poll())
116+
assertEquals(3, q.poll())
117+
assertEquals(null, q.poll())
118+
}
119+
120+
@Test
121+
fun testRemoveHeadSegment() {
122+
val q = SegmentBasedQueue<Int>(createFirstSegment)
123+
q.add(1)
124+
val s = q.add(2)
125+
assertEquals(1, q.poll())
126+
q.add(3)
127+
s.removeSegmentLogically()
128+
s.removeSegmentPhysically()
129+
assertEquals(3, q.poll())
130+
assertEquals(null, q.poll())
131+
}
132+
133+
@Test
134+
fun testRemoveHeadLogically() {
135+
val q = SegmentBasedQueue<Int>(createFirstSegment)
136+
val s = q.add(1)
137+
s.removeSegmentLogically()
138+
assertEquals(null, q.poll())
139+
}
140+
141+
@Test
142+
fun stressTest() {
143+
val q = SegmentBasedQueue<Int>(createFirstSegment)
144+
val expectedQueue = ArrayDeque<Int>()
145+
val r = Random(0)
146+
repeat(1_000_000) {
147+
if (r.nextBoolean()) { // add
148+
val el = r.nextInt()
149+
q.add(el)
150+
expectedQueue.add(el)
151+
} else { // remove
152+
assertEquals(expectedQueue.poll(), q.poll())
153+
}
154+
}
155+
}
156+
157+
@Test
158+
fun stressTestRemoveSegmentsSerial() = stressTestRemoveSegments(false)
159+
160+
@Test
161+
fun stressTestRemoveSegmentsRandom() = stressTestRemoveSegments(true)
162+
163+
private fun stressTestRemoveSegments(random: Boolean) {
164+
val N = 100_000
165+
val T = 1
166+
val q = SegmentBasedQueue<Int>(createFirstSegment)
167+
val segments = (1..N).map { q.add(it) }.toMutableList()
168+
if (random) segments.shuffle()
169+
assertEquals(N, q.numberOfSegments)
170+
val nextSegmentIndex = AtomicInteger()
171+
val barrier = CyclicBarrier(T)
172+
(1..T).map {
173+
thread {
174+
while (true) {
175+
barrier.await()
176+
val i = nextSegmentIndex.getAndIncrement()
177+
if (i >= N) break
178+
segments[i].removeSegmentLogically()
179+
assertTrue(segments[i].removed)
180+
segments[i].removeSegmentPhysically()
181+
}
182+
}
183+
}.forEach { it.join() }
184+
assertEquals(2, q.numberOfSegments)
185+
}
186+
}
187+
188+
@StressCTest
189+
class SegmentQueueLFTest {
190+
private companion object {
191+
var createFirstSegment: Boolean = false
192+
}
193+
194+
private val q = SegmentBasedQueue<Int>(createFirstSegment)
195+
196+
@Volatile
197+
private var removedSegment1: OneElementSegment<Int>? = null
198+
@Volatile
199+
private var removedSegment2: OneElementSegment<Int>? = null
200+
@Volatile
201+
private var lastAddedSegment: OneElementSegment<Int>? = null
202+
203+
@Operation
204+
fun addAndSaveSegment(@Param(gen = IntGen::class) x: Int) {
205+
lastAddedSegment = q.add(x)
206+
}
207+
208+
@Operation
209+
fun add(@Param(gen = IntGen::class) x: Int) {
210+
q.add(x)
211+
}
212+
213+
@Operation
214+
fun removeSegmentLogically1() {
215+
val s = lastAddedSegment ?: return
216+
s.removeSegmentLogically()
217+
removedSegment1 = s
218+
}
219+
220+
@Operation
221+
fun removeSegmentPhysically1() {
222+
val s = removedSegment1 ?: return
223+
s.remove()
224+
}
225+
226+
@Operation
227+
fun removeSegmentLogically2() {
228+
val s = lastAddedSegment ?: return
229+
s.removeSegmentLogically()
230+
removedSegment2 = s
231+
}
232+
233+
@Operation
234+
fun removeSegmentPhysically2() {
235+
val s = removedSegment2 ?: return
236+
s.remove()
237+
}
238+
239+
@Operation
240+
fun remove(): Int? = q.poll()
241+
242+
@Test
243+
fun test() {
244+
createFirstSegment = true
245+
LinChecker.check(SegmentQueueLFTest::class.java)
246+
}
247+
248+
@Test
249+
fun testWithLazyFirstSegment() {
250+
createFirstSegment = false
251+
LinChecker.check(SegmentQueueLFTest::class.java)
252+
}
253+
}

0 commit comments

Comments
 (0)