Skip to content

Commit 60d6e66

Browse files
authored
feat(aws-stepfunctions-tasks): add environment property for SageMakerCreateTrainingJob (#18976)
Add environment property for SageMakerCreateTrainingJob. Fixes issue #18919. ---- *By submitting this pull request, I confirm that my contribution is made under the terms of the Apache-2.0 license*
1 parent f8d8fe4 commit 60d6e66

File tree

5 files changed

+24
-9
lines changed

5 files changed

+24
-9
lines changed

packages/@aws-cdk/aws-stepfunctions-tasks/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,8 @@ If your training job or model uses resources from AWS Marketplace,
11671167
[network isolation is required](https://docs.aws.amazon.com/sagemaker/latest/dg/mkt-algo-model-internet-free.html).
11681168
To do so, set the `enableNetworkIsolation` property to `true` for `SageMakerCreateModel` or `SageMakerCreateTrainingJob`.
11691169

1170+
To set environment variables for the Docker container use the `environment` property.
1171+
11701172
### Create Training Job
11711173

11721174
You can call the [`CreateTrainingJob`](https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html) API from a `Task` state.

packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts

+9-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import { Duration, Lazy, Size, Stack } from '@aws-cdk/core';
55
import { Construct } from 'constructs';
66
import { integrationResourceArn, validatePatternSupported } from '../private/task-utils';
77
import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig, S3DataType, StoppingCondition, VpcConfig } from './base-types';
8-
import { renderTags } from './private/utils';
8+
import { renderEnvironment, renderTags } from './private/utils';
99

1010
/**
1111
* Properties for creating an Amazon SageMaker training job
@@ -85,6 +85,13 @@ export interface SageMakerCreateTrainingJobProps extends sfn.TaskStateBaseProps
8585
* @default - No VPC
8686
*/
8787
readonly vpcConfig?: VpcConfig;
88+
89+
/**
90+
* Environment variables to set in the Docker container.
91+
*
92+
* @default - No environment variables
93+
*/
94+
readonly environment?: { [key: string]: string };
8895
}
8996

9097
/**
@@ -234,6 +241,7 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam
234241
...this.renderHyperparameters(this.props.hyperparameters),
235242
...renderTags(this.props.tags),
236243
...this.renderVpcConfig(this.props.vpcConfig),
244+
...renderEnvironment(this.props.environment),
237245
};
238246
}
239247

packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts

+2-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import { Size, Stack, Token } from '@aws-cdk/core';
55
import { Construct } from 'constructs';
66
import { integrationResourceArn, validatePatternSupported } from '../private/task-utils';
77
import { BatchStrategy, ModelClientOptions, S3DataType, TransformInput, TransformOutput, TransformResources } from './base-types';
8-
import { renderTags } from './private/utils';
8+
import { renderEnvironment, renderTags } from './private/utils';
99

1010
/**
1111
* Properties for creating an Amazon SageMaker transform job task
@@ -166,7 +166,7 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase {
166166
private renderParameters(): { [key: string]: any } {
167167
return {
168168
...(this.props.batchStrategy ? { BatchStrategy: this.props.batchStrategy } : {}),
169-
...this.renderEnvironment(this.props.environment),
169+
...renderEnvironment(this.props.environment),
170170
...(this.props.maxConcurrentTransforms ? { MaxConcurrentTransforms: this.props.maxConcurrentTransforms } : {}),
171171
...(this.props.maxPayload ? { MaxPayloadInMB: this.props.maxPayload.toMebibytes() } : {}),
172172
...this.props.modelClientOptions ? this.renderModelClientOptions(this.props.modelClientOptions) : {},
@@ -234,10 +234,6 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase {
234234
};
235235
}
236236

237-
private renderEnvironment(environment: { [key: string]: any } | undefined): { [key: string]: any } {
238-
return environment ? { Environment: environment } : {};
239-
}
240-
241237
private makePolicyStatements(): iam.PolicyStatement[] {
242238
const stack = Stack.of(this);
243239

Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
21
export function renderTags(tags: { [key: string]: any } | undefined): { [key: string]: any } {
32
return tags ? { Tags: Object.keys(tags).map((key) => ({ Key: key, Value: tags[key] })) } : {};
4-
}
3+
}
4+
5+
export function renderEnvironment(environment: { [key: string]: any } | undefined): { [key: string]: any } {
6+
return environment ? { Environment: environment } : {};
7+
}

packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts

+6
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ test('create complex training job', () => {
192192
vpcConfig: {
193193
vpc,
194194
},
195+
environment: {
196+
SOMEVAR: 'myvalue',
197+
},
195198
});
196199
trainTask.addSecurityGroup(securityGroup);
197200

@@ -285,6 +288,9 @@ test('create complex training job', () => {
285288
{ Ref: 'VPCPrivateSubnet2SubnetCFCDAA7A' },
286289
],
287290
},
291+
Environment: {
292+
SOMEVAR: 'myvalue',
293+
},
288294
},
289295
});
290296
});

0 commit comments

Comments
 (0)