@@ -1800,7 +1800,8 @@ export class Cluster extends ClusterBase {
1800
1800
spotInterruptHandler : options . spotInterruptHandler ,
1801
1801
} ) ;
1802
1802
1803
- if ( nodeTypeForInstanceType ( options . instanceType ) === NodeType . INFERENTIA ) {
1803
+ if ( nodeTypeForInstanceType ( options . instanceType ) === NodeType . INFERENTIA ||
1804
+ nodeTypeForInstanceType ( options . instanceType ) === NodeType . TRAINIUM ) {
1804
1805
this . addNeuronDevicePlugin ( ) ;
1805
1806
}
1806
1807
@@ -1817,11 +1818,13 @@ export class Cluster extends ClusterBase {
1817
1818
* @param options options for creating a new nodegroup
1818
1819
*/
1819
1820
public addNodegroupCapacity ( id : string , options ?: NodegroupOptions ) : Nodegroup {
1820
- const hasInferentiaInstanceType = [
1821
+ const hasInferentiaOrTrainiumInstanceType = [
1821
1822
options ?. instanceType ,
1822
1823
...options ?. instanceTypes ?? [ ] ,
1823
- ] . some ( i => i && nodeTypeForInstanceType ( i ) === NodeType . INFERENTIA ) ;
1824
- if ( hasInferentiaInstanceType ) {
1824
+ ] . some ( i => i && ( nodeTypeForInstanceType ( i ) === NodeType . INFERENTIA ||
1825
+ nodeTypeForInstanceType ( i ) === NodeType . TRAINIUM ) ) ;
1826
+
1827
+ if ( hasInferentiaOrTrainiumInstanceType ) {
1825
1828
this . addNeuronDevicePlugin ( ) ;
1826
1829
}
1827
1830
return new Nodegroup ( this , `Nodegroup${ id } ` , {
@@ -2373,6 +2376,7 @@ export class EksOptimizedImage implements ec2.IMachineImage {
2373
2376
'amazon-linux-2/' : 'amazon-linux-2-arm64/' : '' )
2374
2377
+ ( this . nodeType === NodeType . GPU ? 'amazon-linux-2-gpu/' : '' )
2375
2378
+ ( this . nodeType === NodeType . INFERENTIA ? 'amazon-linux-2-gpu/' : '' )
2379
+ + ( this . nodeType === NodeType . TRAINIUM ? 'amazon-linux-2-gpu/' : '' )
2376
2380
+ 'recommended/image_id' ;
2377
2381
}
2378
2382
@@ -2410,6 +2414,11 @@ export enum NodeType {
2410
2414
* Inferentia instances
2411
2415
*/
2412
2416
INFERENTIA = 'INFERENTIA' ,
2417
+
2418
+ /**
2419
+ * Trainium instances
2420
+ */
2421
+ TRAINIUM = 'TRAINIUM' ,
2413
2422
}
2414
2423
2415
2424
/**
@@ -2471,9 +2480,14 @@ export enum MachineImageType {
2471
2480
}
2472
2481
2473
2482
function nodeTypeForInstanceType ( instanceType : ec2 . InstanceType ) {
2474
- return INSTANCE_TYPES . gpu . includes ( instanceType . toString ( ) . substring ( 0 , 2 ) ) ? NodeType . GPU :
2475
- INSTANCE_TYPES . inferentia . includes ( instanceType . toString ( ) . substring ( 0 , 4 ) ) ? NodeType . INFERENTIA :
2476
- NodeType . STANDARD ;
2483
+ if ( INSTANCE_TYPES . gpu . includes ( instanceType . toString ( ) . substring ( 0 , 2 ) ) ) {
2484
+ return NodeType . GPU ;
2485
+ } else if ( INSTANCE_TYPES . inferentia . includes ( instanceType . toString ( ) . substring ( 0 , 4 ) ) ) {
2486
+ return NodeType . INFERENTIA ;
2487
+ } else if ( INSTANCE_TYPES . trainium . includes ( instanceType . toString ( ) . substring ( 0 , 4 ) ) ) {
2488
+ return NodeType . TRAINIUM ;
2489
+ }
2490
+ return NodeType . STANDARD ;
2477
2491
}
2478
2492
2479
2493
function cpuArchForInstanceType ( instanceType : ec2 . InstanceType ) {
0 commit comments