Skip to content

Commit b4c16b3

Browse files
committed
chore: update aws auth generator with default role assumers
1 parent 088ccf7 commit b4c16b3

File tree

5 files changed

+231
-3
lines changed

5 files changed

+231
-3
lines changed

codegen/smithy-aws-typescript-codegen/src/main/java/software/amazon/smithy/aws/typescript/codegen/AddAwsAuthPlugin.java

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.List;
2323
import java.util.Map;
2424
import java.util.Set;
25+
import java.util.function.BiConsumer;
2526
import java.util.function.Consumer;
2627
import software.amazon.smithy.aws.traits.ServiceTrait;
2728
import software.amazon.smithy.codegen.core.SymbolProvider;
@@ -37,6 +38,7 @@
3738
import software.amazon.smithy.typescript.codegen.TypeScriptWriter;
3839
import software.amazon.smithy.typescript.codegen.integration.RuntimeClientPlugin;
3940
import software.amazon.smithy.typescript.codegen.integration.TypeScriptIntegration;
41+
import software.amazon.smithy.utils.IoUtils;
4042
import software.amazon.smithy.utils.ListUtils;
4143
import software.amazon.smithy.utils.MapUtils;
4244
import software.amazon.smithy.utils.SetUtils;
@@ -45,6 +47,9 @@
4547
* Configure clients with AWS auth configurations and plugin.
4648
*/
4749
public final class AddAwsAuthPlugin implements TypeScriptIntegration {
50+
static final String STS_CLIENT_PREFIX = "sts-client-";
51+
static final String ROLE_ASSUMERS_FILE = "defaultRoleAssumers";
52+
static final String STS_ROLE_ASSUMERS_FILE = "defaultStsRoleAssumers";
4853

4954
@Override
5055
public void addConfigInterfaceFields(
@@ -66,7 +71,13 @@ public List<RuntimeClientPlugin> getClientPlugins() {
6671
return ListUtils.of(
6772
RuntimeClientPlugin.builder()
6873
.withConventions(AwsDependency.MIDDLEWARE_SIGNING.dependency, "AwsAuth", HAS_CONFIG)
69-
.servicePredicate((m, s) -> !areAllOptionalAuthOperations(m, s))
74+
.servicePredicate((m, s) -> !areAllOptionalAuthOperations(m, s) && !testServiceId(s, "STS"))
75+
.build(),
76+
RuntimeClientPlugin.builder()
77+
.withConventions(AwsDependency.STS_MIDDLEWARE.dependency,
78+
"StsAuth", HAS_CONFIG)
79+
.additionalResolveFunctionParameters("STSClient")
80+
.servicePredicate((m, s) -> testServiceId(s, "STS"))
7081
.build(),
7182
RuntimeClientPlugin.builder()
7283
.withConventions(AwsDependency.MIDDLEWARE_SIGNING.dependency, "AwsAuth", HAS_MIDDLEWARE)
@@ -104,17 +115,70 @@ public Map<String, Consumer<TypeScriptWriter>> getRuntimeConfigWriters(
104115
case NODE:
105116
return MapUtils.of(
106117
"credentialDefaultProvider", writer -> {
118+
if (!testServiceId(service, "STS")) {
119+
writer.addDependency(AwsDependency.STS_CLIENT);
120+
writer.addImport("decorateDefaultCredentialProvider", "decorateDefaultCredentialProvider",
121+
AwsDependency.STS_CLIENT.packageName);
122+
} else {
123+
writer.addImport("decorateDefaultCredentialProvider", "decorateDefaultCredentialProvider",
124+
"./" + ROLE_ASSUMERS_FILE);
125+
}
107126
writer.addDependency(AwsDependency.CREDENTIAL_PROVIDER_NODE);
108127
writer.addImport("defaultProvider", "credentialDefaultProvider",
109128
AwsDependency.CREDENTIAL_PROVIDER_NODE.packageName);
110-
writer.write("credentialDefaultProvider,");
129+
writer.write("credentialDefaultProvider: decorateDefaultCredentialProvider("
130+
+ "credentialDefaultProvider),");
111131
}
112132
);
113133
default:
114134
return Collections.emptyMap();
115135
}
116136
}
117137

138+
@Override
139+
public void writeAdditionalFiles(
140+
TypeScriptSettings settings,
141+
Model model,
142+
SymbolProvider symbolProvider,
143+
BiConsumer<String, Consumer<TypeScriptWriter>> writerFactory
144+
) {
145+
ServiceShape service = settings.getService(model);
146+
if (!testServiceId(service, "STS")) {
147+
return;
148+
}
149+
writerFactory.accept("defaultRoleAssumers.ts", writer -> {
150+
String source = IoUtils.readUtf8Resource(getClass(),
151+
String.format("%s%s.ts", STS_CLIENT_PREFIX, ROLE_ASSUMERS_FILE));
152+
writer.write("$L", source);
153+
});
154+
writerFactory.accept("defaultStsRoleAssumers.ts", writer -> {
155+
String source = IoUtils.readUtf8Resource(getClass(),
156+
String.format("%s%s.ts", STS_CLIENT_PREFIX, STS_ROLE_ASSUMERS_FILE));
157+
writer.write("$L", source);
158+
});
159+
160+
// String utilsFileLocation = String.format("%s%s", docClientPrefix, DocumentClientUtils.CLIENT_UTILS_FILE);
161+
// writerFactory.accept(String.format("%s%s/%s.ts", docClientPrefix,
162+
// DocumentClientUtils.CLIENT_COMMANDS_FOLDER, DocumentClientUtils.CLIENT_UTILS_FILE), writer -> {
163+
// writer.write(IoUtils.readUtf8Resource(AddDocumentClientPlugin.class,
164+
// String.format("%s.ts", utilsFileLocation)));
165+
// });
166+
}
167+
168+
@Override
169+
public void writeAdditionalExports(
170+
TypeScriptSettings settings,
171+
Model model,
172+
SymbolProvider symbolProvider,
173+
TypeScriptWriter writer
174+
) {
175+
ServiceShape service = settings.getService(model);
176+
if (!testServiceId(service, "STS")) {
177+
return;
178+
}
179+
writer.write("export * from $S", "./" + ROLE_ASSUMERS_FILE);
180+
}
181+
118182
private static boolean testServiceId(Shape serviceShape, String expectedId) {
119183
return serviceShape.getTrait(ServiceTrait.class).map(ServiceTrait::getSdkId).orElse("").equals(expectedId);
120184
}

codegen/smithy-aws-typescript-codegen/src/main/java/software/amazon/smithy/aws/typescript/codegen/AwsDependency.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ public enum AwsDependency implements SymbolDependencyContainer {
5959
AWS_SDK_EVENTSTREAM_HANDLER_NODE(NORMAL_DEPENDENCY, "@aws-sdk/eventstream-handler-node", "^1.0.0-rc.1"),
6060
TRANSCRIBE_STREAMING_MIDDLEWARE(NORMAL_DEPENDENCY, "@aws-sdk/middleware-sdk-transcribe-streaming",
6161
"^1.0.0-rc.1"),
62+
STS_MIDDLEWARE(NORMAL_DEPENDENCY, "@aws-sdk/middleware-sdk-sts", "3.11.0"),
63+
STS_CLIENT(NORMAL_DEPENDENCY, "@aws-sdk/client-sts", "3.11.0"),
6264
RETRY_CONFIG_PROVIDER(NORMAL_DEPENDENCY, "@aws-sdk/retry-config-provider", "^1.0.0-rc.1"),
6365
NODE_CONFIG_PROVIDER(NORMAL_DEPENDENCY, "@aws-sdk/node-config-provider", "^1.0.0-rc.1"),
6466
MIDDLEWARE_LOGGER(NORMAL_DEPENDENCY, "@aws-sdk/middleware-logger", "^1.0.0-rc.1"),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import {
2+
DefaultCredentialProvider,
3+
getDefaultRoleAssumer as StsGetDefaultRoleAssumer,
4+
getDefaultRoleAssumerWithWebIdentity as StsGetDefaultRoleAssumerWithWebIdentity,
5+
RoleAssumer,
6+
RoleAssumerWithWebIdentity,
7+
} from "./defaultStsRoleAssumers";
8+
import { STSClient, STSClientConfig } from "./STSClient";
9+
10+
/**
11+
* The default role assumer that used by credential providers when sts:AssumeRole API is needed.
12+
*/
13+
export const getDefaultRoleAssumer = (stsOptions: Pick<STSClientConfig, "logger" | "region"> = {}): RoleAssumer =>
14+
StsGetDefaultRoleAssumer(stsOptions, STSClient);
15+
16+
/**
17+
* The default role assumer that used by credential providers when sts:AssumeRoleWithWebIdentity API is needed.
18+
*/
19+
export const getDefaultRoleAssumerWithWebIdentity = (
20+
stsOptions: Pick<STSClientConfig, "logger" | "region"> = {}
21+
): RoleAssumerWithWebIdentity => StsGetDefaultRoleAssumerWithWebIdentity(stsOptions, STSClient);
22+
23+
/**
24+
* The default credential providers depend STS client to assume role with desired API: sts:assumeRole,
25+
* sts:assumeRoleWithWebIdentity, etc. This function decorates the default credential provider with role assumers which
26+
* encapsulates the process of calling STS commands. This can only be imported by AWS client packages to avoid circular
27+
* dependencies.
28+
*
29+
* @internal
30+
*/
31+
export const decorateDefaultCredentialProvider = (provider: DefaultCredentialProvider): DefaultCredentialProvider => (
32+
input: any
33+
) =>
34+
provider({
35+
roleAssumer: getDefaultRoleAssumer(input),
36+
roleAssumerWithWebIdentity: getDefaultRoleAssumerWithWebIdentity(input),
37+
...input,
38+
});
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import { Credentials, Provider } from "@aws-sdk/types";
2+
3+
import { AssumeRoleCommand, AssumeRoleCommandInput } from "./commands/AssumeRoleCommand";
4+
import {
5+
AssumeRoleWithWebIdentityCommand,
6+
AssumeRoleWithWebIdentityCommandInput,
7+
} from "./commands/AssumeRoleWithWebIdentityCommand";
8+
import type { STSClient, STSClientConfig, STSClientResolvedConfig } from "./STSClient";
9+
10+
/**
11+
* @internal
12+
*/
13+
export type RoleAssumer = (sourceCreds: Credentials, params: AssumeRoleCommandInput) => Promise<Credentials>;
14+
15+
const ASSUME_ROLE_DEFAULT_REGION = "us-east-1";
16+
17+
/**
18+
* Inject the fallback STS region of us-east-1.
19+
*/
20+
const decorateDefaultRegion = (region: string | Provider<string> | undefined): string | Provider<string> => {
21+
if (typeof region !== "function") {
22+
return region === undefined ? ASSUME_ROLE_DEFAULT_REGION : region;
23+
}
24+
return async () => {
25+
try {
26+
return await region();
27+
} catch (e) {
28+
return ASSUME_ROLE_DEFAULT_REGION;
29+
}
30+
};
31+
};
32+
33+
/**
34+
* The default role assumer that used by credential providers when sts:AssumeRole API is needed.
35+
* @internal
36+
*/
37+
export const getDefaultRoleAssumer = (
38+
stsOptions: Pick<STSClientConfig, "logger" | "region">,
39+
stsClientCtor: new (options: STSClientConfig) => STSClient
40+
): RoleAssumer => {
41+
let stsClient: STSClient;
42+
return async (sourceCreds, params) => {
43+
if (!stsClient) {
44+
const { logger, region } = stsOptions;
45+
stsClient = new stsClientCtor({
46+
logger,
47+
credentials: sourceCreds,
48+
region: decorateDefaultRegion(region),
49+
});
50+
}
51+
const { Credentials } = await stsClient.send(new AssumeRoleCommand(params));
52+
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
53+
throw new Error(`Invalid response from STS.assumeRole call with role ${params.RoleArn}`);
54+
}
55+
return {
56+
accessKeyId: Credentials.AccessKeyId,
57+
secretAccessKey: Credentials.SecretAccessKey,
58+
sessionToken: Credentials.SessionToken,
59+
expiration: Credentials.Expiration,
60+
};
61+
};
62+
};
63+
64+
/**
65+
* @internal
66+
*/
67+
export type RoleAssumerWithWebIdentity = (params: AssumeRoleWithWebIdentityCommandInput) => Promise<Credentials>;
68+
69+
/**
70+
* The default role assumer that used by credential providers when sts:AssumeRoleWithWebIdentity API is needed.
71+
* @internal
72+
*/
73+
export const getDefaultRoleAssumerWithWebIdentity = (
74+
stsOptions: Pick<STSClientConfig, "logger" | "region">,
75+
stsClientCtor: new (options: STSClientConfig) => STSClient
76+
): RoleAssumerWithWebIdentity => {
77+
let stsClient: STSClient;
78+
return async (params) => {
79+
if (!stsClient) {
80+
const { logger, region } = stsOptions;
81+
stsClient = new stsClientCtor({
82+
logger,
83+
region: decorateDefaultRegion(region),
84+
});
85+
}
86+
const { Credentials } = await stsClient.send(new AssumeRoleWithWebIdentityCommand(params));
87+
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
88+
throw new Error(`Invalid response from STS.assumeRoleWithWebIdentity call with role ${params.RoleArn}`);
89+
}
90+
return {
91+
accessKeyId: Credentials.AccessKeyId,
92+
secretAccessKey: Credentials.SecretAccessKey,
93+
sessionToken: Credentials.SessionToken,
94+
expiration: Credentials.Expiration,
95+
};
96+
};
97+
};
98+
99+
/**
100+
* @internal
101+
*/
102+
export type DefaultCredentialProvider = (input: any) => Provider<Credentials>;
103+
104+
/**
105+
* The default credential providers depend STS client to assume role with desired API: sts:assumeRole,
106+
* sts:assumeRoleWithWebIdentity, etc. This function decorates the default credential provider with role assumers which
107+
* encapsulates the process of calling STS commands. This can only be imported by AWS client packages to avoid circular
108+
* dependencies.
109+
*
110+
* @internal
111+
*/
112+
export const decorateDefaultCredentialProvider = (provider: DefaultCredentialProvider): DefaultCredentialProvider => (
113+
input: STSClientResolvedConfig
114+
) =>
115+
provider({
116+
roleAssumer: getDefaultRoleAssumer(input, input.stsClientCtor),
117+
roleAssumerWithWebIdentity: getDefaultRoleAssumerWithWebIdentity(input, input.stsClientCtor),
118+
...input,
119+
});

scripts/generate-clients/copy-to-clients.js

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,17 @@ const getOverwritablePredicate = (packageName) => (pathName) => {
2020
"endpoints.ts",
2121
"README.md",
2222
];
23+
const additionalGeneratedFiles = {
24+
"@aws-sdk/client-sts": ["defaultRoleAssumers.ts", "defaultStsRoleAssumers.ts"],
25+
};
2326
return (
2427
pathName
2528
.toLowerCase()
2629
.startsWith(
2730
packageName.toLowerCase().replace("@aws-sdk/client-", "").replace("@aws-sdk/aws-", "").replace(/-/g, "")
28-
) || overwritablePathnames.indexOf(pathName) >= 0
31+
) ||
32+
overwritablePathnames.indexOf(pathName) >= 0 ||
33+
additionalGeneratedFiles[packageName.toLowerCase()]?.indexOf(pathName) >= 0
2934
);
3035
};
3136

0 commit comments

Comments
 (0)