@@ -26,6 +26,7 @@ import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRe
26
26
import org.springframework.web.testfixture.server.MockServerWebExchange
27
27
import reactor.core.publisher.Mono
28
28
import reactor.test.StepVerifier
29
+ import kotlin.coroutines.AbstractCoroutineContextElement
29
30
import kotlin.coroutines.CoroutineContext
30
31
31
32
/* *
@@ -44,10 +45,24 @@ class CoWebFilterTests {
44
45
val filter = MyCoWebFilter ()
45
46
val result = filter.filter(exchange, chain)
46
47
47
- StepVerifier .create(result)
48
- .verifyComplete()
48
+ StepVerifier .create(result).verifyComplete()
49
+
50
+ assertThat(exchange.attributes[" foo" ]).isEqualTo(" bar" )
51
+ }
52
+
53
+ @Test
54
+ fun multipleFilters () {
55
+ val exchange = MockServerWebExchange .from(MockServerHttpRequest .get(" https://example.com" ))
56
+
57
+ val chain = Mockito .mock(WebFilterChain ::class .java)
58
+ given(chain.filter(exchange)).willAnswer { MyOtherCoWebFilter ().filter(exchange,chain) }.willReturn(Mono .empty())
59
+
60
+ val result = MyCoWebFilter ().filter(exchange, chain)
61
+
62
+ StepVerifier .create(result).verifyComplete()
49
63
50
64
assertThat(exchange.attributes[" foo" ]).isEqualTo(" bar" )
65
+ assertThat(exchange.attributes[" foofoo" ]).isEqualTo(" barbar" )
51
66
}
52
67
53
68
@Test
@@ -69,6 +84,28 @@ class CoWebFilterTests {
69
84
assertThat(coroutineName.name).isEqualTo(" foo" )
70
85
}
71
86
87
+ @Test
88
+ fun multipleFiltersWithContext () {
89
+ val exchange = MockServerWebExchange .from(MockServerHttpRequest .get(" https://example.com" ))
90
+
91
+ val chain = Mockito .mock(WebFilterChain ::class .java)
92
+ given(chain.filter(exchange)).willAnswer { MyOtherCoWebFilterWithContext ().filter(exchange,chain) }.willReturn(Mono .empty())
93
+
94
+ val filter = MyCoWebFilterWithContext ()
95
+ val result = filter.filter(exchange, chain)
96
+
97
+ StepVerifier .create(result).verifyComplete()
98
+
99
+ val context = exchange.attributes[CoWebFilter .COROUTINE_CONTEXT_ATTRIBUTE ] as CoroutineContext
100
+ assertThat(context).isNotNull()
101
+ val coroutineName = context[CoroutineName .Key ] as CoroutineName
102
+ assertThat(coroutineName).isNotNull()
103
+ assertThat(coroutineName.name).isEqualTo(" foo" )
104
+ val coroutineDescription = context[CoroutineDescription .Key ] as CoroutineDescription
105
+ assertThat(coroutineDescription).isNotNull()
106
+ assertThat(coroutineDescription.description).isEqualTo(" foofoo" )
107
+ }
108
+
72
109
}
73
110
74
111
@@ -79,10 +116,33 @@ private class MyCoWebFilter : CoWebFilter() {
79
116
}
80
117
}
81
118
119
+ private class MyOtherCoWebFilter : CoWebFilter () {
120
+ override suspend fun filter (exchange : ServerWebExchange , chain : CoWebFilterChain ) {
121
+ exchange.attributes[" foofoo" ] = " barbar"
122
+ chain.filter(exchange)
123
+ }
124
+ }
125
+
82
126
private class MyCoWebFilterWithContext : CoWebFilter () {
83
127
override suspend fun filter (exchange : ServerWebExchange , chain : CoWebFilterChain ) {
84
128
withContext(CoroutineName (" foo" )) {
85
129
chain.filter(exchange)
86
130
}
87
131
}
88
132
}
133
+
134
+ private class MyOtherCoWebFilterWithContext : CoWebFilter () {
135
+ override suspend fun filter (exchange : ServerWebExchange , chain : CoWebFilterChain ) {
136
+ withContext(CoroutineDescription (" foofoo" )) {
137
+ chain.filter(exchange)
138
+ }
139
+ }
140
+ }
141
+
142
+ data class CoroutineDescription (val description : String ) : AbstractCoroutineContextElement(CoroutineDescription ) {
143
+
144
+ companion object Key : CoroutineContext.Key<CoroutineDescription>
145
+
146
+ override fun toString (): String = " CoroutineDescription($description )"
147
+ }
148
+
0 commit comments