Skip to content

Commit 43ee3c1

Browse files
authored
chore(middleware-flexible-checksums): delay checksum validation until stream read (#6629)
1 parent 0670605 commit 43ee3c1

File tree

6 files changed

+99
-41
lines changed

6 files changed

+99
-41
lines changed

clients/client-s3/test/e2e/S3.e2e.spec.ts

+22-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import "@aws-sdk/signature-v4-crt";
22

3-
import { S3, SelectObjectContentEventStream } from "@aws-sdk/client-s3";
3+
import { ChecksumAlgorithm, S3, SelectObjectContentEventStream } from "@aws-sdk/client-s3";
44
import { afterAll, afterEach, beforeAll, describe, expect, test as it } from "vitest";
55

66
import { getIntegTestResources } from "../../../../tests/e2e/get-integ-test-resources";
@@ -24,9 +24,7 @@ describe("@aws-sdk/client-s3", () => {
2424

2525
Key = ``;
2626

27-
client = new S3({
28-
region,
29-
});
27+
client = new S3({ region });
3028
});
3129

3230
describe("PutObject", () => {
@@ -74,26 +72,43 @@ describe("@aws-sdk/client-s3", () => {
7472
await client.deleteObject({ Bucket, Key });
7573
});
7674

77-
it("should succeed with valid body payload", async () => {
75+
it("should succeed with valid body payload with checksums", async () => {
7876
// prepare the object.
7977
const body = createBuffer("1MB");
78+
let bodyChecksum = "";
79+
80+
const bodyChecksumReader = (next) => async (args) => {
81+
const checksumValue = args.request.headers["x-amz-checksum-crc32"];
82+
if (checksumValue) {
83+
bodyChecksum = checksumValue;
84+
}
85+
return next(args);
86+
};
87+
client.middlewareStack.addRelativeTo(bodyChecksumReader, {
88+
name: "bodyChecksumReader",
89+
relation: "before",
90+
toMiddleware: "deserializerMiddleware",
91+
});
8092

8193
try {
82-
await client.putObject({ Bucket, Key, Body: body });
94+
await client.putObject({ Bucket, Key, Body: body, ChecksumAlgorithm: ChecksumAlgorithm.CRC32 });
8395
} catch (e) {
8496
console.error("failed to put");
8597
throw e;
8698
}
8799

100+
expect(bodyChecksum).not.toEqual("");
101+
88102
try {
89103
// eslint-disable-next-line no-var
90-
var result = await client.getObject({ Bucket, Key });
104+
var result = await client.getObject({ Bucket, Key, ChecksumMode: "ENABLED" });
91105
} catch (e) {
92106
console.error("failed to get");
93107
throw e;
94108
}
95109

96110
expect(result.$metadata.httpStatusCode).toEqual(200);
111+
expect(result.ChecksumCRC32).toEqual(bodyChecksum);
97112
const { Readable } = require("stream");
98113
expect(result.Body).toBeInstanceOf(Readable);
99114
});

packages/middleware-flexible-checksums/package.json

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"@smithy/types": "^3.6.0",
4141
"@smithy/util-middleware": "^3.0.8",
4242
"@smithy/util-utf8": "^3.0.0",
43+
"@smithy/util-stream": "^3.2.1",
4344
"tslib": "^2.6.2"
4445
},
4546
"devDependencies": {
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest";
22

33
import { getChecksum } from "./getChecksum";
4-
import { isStreaming } from "./isStreaming";
54
import { stringHasher } from "./stringHasher";
65

7-
vi.mock("./isStreaming");
86
vi.mock("./stringHasher");
97

108
describe(getChecksum.name, () => {
119
const mockOptions = {
12-
streamHasher: vi.fn(),
1310
checksumAlgorithmFn: vi.fn(),
1411
base64Encoder: vi.fn(),
1512
};
@@ -26,21 +23,10 @@ describe(getChecksum.name, () => {
2623
vi.clearAllMocks();
2724
});
2825

29-
it("gets checksum from streamHasher if body is streaming", async () => {
30-
vi.mocked(isStreaming).mockReturnValue(true);
31-
mockOptions.streamHasher.mockResolvedValue(mockRawOutput);
32-
const checksum = await getChecksum(mockBody, mockOptions);
33-
expect(checksum).toEqual(mockOutput);
34-
expect(stringHasher).not.toHaveBeenCalled();
35-
expect(mockOptions.streamHasher).toHaveBeenCalledWith(mockOptions.checksumAlgorithmFn, mockBody);
36-
});
37-
38-
it("gets checksum from stringHasher if body is not streaming", async () => {
39-
vi.mocked(isStreaming).mockReturnValue(false);
26+
it("gets checksum from stringHasher", async () => {
4027
vi.mocked(stringHasher).mockResolvedValue(mockRawOutput);
4128
const checksum = await getChecksum(mockBody, mockOptions);
4229
expect(checksum).toEqual(mockOutput);
43-
expect(mockOptions.streamHasher).not.toHaveBeenCalled();
4430
expect(stringHasher).toHaveBeenCalledWith(mockOptions.checksumAlgorithmFn, mockBody);
4531
});
4632
});
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,11 @@
1-
import { ChecksumConstructor, Encoder, HashConstructor, StreamHasher } from "@smithy/types";
1+
import { ChecksumConstructor, Encoder, HashConstructor } from "@smithy/types";
22

3-
import { isStreaming } from "./isStreaming";
43
import { stringHasher } from "./stringHasher";
54

65
export interface GetChecksumDigestOptions {
7-
streamHasher: StreamHasher<any>;
86
checksumAlgorithmFn: ChecksumConstructor | HashConstructor;
97
base64Encoder: Encoder;
108
}
119

12-
export const getChecksum = async (
13-
body: unknown,
14-
{ streamHasher, checksumAlgorithmFn, base64Encoder }: GetChecksumDigestOptions
15-
) => {
16-
const digest = isStreaming(body) ? streamHasher(checksumAlgorithmFn, body) : stringHasher(checksumAlgorithmFn, body);
17-
return base64Encoder(await digest);
18-
};
10+
export const getChecksum = async (body: unknown, { checksumAlgorithmFn, base64Encoder }: GetChecksumDigestOptions) =>
11+
base64Encoder(await stringHasher(checksumAlgorithmFn, body));

packages/middleware-flexible-checksums/src/validateChecksumFromResponse.spec.ts

+56-7
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,30 @@
11
import { HttpResponse } from "@smithy/protocol-http";
2+
import { createChecksumStream } from "@smithy/util-stream";
23
import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest";
34

45
import { PreviouslyResolved } from "./configuration";
56
import { ChecksumAlgorithm } from "./constants";
67
import { getChecksum } from "./getChecksum";
78
import { getChecksumAlgorithmListForResponse } from "./getChecksumAlgorithmListForResponse";
89
import { getChecksumLocationName } from "./getChecksumLocationName";
10+
import { isStreaming } from "./isStreaming";
911
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";
1012
import { validateChecksumFromResponse } from "./validateChecksumFromResponse";
1113

14+
vi.mock("@smithy/util-stream");
1215
vi.mock("./getChecksum");
1316
vi.mock("./getChecksumLocationName");
1417
vi.mock("./getChecksumAlgorithmListForResponse");
18+
vi.mock("./isStreaming");
1519
vi.mock("./selectChecksumAlgorithmFunction");
1620

1721
describe(validateChecksumFromResponse.name, () => {
1822
const mockConfig = {
19-
streamHasher: vi.fn(),
2023
base64Encoder: vi.fn(),
2124
} as unknown as PreviouslyResolved;
2225

2326
const mockBody = {};
27+
const mockBodyStream = { isStream: true };
2428
const mockHeaders = {};
2529
const mockResponse = {
2630
body: mockBody,
@@ -50,6 +54,7 @@ describe(validateChecksumFromResponse.name, () => {
5054
vi.mocked(getChecksumAlgorithmListForResponse).mockImplementation((responseAlgorithms) => responseAlgorithms);
5155
vi.mocked(selectChecksumAlgorithmFunction).mockReturnValue(mockChecksumAlgorithmFn);
5256
vi.mocked(getChecksum).mockResolvedValue(mockChecksum);
57+
vi.mocked(createChecksumStream).mockReturnValue(mockBodyStream);
5358
});
5459

5560
afterEach(() => {
@@ -85,31 +90,56 @@ describe(validateChecksumFromResponse.name, () => {
8590
});
8691

8792
describe("successful validation", () => {
88-
afterEach(() => {
93+
const validateCalls = (isStream: boolean, checksumAlgoFn: ChecksumAlgorithm) => {
8994
expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms);
9095
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
91-
expect(getChecksum).toHaveBeenCalledTimes(1);
92-
});
9396

94-
it("when checksum is populated for first algorithm", async () => {
97+
if (isStream) {
98+
expect(getChecksum).not.toHaveBeenCalled();
99+
expect(createChecksumStream).toHaveBeenCalledTimes(1);
100+
expect(createChecksumStream).toHaveBeenCalledWith({
101+
expectedChecksum: mockChecksum,
102+
checksumSourceLocation: checksumAlgoFn,
103+
checksum: new mockChecksumAlgorithmFn(),
104+
source: mockBody,
105+
base64Encoder: mockConfig.base64Encoder,
106+
});
107+
} else {
108+
expect(getChecksum).toHaveBeenCalledTimes(1);
109+
expect(getChecksum).toHaveBeenCalledWith(mockBody, {
110+
checksumAlgorithmFn: mockChecksumAlgorithmFn,
111+
base64Encoder: mockConfig.base64Encoder,
112+
});
113+
expect(createChecksumStream).not.toHaveBeenCalled();
114+
}
115+
};
116+
117+
it.each([false, true])("when checksum is populated for first algorithm when streaming: %s", async (isStream) => {
118+
vi.mocked(isStreaming).mockReturnValue(isStream);
95119
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], mockChecksum);
96120
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
97121
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
98122
expect(getChecksumLocationName).toHaveBeenCalledWith(mockResponseAlgorithms[0]);
123+
validateCalls(isStream, mockResponseAlgorithms[0]);
99124
});
100125

101-
it("when checksum is populated for second algorithm", async () => {
126+
it.each([false, true])("when checksum is populated for second algorithm when streaming: %s", async (isStream) => {
127+
vi.mocked(isStreaming).mockReturnValue(isStream);
102128
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[1], mockChecksum);
103129
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
104130
expect(getChecksumLocationName).toHaveBeenCalledTimes(2);
105131
expect(getChecksumLocationName).toHaveBeenNthCalledWith(1, mockResponseAlgorithms[0]);
106132
expect(getChecksumLocationName).toHaveBeenNthCalledWith(2, mockResponseAlgorithms[1]);
133+
validateCalls(isStream, mockResponseAlgorithms[1]);
107134
});
108135
});
109136

110-
it("throw error if checksum value is not accurate", async () => {
137+
it("throw error if checksum value is not accurate when not streaming", async () => {
138+
vi.mocked(isStreaming).mockReturnValue(false);
139+
111140
const incorrectChecksum = "incorrectChecksum";
112141
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], incorrectChecksum);
142+
113143
try {
114144
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
115145
fail("should throw checksum mismatch error");
@@ -119,9 +149,28 @@ describe(validateChecksumFromResponse.name, () => {
119149
` in response header "${mockResponseAlgorithms[0]}".`
120150
);
121151
}
152+
122153
expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms);
123154
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
124155
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
125156
expect(getChecksum).toHaveBeenCalledTimes(1);
157+
expect(createChecksumStream).not.toHaveBeenCalled();
158+
});
159+
160+
it("return if checksum value is not accurate when streaming, as error will be thrown when stream is consumed", async () => {
161+
vi.mocked(isStreaming).mockReturnValue(true);
162+
163+
// This override does not matter for the purpose of unit test, but is kept for completeness.
164+
const incorrectChecksum = "incorrectChecksum";
165+
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], incorrectChecksum);
166+
167+
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
168+
169+
expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms);
170+
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
171+
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
172+
expect(getChecksum).not.toHaveBeenCalled();
173+
expect(createChecksumStream).toHaveBeenCalledTimes(1);
174+
expect(responseWithChecksum.body).toBe(mockBodyStream);
126175
});
127176
});

packages/middleware-flexible-checksums/src/validateChecksumFromResponse.ts

+16-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import { HttpResponse } from "@smithy/protocol-http";
2+
import { Checksum } from "@smithy/types";
3+
import { createChecksumStream } from "@smithy/util-stream";
24

35
import { PreviouslyResolved } from "./configuration";
46
import { ChecksumAlgorithm } from "./constants";
57
import { getChecksum } from "./getChecksum";
68
import { getChecksumAlgorithmListForResponse } from "./getChecksumAlgorithmListForResponse";
79
import { getChecksumLocationName } from "./getChecksumLocationName";
10+
import { isStreaming } from "./isStreaming";
811
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";
912

1013
export interface ValidateChecksumFromResponseOptions {
@@ -29,9 +32,20 @@ export const validateChecksumFromResponse = async (
2932
const checksumFromResponse = responseHeaders[responseHeader];
3033
if (checksumFromResponse) {
3134
const checksumAlgorithmFn = selectChecksumAlgorithmFunction(algorithm as ChecksumAlgorithm, config);
32-
const { streamHasher, base64Encoder } = config;
33-
const checksum = await getChecksum(responseBody, { streamHasher, checksumAlgorithmFn, base64Encoder });
35+
const { base64Encoder } = config;
3436

37+
if (isStreaming(responseBody)) {
38+
response.body = createChecksumStream({
39+
expectedChecksum: checksumFromResponse,
40+
checksumSourceLocation: responseHeader,
41+
checksum: new checksumAlgorithmFn() as Checksum,
42+
source: responseBody,
43+
base64Encoder,
44+
});
45+
return;
46+
}
47+
48+
const checksum = await getChecksum(responseBody, { checksumAlgorithmFn, base64Encoder });
3549
if (checksum === checksumFromResponse) {
3650
// The checksum for response payload is valid.
3751
break;

0 commit comments

Comments
 (0)