Skip to content

Commit 507b709

Browse files
authored
feat(eks): trainium instance types (#29155)
@freschri – It's a little hard to find docs on this but I think this is what you're after? Closes #29131. ---- *By submitting this pull request, I confirm that my contribution is made under the terms of the Apache-2.0 license*
1 parent 98e9fbe commit 507b709

File tree

4 files changed

+60
-9
lines changed

4 files changed

+60
-9
lines changed

packages/aws-cdk-lib/aws-eks/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,8 @@ cluster.addNodegroupCapacity('custom-node-group', {
228228
});
229229
```
230230

231-
> **NOTE:** If you add instances with the inferentia (`inf1` or `inf2`) class the
232-
> [neuron plugin](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/containers/dlc-then-eks-devflow.html)
231+
> **NOTE:** If you add instances with the inferentia class (`inf1` or `inf2`) or trainium class (`trn1` or `trn1n`)
232+
> the [neuron plugin](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/containers/dlc-then-eks-devflow.html)
233233
> will be automatically installed in the kubernetes cluster.
234234
235235
#### Node Groups with IPv6 Support

packages/aws-cdk-lib/aws-eks/lib/cluster.ts

+21-7
Original file line numberDiff line numberDiff line change
@@ -1800,7 +1800,8 @@ export class Cluster extends ClusterBase {
18001800
spotInterruptHandler: options.spotInterruptHandler,
18011801
});
18021802

1803-
if (nodeTypeForInstanceType(options.instanceType) === NodeType.INFERENTIA) {
1803+
if (nodeTypeForInstanceType(options.instanceType) === NodeType.INFERENTIA ||
1804+
nodeTypeForInstanceType(options.instanceType) === NodeType.TRAINIUM ) {
18041805
this.addNeuronDevicePlugin();
18051806
}
18061807

@@ -1817,11 +1818,13 @@ export class Cluster extends ClusterBase {
18171818
* @param options options for creating a new nodegroup
18181819
*/
18191820
public addNodegroupCapacity(id: string, options?: NodegroupOptions): Nodegroup {
1820-
const hasInferentiaInstanceType = [
1821+
const hasInferentiaOrTrainiumInstanceType = [
18211822
options?.instanceType,
18221823
...options?.instanceTypes ?? [],
1823-
].some(i => i && nodeTypeForInstanceType(i) === NodeType.INFERENTIA);
1824-
if (hasInferentiaInstanceType) {
1824+
].some(i => i && (nodeTypeForInstanceType(i) === NodeType.INFERENTIA ||
1825+
nodeTypeForInstanceType(i) === NodeType.TRAINIUM));
1826+
1827+
if (hasInferentiaOrTrainiumInstanceType) {
18251828
this.addNeuronDevicePlugin();
18261829
}
18271830
return new Nodegroup(this, `Nodegroup${id}`, {
@@ -2373,6 +2376,7 @@ export class EksOptimizedImage implements ec2.IMachineImage {
23732376
'amazon-linux-2/' : 'amazon-linux-2-arm64/' : '')
23742377
+ (this.nodeType === NodeType.GPU ? 'amazon-linux-2-gpu/' : '')
23752378
+ (this.nodeType === NodeType.INFERENTIA ? 'amazon-linux-2-gpu/' : '')
2379+
+ (this.nodeType === NodeType.TRAINIUM ? 'amazon-linux-2-gpu/' : '')
23762380
+ 'recommended/image_id';
23772381
}
23782382

@@ -2410,6 +2414,11 @@ export enum NodeType {
24102414
* Inferentia instances
24112415
*/
24122416
INFERENTIA = 'INFERENTIA',
2417+
2418+
/**
2419+
* Trainium instances
2420+
*/
2421+
TRAINIUM = 'TRAINIUM',
24132422
}
24142423

24152424
/**
@@ -2471,9 +2480,14 @@ export enum MachineImageType {
24712480
}
24722481

24732482
function nodeTypeForInstanceType(instanceType: ec2.InstanceType) {
2474-
return INSTANCE_TYPES.gpu.includes(instanceType.toString().substring(0, 2)) ? NodeType.GPU :
2475-
INSTANCE_TYPES.inferentia.includes(instanceType.toString().substring(0, 4)) ? NodeType.INFERENTIA :
2476-
NodeType.STANDARD;
2483+
if (INSTANCE_TYPES.gpu.includes(instanceType.toString().substring(0, 2))) {
2484+
return NodeType.GPU;
2485+
} else if (INSTANCE_TYPES.inferentia.includes(instanceType.toString().substring(0, 4))) {
2486+
return NodeType.INFERENTIA;
2487+
} else if (INSTANCE_TYPES.trainium.includes(instanceType.toString().substring(0, 4))) {
2488+
return NodeType.TRAINIUM;
2489+
}
2490+
return NodeType.STANDARD;
24772491
}
24782492

24792493
function cpuArchForInstanceType(instanceType: ec2.InstanceType) {

packages/aws-cdk-lib/aws-eks/lib/instance-types.ts

+1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ export const INSTANCE_TYPES = {
44
graviton: ['a1'],
55
graviton2: ['c6g', 'm6g', 'r6g', 't4g'],
66
graviton3: ['c7g'],
7+
trainium: ['trn1', 'trn1n'],
78
};

packages/aws-cdk-lib/aws-eks/test/cluster.test.ts

+36
Original file line numberDiff line numberDiff line change
@@ -2209,6 +2209,42 @@ describe('cluster', () => {
22092209
Manifest: JSON.stringify([sanitized]),
22102210
});
22112211
});
2212+
test('trn1 instances are supported', () => {
2213+
// GIVEN
2214+
const { stack } = testFixtureNoVpc();
2215+
const cluster = new eks.Cluster(stack, 'Cluster', { defaultCapacity: 0, version: CLUSTER_VERSION, prune: false });
2216+
2217+
// WHEN
2218+
cluster.addAutoScalingGroupCapacity('TrainiumInstances', {
2219+
instanceType: new ec2.InstanceType('trn1.2xlarge'),
2220+
minCapacity: 1,
2221+
});
2222+
const fileContents = fs.readFileSync(path.join(__dirname, '..', 'lib', 'addons', 'neuron-device-plugin.yaml'), 'utf8');
2223+
const sanitized = YAML.parse(fileContents);
2224+
2225+
// THEN
2226+
Template.fromStack(stack).hasResourceProperties(eks.KubernetesManifest.RESOURCE_TYPE, {
2227+
Manifest: JSON.stringify([sanitized]),
2228+
});
2229+
});
2230+
test('trn1n instances are supported', () => {
2231+
// GIVEN
2232+
const { stack } = testFixtureNoVpc();
2233+
const cluster = new eks.Cluster(stack, 'Cluster', { defaultCapacity: 0, version: CLUSTER_VERSION, prune: false });
2234+
2235+
// WHEN
2236+
cluster.addAutoScalingGroupCapacity('TrainiumInstances', {
2237+
instanceType: new ec2.InstanceType('trn1n.2xlarge'),
2238+
minCapacity: 1,
2239+
});
2240+
const fileContents = fs.readFileSync(path.join(__dirname, '..', 'lib', 'addons', 'neuron-device-plugin.yaml'), 'utf8');
2241+
const sanitized = YAML.parse(fileContents);
2242+
2243+
// THEN
2244+
Template.fromStack(stack).hasResourceProperties(eks.KubernetesManifest.RESOURCE_TYPE, {
2245+
Manifest: JSON.stringify([sanitized]),
2246+
});
2247+
});
22122248

22132249
test('inf1 instances are supported in addNodegroupCapacity', () => {
22142250
// GIVEN

0 commit comments

Comments
 (0)