Skip to content

Commit d75a7c3

Browse files
committed
Support multiple CoWebFilter changing the context
This commit ensures CoWebFilter merges the exchange CoroutineContext with the filter one if needed. Closes gh-31792
1 parent e2c2268 commit d75a7c3

File tree

2 files changed

+65
-3
lines changed

2 files changed

+65
-3
lines changed

spring-web/src/main/kotlin/org/springframework/web/server/CoWebFilter.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import kotlinx.coroutines.currentCoroutineContext
2222
import kotlinx.coroutines.reactor.awaitSingleOrNull
2323
import kotlinx.coroutines.reactor.mono
2424
import reactor.core.publisher.Mono
25+
import kotlin.coroutines.CoroutineContext
2526

2627
/**
2728
* Kotlin-specific implementation of the [WebFilter] interface that allows for
@@ -34,7 +35,8 @@ import reactor.core.publisher.Mono
3435
abstract class CoWebFilter : WebFilter {
3536

3637
final override fun filter(exchange: ServerWebExchange, chain: WebFilterChain): Mono<Void> {
37-
return mono(Dispatchers.Unconfined) {
38+
val context = exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext?
39+
return mono(context ?: Dispatchers.Unconfined) {
3840
filter(exchange, object : CoWebFilterChain {
3941
override suspend fun filter(exchange: ServerWebExchange) {
4042
exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] = currentCoroutineContext().minusKey(Job.Key)

spring-web/src/test/kotlin/org/springframework/web/server/CoWebFilterTests.kt

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRe
2626
import org.springframework.web.testfixture.server.MockServerWebExchange
2727
import reactor.core.publisher.Mono
2828
import reactor.test.StepVerifier
29+
import kotlin.coroutines.AbstractCoroutineContextElement
2930
import kotlin.coroutines.CoroutineContext
3031

3132
/**
@@ -44,10 +45,24 @@ class CoWebFilterTests {
4445
val filter = MyCoWebFilter()
4546
val result = filter.filter(exchange, chain)
4647

47-
StepVerifier.create(result)
48-
.verifyComplete()
48+
StepVerifier.create(result).verifyComplete()
49+
50+
assertThat(exchange.attributes["foo"]).isEqualTo("bar")
51+
}
52+
53+
@Test
54+
fun multipleFilters() {
55+
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))
56+
57+
val chain = Mockito.mock(WebFilterChain::class.java)
58+
given(chain.filter(exchange)).willAnswer { MyOtherCoWebFilter().filter(exchange,chain) }.willReturn(Mono.empty())
59+
60+
val result = MyCoWebFilter().filter(exchange, chain)
61+
62+
StepVerifier.create(result).verifyComplete()
4963

5064
assertThat(exchange.attributes["foo"]).isEqualTo("bar")
65+
assertThat(exchange.attributes["foofoo"]).isEqualTo("barbar")
5166
}
5267

5368
@Test
@@ -69,6 +84,28 @@ class CoWebFilterTests {
6984
assertThat(coroutineName.name).isEqualTo("foo")
7085
}
7186

87+
@Test
88+
fun multipleFiltersWithContext() {
89+
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))
90+
91+
val chain = Mockito.mock(WebFilterChain::class.java)
92+
given(chain.filter(exchange)).willAnswer { MyOtherCoWebFilterWithContext().filter(exchange,chain) }.willReturn(Mono.empty())
93+
94+
val filter = MyCoWebFilterWithContext()
95+
val result = filter.filter(exchange, chain)
96+
97+
StepVerifier.create(result).verifyComplete()
98+
99+
val context = exchange.attributes[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext
100+
assertThat(context).isNotNull()
101+
val coroutineName = context[CoroutineName.Key] as CoroutineName
102+
assertThat(coroutineName).isNotNull()
103+
assertThat(coroutineName.name).isEqualTo("foo")
104+
val coroutineDescription = context[CoroutineDescription.Key] as CoroutineDescription
105+
assertThat(coroutineDescription).isNotNull()
106+
assertThat(coroutineDescription.description).isEqualTo("foofoo")
107+
}
108+
72109
}
73110

74111

@@ -79,10 +116,33 @@ private class MyCoWebFilter : CoWebFilter() {
79116
}
80117
}
81118

119+
private class MyOtherCoWebFilter : CoWebFilter() {
120+
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
121+
exchange.attributes["foofoo"] = "barbar"
122+
chain.filter(exchange)
123+
}
124+
}
125+
82126
private class MyCoWebFilterWithContext : CoWebFilter() {
83127
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
84128
withContext(CoroutineName("foo")) {
85129
chain.filter(exchange)
86130
}
87131
}
88132
}
133+
134+
private class MyOtherCoWebFilterWithContext : CoWebFilter() {
135+
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
136+
withContext(CoroutineDescription("foofoo")) {
137+
chain.filter(exchange)
138+
}
139+
}
140+
}
141+
142+
data class CoroutineDescription(val description: String) : AbstractCoroutineContextElement(CoroutineDescription) {
143+
144+
companion object Key : CoroutineContext.Key<CoroutineDescription>
145+
146+
override fun toString(): String = "CoroutineDescription($description)"
147+
}
148+

0 commit comments

Comments
 (0)