Skip to content

Commit 9756954

Browse files
authored
fix(client-sts): make role assumer source creds refreshable (#2353)
1 parent e5b876f commit 9756954

File tree

3 files changed

+73
-2
lines changed

3 files changed

+73
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import { HttpResponse } from "@aws-sdk/protocol-http";
2+
import { Readable } from "stream";
3+
const assumeRoleResponse = `<AssumeRoleResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
4+
<AssumeRoleResult>
5+
<AssumedRoleUser>
6+
<AssumedRoleId>AROAZOX2IL27GNRBJHWC2:session</AssumedRoleId>
7+
<Arn>arn:aws:sts::123:assumed-role/assume-role-test/session</Arn>
8+
</AssumedRoleUser>
9+
<Credentials>
10+
<AccessKeyId>key</AccessKeyId>
11+
<SecretAccessKey>secrete</SecretAccessKey>
12+
<SessionToken>session-token</SessionToken>
13+
<Expiration>2021-05-05T23:22:08Z</Expiration>
14+
</Credentials>
15+
</AssumeRoleResult>
16+
<ResponseMetadata>
17+
<RequestId>12345678id</RequestId>
18+
</ResponseMetadata>
19+
</AssumeRoleResponse>`;
20+
const mockHandle = jest.fn().mockResolvedValue({
21+
response: new HttpResponse({
22+
statusCode: 200,
23+
body: Readable.from([""]),
24+
}),
25+
});
26+
jest.mock("@aws-sdk/node-http-handler", () => ({
27+
NodeHttpHandler: jest.fn().mockImplementation(() => ({
28+
destroy: () => {},
29+
handle: mockHandle,
30+
})),
31+
streamCollector: async () => Buffer.from(assumeRoleResponse),
32+
}));
33+
34+
import { getDefaultRoleAssumer } from "./defaultRoleAssumers";
35+
import type { AssumeRoleCommandInput } from "./commands/AssumeRoleCommand";
36+
37+
describe("getDefaultRoleAssumer", () => {
38+
beforeEach(() => {
39+
jest.clearAllMocks();
40+
});
41+
it("should use supplied source credentials", async () => {
42+
const roleAssumer = getDefaultRoleAssumer();
43+
const params: AssumeRoleCommandInput = {
44+
RoleArn: "arn:aws:foo",
45+
RoleSessionName: "session",
46+
};
47+
const sourceCred1 = { accessKeyId: "key1", secretAccessKey: "secrete1" };
48+
await roleAssumer(sourceCred1, params);
49+
expect(mockHandle).toBeCalledTimes(1);
50+
// Validate request is signed by sourceCred1
51+
expect(mockHandle.mock.calls[0][0].headers?.authorization).toEqual(
52+
expect.stringContaining("AWS4-HMAC-SHA256 Credential=key1/")
53+
);
54+
const sourceCred2 = { accessKeyId: "key2", secretAccessKey: "secrete1" };
55+
await roleAssumer(sourceCred2, params);
56+
// Validate request is signed by sourceCred2
57+
expect(mockHandle).toBeCalledTimes(2);
58+
expect(mockHandle.mock.calls[1][0].headers?.authorization).toEqual(
59+
expect.stringContaining("AWS4-HMAC-SHA256 Credential=key2/")
60+
);
61+
});
62+
});

clients/client-sts/defaultStsRoleAssumers.ts

+4-1
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,15 @@ export const getDefaultRoleAssumer = (
3939
stsClientCtor: new (options: STSClientConfig) => STSClient
4040
): RoleAssumer => {
4141
let stsClient: STSClient;
42+
let closureSourceCreds: Credentials;
4243
return async (sourceCreds, params) => {
44+
closureSourceCreds = sourceCreds;
4345
if (!stsClient) {
4446
const { logger, region } = stsOptions;
4547
stsClient = new stsClientCtor({
4648
logger,
47-
credentials: sourceCreds,
49+
// A hack to make sts client uses the credential in current closure.
50+
credentialDefaultProvider: () => async () => closureSourceCreds,
4851
region: decorateDefaultRegion(region),
4952
});
5053
}

packages/credential-provider-node/src/index.ts

+7-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,13 @@ export const defaultProvider = (
5555
): CredentialProvider => {
5656
const options = { profile: process.env[ENV_PROFILE], ...init };
5757
if (!options.loadedConfig) options.loadedConfig = loadSharedConfigFiles(init);
58-
const providers = [fromSSO(options), fromIni(options), fromProcess(options), fromTokenFile(options), remoteProvider(options)];
58+
const providers = [
59+
fromSSO(options),
60+
fromIni(options),
61+
fromProcess(options),
62+
fromTokenFile(options),
63+
remoteProvider(options),
64+
];
5965
if (!options.profile) providers.unshift(fromEnv());
6066
const providerChain = chain(...providers);
6167

0 commit comments

Comments
 (0)