Skip to content

feat(tracer): close & restore segments when other middlewares return #1545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion packages/tracer/src/middleware/middy.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { TRACER_KEY } from '@aws-lambda-powertools/commons/lib/middleware';
import type { Tracer } from '../Tracer';
import type { Segment, Subsegment } from 'aws-xray-sdk-core';
import type { CaptureLambdaHandlerOptions } from '../types';
Expand Down Expand Up @@ -40,6 +41,18 @@ const captureLambdaHandler = (
let lambdaSegment: Segment;
let handlerSegment: Subsegment;

/**
* Set the cleanup function to be called in case other middlewares return early.
*
* @param request - The request object
*/
const setCleanupFunction = (request: MiddyLikeRequest): void => {
request.internal = {
...request.internal,
[TRACER_KEY]: close,
};
};

const open = (): void => {
const segment = target.getSegment();
if (segment === undefined) {
Expand All @@ -61,9 +74,12 @@ const captureLambdaHandler = (
target.setSegment(lambdaSegment);
};

const captureLambdaHandlerBefore = async (): Promise<void> => {
const captureLambdaHandlerBefore = async (
request: MiddyLikeRequest
): Promise<void> => {
if (target.isTracingEnabled()) {
open();
setCleanupFunction(request);
target.annotateColdStart();
target.addServiceNameAnnotation();
}
Expand Down
64 changes: 63 additions & 1 deletion packages/tracer/tests/unit/middy.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
*
* @group unit/tracer/all
*/

import { captureLambdaHandler } from '../../src/middleware/middy';
import middy from '@middy/core';
import { Tracer } from './../../src';
Expand All @@ -13,6 +12,7 @@ import {
setContextMissingStrategy,
Subsegment,
} from 'aws-xray-sdk-core';
import { cleanupMiddlewares } from '@aws-lambda-powertools/commons/lib/middleware';

jest.spyOn(console, 'debug').mockImplementation(() => null);
jest.spyOn(console, 'warn').mockImplementation(() => null);
Expand Down Expand Up @@ -306,4 +306,66 @@ describe('Middy middleware', () => {
'hello-world'
);
});

test('when enabled, and another middleware returns early, it still closes and restores the segments correctly', async () => {
// Prepare
const tracer = new Tracer();
const setSegmentSpy = jest
.spyOn(tracer.provider, 'setSegment')
.mockImplementation(() => ({}));
jest.spyOn(tracer, 'annotateColdStart').mockImplementation(() => ({}));
jest
.spyOn(tracer, 'addServiceNameAnnotation')
.mockImplementation(() => ({}));
const facadeSegment1 = new Segment('facade');
const handlerSubsegment1 = new Subsegment('## index.handlerA');
jest
.spyOn(facadeSegment1, 'addNewSubsegment')
.mockImplementation(() => handlerSubsegment1);
const facadeSegment2 = new Segment('facade');
const handlerSubsegment2 = new Subsegment('## index.handlerB');
jest
.spyOn(facadeSegment2, 'addNewSubsegment')
.mockImplementation(() => handlerSubsegment2);
jest
.spyOn(tracer.provider, 'getSegment')
.mockImplementationOnce(() => facadeSegment1)
.mockImplementationOnce(() => facadeSegment2);
const myCustomMiddleware = (): middy.MiddlewareObj => {
const before = async (
request: middy.Request
): Promise<undefined | string> => {
// Return early on the second invocation
if (request.event.idx === 1) {
// Cleanup Powertools resources
await cleanupMiddlewares(request);

// Then return early
return 'foo';
}
};

return {
before,
};
};
const handler = middy((): void => {
console.log('Hello world!');
})
.use(captureLambdaHandler(tracer, { captureResponse: false }))
.use(myCustomMiddleware());

// Act
await handler({}, context);
await handler({}, context);

// Assess
// Check that the subsegments are closed
expect(handlerSubsegment1.isClosed()).toBe(true);
expect(handlerSubsegment2.isClosed()).toBe(true);
// Check that the segments are restored
expect(setSegmentSpy).toHaveBeenCalledTimes(4);
expect(setSegmentSpy).toHaveBeenNthCalledWith(2, facadeSegment1);
expect(setSegmentSpy).toHaveBeenNthCalledWith(4, facadeSegment2);
});
});