Skip to content

Commit 1cead3b

Browse files
authored
feat(stepfunctions-tasks): algorithmName validation for SageMakerCreateTrainingJob (#26877)
Referencing PR #26675, I have added validation for the `algorithmName` parameter in `SageMakerCreateTrainingJob`. However, it was suggested that changes for validation should be separated. So, I have created this PR. Docs for `algorithmName`: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#API_AlgorithmSpecification_Contents Exemption Request: This change does not alter the behavior. I believe the unit test `create-training-job.test.ts` that I have added is sufficient to test this change. ---- *By submitting this pull request, I confirm that my contribution is made under the terms of the Apache-2.0 license*
1 parent 4fd510e commit 1cead3b

File tree

2 files changed

+167
-1
lines changed

2 files changed

+167
-1
lines changed

packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts

+24-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { renderEnvironment, renderTags } from './private/utils';
44
import * as ec2 from '../../../aws-ec2';
55
import * as iam from '../../../aws-iam';
66
import * as sfn from '../../../aws-stepfunctions';
7-
import { Duration, Lazy, Size, Stack } from '../../../core';
7+
import { Duration, Lazy, Size, Stack, Token } from '../../../core';
88
import { integrationResourceArn, validatePatternSupported } from '../private/task-utils';
99

1010
/**
@@ -163,6 +163,14 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam
163163
throw new Error('Must define either an algorithm name or training image URI in the algorithm specification');
164164
}
165165

166+
// check that both algorithm name and image are not defined
167+
if (props.algorithmSpecification.algorithmName && props.algorithmSpecification.trainingImage) {
168+
throw new Error('Cannot define both an algorithm name and training image URI in the algorithm specification');
169+
}
170+
171+
// validate algorithm name
172+
this.validateAlgorithmName(props.algorithmSpecification.algorithmName);
173+
166174
// set the input mode to 'File' if not defined
167175
this.algorithmSpecification = props.algorithmSpecification.trainingInputMode
168176
? props.algorithmSpecification
@@ -324,6 +332,21 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam
324332
: {};
325333
}
326334

335+
private validateAlgorithmName(algorithmName?: string): void {
336+
if (algorithmName === undefined || Token.isUnresolved(algorithmName)) {
337+
return;
338+
}
339+
340+
if (algorithmName.length < 1 || 170 < algorithmName.length) {
341+
throw new Error(`Algorithm name length must be between 1 and 170, but got ${algorithmName.length}`);
342+
}
343+
344+
const regex = /^(arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:[a-z\-]*\/)?([a-zA-Z0-9]([a-zA-Z0-9-]){0,62})(?<!-)$/;
345+
if (!regex.test(algorithmName)) {
346+
throw new Error(`Expected algorithm name to match pattern ${regex.source}, but got ${algorithmName}`);
347+
}
348+
}
349+
327350
private makePolicyStatements(): iam.PolicyStatement[] {
328351
// set the sagemaker role or create new one
329352
this._grantPrincipal = this._role =

packages/aws-cdk-lib/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts

+143
Original file line numberDiff line numberDiff line change
@@ -408,3 +408,146 @@ test('Cannot create a SageMaker train task with both algorithm name and image na
408408
}))
409409
.toThrowError(/Must define either an algorithm name or training image URI in the algorithm specification/);
410410
});
411+
412+
test('Cannot create a SageMaker train task with both algorithm name and image name defined', () => {
413+
414+
expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
415+
trainingJobName: 'myTrainJob',
416+
algorithmSpecification: {
417+
algorithmName: 'BlazingText',
418+
trainingImage: tasks.DockerImage.fromJsonExpression(sfn.JsonPath.stringAt('$.Training.imageName')),
419+
},
420+
inputDataConfig: [
421+
{
422+
channelName: 'train',
423+
dataSource: {
424+
s3DataSource: {
425+
s3DataType: tasks.S3DataType.S3_PREFIX,
426+
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
427+
},
428+
},
429+
},
430+
],
431+
outputDataConfig: {
432+
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
433+
},
434+
}))
435+
.toThrowError(/Cannot define both an algorithm name and training image URI in the algorithm specification/);
436+
});
437+
438+
test('create a SageMaker train task with trainingImage', () => {
439+
440+
const task = new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
441+
trainingJobName: 'myTrainJob',
442+
algorithmSpecification: {
443+
trainingImage: tasks.DockerImage.fromJsonExpression(sfn.JsonPath.stringAt('$.Training.imageName')),
444+
},
445+
inputDataConfig: [
446+
{
447+
channelName: 'train',
448+
dataSource: {
449+
s3DataSource: {
450+
s3DataType: tasks.S3DataType.S3_PREFIX,
451+
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
452+
},
453+
},
454+
},
455+
],
456+
outputDataConfig: {
457+
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
458+
},
459+
});
460+
461+
// THEN
462+
expect(stack.resolve(task.toStateJson())).toMatchObject({
463+
Parameters: {
464+
AlgorithmSpecification: {
465+
'TrainingImage.$': '$.Training.imageName',
466+
'TrainingInputMode': 'File',
467+
},
468+
},
469+
});
470+
});
471+
472+
test('create a SageMaker train task with image URI algorithmName', () => {
473+
474+
const task = new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
475+
trainingJobName: 'myTrainJob',
476+
algorithmSpecification: {
477+
algorithmName: 'arn:aws:sagemaker:us-east-1:123456789012:algorithm/scikit-decision-trees',
478+
},
479+
inputDataConfig: [
480+
{
481+
channelName: 'train',
482+
dataSource: {
483+
s3DataSource: {
484+
s3DataType: tasks.S3DataType.S3_PREFIX,
485+
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
486+
},
487+
},
488+
},
489+
],
490+
outputDataConfig: {
491+
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
492+
},
493+
});
494+
495+
// THEN
496+
expect(stack.resolve(task.toStateJson())).toMatchObject({
497+
Parameters: {
498+
AlgorithmSpecification: {
499+
AlgorithmName: 'arn:aws:sagemaker:us-east-1:123456789012:algorithm/scikit-decision-trees',
500+
},
501+
},
502+
});
503+
});
504+
505+
test('Cannot create a SageMaker train task when algorithmName length is 171 or more', () => {
506+
507+
expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
508+
trainingJobName: 'myTrainJob',
509+
algorithmSpecification: {
510+
algorithmName: 'a'.repeat(171), // maximum length is 170
511+
},
512+
inputDataConfig: [
513+
{
514+
channelName: 'train',
515+
dataSource: {
516+
s3DataSource: {
517+
s3DataType: tasks.S3DataType.S3_PREFIX,
518+
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
519+
},
520+
},
521+
},
522+
],
523+
outputDataConfig: {
524+
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
525+
},
526+
}))
527+
.toThrowError(/Algorithm name length must be between 1 and 170, but got 171/);
528+
});
529+
530+
test('Cannot create a SageMaker train task with incorrect algorithmName', () => {
531+
532+
expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
533+
trainingJobName: 'myTrainJob',
534+
algorithmSpecification: {
535+
algorithmName: 'Blazing_Text', // underscores are not allowed
536+
},
537+
inputDataConfig: [
538+
{
539+
channelName: 'train',
540+
dataSource: {
541+
s3DataSource: {
542+
s3DataType: tasks.S3DataType.S3_PREFIX,
543+
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
544+
},
545+
},
546+
},
547+
],
548+
outputDataConfig: {
549+
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
550+
},
551+
}))
552+
.toThrowError(/Expected algorithm name to match pattern/);
553+
});

0 commit comments

Comments
 (0)