diff --git a/web/src/main/java/org/springframework/security/web/ObservationFilterChainDecorator.java b/web/src/main/java/org/springframework/security/web/ObservationFilterChainDecorator.java index ed24f19f910..a8340fa789e 100644 --- a/web/src/main/java/org/springframework/security/web/ObservationFilterChainDecorator.java +++ b/web/src/main/java/org/springframework/security/web/ObservationFilterChainDecorator.java @@ -46,13 +46,14 @@ * wraps the chain in before and after observations * * @author Josh Cummings + * @author Nikita Konev * @since 6.0 */ public final class ObservationFilterChainDecorator implements FilterChainProxy.FilterChainDecorator { private static final Log logger = LogFactory.getLog(FilterChainProxy.class); - private static final String ATTRIBUTE = ObservationFilterChainDecorator.class + ".observation"; + static final String ATTRIBUTE = ObservationFilterChainDecorator.class + ".observation"; static final String UNSECURED_OBSERVATION_NAME = "spring.security.http.unsecured.requests"; @@ -250,6 +251,16 @@ private void wrapFilter(ServletRequest request, ServletResponse response, Filter private AroundFilterObservation parent(HttpServletRequest request) { FilterChainObservationContext beforeContext = FilterChainObservationContext.before(); FilterChainObservationContext afterContext = FilterChainObservationContext.after(); + + AroundFilterObservation existingParentObservation = (AroundFilterObservation) request + .getAttribute(ATTRIBUTE); + if (existingParentObservation != null) { + beforeContext + .setParentObservation(existingParentObservation.before().getContext().getParentObservation()); + afterContext + .setParentObservation(existingParentObservation.after().getContext().getParentObservation()); + } + Observation before = Observation.createNotStarted(this.convention, () -> beforeContext, this.registry); Observation after = Observation.createNotStarted(this.convention, () -> afterContext, this.registry); AroundFilterObservation parent = AroundFilterObservation.create(before, after); diff --git a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java index 778efc545fc..5de65d234e5 100644 --- a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java +++ b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java @@ -310,6 +310,65 @@ public void doFilterWhenMatchesThenObservationRegistryObserves() throws Exceptio assertFilterChainObservation(contexts.next(), "after", 1); } + // gh-12610 + @Test + void parentObservationIsTakenIntoAccountDuringDispatchError() throws Exception { + ObservationHandler handler = mock(ObservationHandler.class); + given(handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + + given(this.matcher.matches(any())).willReturn(true); + SecurityFilterChain sec = new DefaultSecurityFilterChain(this.matcher, Arrays.asList(this.filter)); + FilterChainProxy fcp = new FilterChainProxy(sec); + fcp.setFilterChainDecorator(new ObservationFilterChainDecorator(registry)); + Filter initialFilter = ObservationFilterChainDecorator.FilterObservation + .create(Observation.createNotStarted("wrap", registry)) + .wrap(fcp); + + ServletRequest initialRequest = new MockHttpServletRequest("GET", "/"); + initialFilter.doFilter(initialRequest, new MockHttpServletResponse(), this.chain); + + // simulate request attribute copying in case dispatching to ERROR + ObservationFilterChainDecorator.AroundFilterObservation parentObservation = (ObservationFilterChainDecorator.AroundFilterObservation) initialRequest + .getAttribute(ObservationFilterChainDecorator.ATTRIBUTE); + assertThat(parentObservation).isNotNull(); + + // simulate dispatching error-related request + Filter errorRelatedFilter = ObservationFilterChainDecorator.FilterObservation + .create(Observation.createNotStarted("wrap", registry)) + .wrap(fcp); + ServletRequest errorRelatedRequest = new MockHttpServletRequest("GET", "/error"); + errorRelatedRequest.setAttribute(ObservationFilterChainDecorator.ATTRIBUTE, parentObservation); + errorRelatedFilter.doFilter(errorRelatedRequest, new MockHttpServletResponse(), this.chain); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Observation.Context.class); + verify(handler, times(8)).onStart(captor.capture()); + verify(handler, times(8)).onStop(any()); + List contexts = captor.getAllValues(); + + Observation.Context initialRequestObservationContextBefore = contexts.get(1); + Observation.Context initialRequestObservationContextAfter = contexts.get(3); + assertFilterChainObservation(initialRequestObservationContextBefore, "before", 1); + assertFilterChainObservation(initialRequestObservationContextAfter, "after", 1); + + assertThat(initialRequestObservationContextBefore.getParentObservation()).isNotNull(); + assertThat(initialRequestObservationContextBefore.getParentObservation()) + .isSameAs(initialRequestObservationContextAfter.getParentObservation()); + + Observation.Context errorRelatedRequestObservationContextBefore = contexts.get(5); + Observation.Context errorRelatedRequestObservationContextAfter = contexts.get(7); + assertFilterChainObservation(errorRelatedRequestObservationContextBefore, "before", 1); + assertFilterChainObservation(errorRelatedRequestObservationContextAfter, "after", 1); + + assertThat(errorRelatedRequestObservationContextBefore.getParentObservation()).isNotNull(); + assertThat(errorRelatedRequestObservationContextBefore.getParentObservation()) + .isSameAs(initialRequestObservationContextBefore.getParentObservation()); + assertThat(errorRelatedRequestObservationContextAfter.getParentObservation()).isNotNull(); + assertThat(errorRelatedRequestObservationContextAfter.getParentObservation()) + .isSameAs(initialRequestObservationContextBefore.getParentObservation()); + } + @Test public void doFilterWhenMultipleFiltersThenObservationRegistryObserves() throws Exception { ObservationHandler handler = mock(ObservationHandler.class);