22
22
from sagemaker .workflow .pipeline import Pipeline
23
23
from sagemaker .workflow .lambda_step import LambdaStep , LambdaOutput , LambdaOutputTypeEnum
24
24
from sagemaker .lambda_helper import Lambda
25
+ from sagemaker .workflow .steps import CacheConfig
25
26
26
27
27
28
@pytest .fixture ()
@@ -38,10 +39,25 @@ def sagemaker_session():
38
39
return session_mock
39
40
40
41
42
+ @pytest .fixture ()
43
+ def sagemaker_session_cn ():
44
+ boto_mock = Mock (name = "boto_session" , region_name = "cn-north-1" )
45
+ session_mock = MagicMock (
46
+ name = "sagemaker_session" ,
47
+ boto_session = boto_mock ,
48
+ boto_region_name = "cn-north-1" ,
49
+ config = None ,
50
+ local_mode = False ,
51
+ )
52
+ session_mock .account_id .return_value = "234567890123"
53
+ return session_mock
54
+
55
+
41
56
def test_lambda_step (sagemaker_session ):
42
57
param = ParameterInteger (name = "MyInt" )
43
- outputParam1 = LambdaOutput (output_name = "output1" , output_type = LambdaOutputTypeEnum .String )
44
- outputParam2 = LambdaOutput (output_name = "output2" , output_type = LambdaOutputTypeEnum .Boolean )
58
+ output_param1 = LambdaOutput (output_name = "output1" , output_type = LambdaOutputTypeEnum .String )
59
+ output_param2 = LambdaOutput (output_name = "output2" , output_type = LambdaOutputTypeEnum .Boolean )
60
+ cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
45
61
lambda_step = LambdaStep (
46
62
name = "MyLambdaStep" ,
47
63
depends_on = ["TestStep" ],
@@ -52,10 +68,17 @@ def test_lambda_step(sagemaker_session):
52
68
display_name = "MyLambdaStep" ,
53
69
description = "MyLambdaStepDescription" ,
54
70
inputs = {"arg1" : "foo" , "arg2" : 5 , "arg3" : param },
55
- outputs = [outputParam1 , outputParam2 ],
71
+ outputs = [output_param1 , output_param2 ],
72
+ cache_config = cache_config ,
56
73
)
57
74
lambda_step .add_depends_on (["SecondTestStep" ])
58
- assert lambda_step .to_request () == {
75
+ pipeline = Pipeline (
76
+ name = "MyPipeline" ,
77
+ parameters = [param ],
78
+ steps = [lambda_step ],
79
+ sagemaker_session = sagemaker_session ,
80
+ )
81
+ assert json .loads (pipeline .definition ())["Steps" ][0 ] == {
59
82
"Name" : "MyLambdaStep" ,
60
83
"Type" : "Lambda" ,
61
84
"DependsOn" : ["TestStep" , "SecondTestStep" ],
@@ -66,7 +89,8 @@ def test_lambda_step(sagemaker_session):
66
89
{"OutputName" : "output1" , "OutputType" : "String" },
67
90
{"OutputName" : "output2" , "OutputType" : "Boolean" },
68
91
],
69
- "Arguments" : {"arg1" : "foo" , "arg2" : 5 , "arg3" : param },
92
+ "Arguments" : {"arg1" : "foo" , "arg2" : 5 , "arg3" : {"Get" : "Parameters.MyInt" }},
93
+ "CacheConfig" : {"Enabled" : True , "ExpireAfter" : "PT1H" },
70
94
}
71
95
72
96
@@ -95,8 +119,8 @@ def test_lambda_step_output_expr(sagemaker_session):
95
119
96
120
def test_pipeline_interpolates_lambda_outputs (sagemaker_session ):
97
121
parameter = ParameterString ("MyStr" )
98
- outputParam1 = LambdaOutput (output_name = "output1" , output_type = LambdaOutputTypeEnum .String )
99
- outputParam2 = LambdaOutput (output_name = "output2" , output_type = LambdaOutputTypeEnum .String )
122
+ output_param1 = LambdaOutput (output_name = "output1" , output_type = LambdaOutputTypeEnum .String )
123
+ output_param2 = LambdaOutput (output_name = "output2" , output_type = LambdaOutputTypeEnum .String )
100
124
lambda_step1 = LambdaStep (
101
125
name = "MyLambdaStep1" ,
102
126
depends_on = ["TestStep" ],
@@ -105,7 +129,7 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
105
129
session = sagemaker_session ,
106
130
),
107
131
inputs = {"arg1" : "foo" },
108
- outputs = [outputParam1 ],
132
+ outputs = [output_param1 ],
109
133
)
110
134
lambda_step2 = LambdaStep (
111
135
name = "MyLambdaStep2" ,
@@ -114,8 +138,8 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session):
114
138
function_arn = "arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda" ,
115
139
session = sagemaker_session ,
116
140
),
117
- inputs = {"arg1" : outputParam1 },
118
- outputs = [outputParam2 ],
141
+ inputs = {"arg1" : output_param1 },
142
+ outputs = [output_param2 ],
119
143
)
120
144
121
145
pipeline = Pipeline (
@@ -207,3 +231,37 @@ def test_lambda_step_without_function_arn(sagemaker_session):
207
231
)
208
232
lambda_step ._get_function_arn ()
209
233
sagemaker_session .account_id .assert_called_once ()
234
+
235
+
236
+ def test_lambda_step_without_function_arn_and_with_error (sagemaker_session_cn ):
237
+ lambda_func = MagicMock (
238
+ function_arn = None ,
239
+ function_name = "name" ,
240
+ execution_role_arn = "arn:aws:lambda:us-west-2:123456789012:execution_role" ,
241
+ zipped_code_dir = "" ,
242
+ handler = "" ,
243
+ session = sagemaker_session_cn ,
244
+ )
245
+ # The raised ValueError contains ResourceConflictException
246
+ lambda_func .create .side_effect = ValueError ("ResourceConflictException" )
247
+ lambda_step1 = LambdaStep (
248
+ name = "MyLambdaStep1" ,
249
+ depends_on = ["TestStep" ],
250
+ lambda_func = lambda_func ,
251
+ inputs = {},
252
+ outputs = [],
253
+ )
254
+ function_arn = lambda_step1 ._get_function_arn ()
255
+ assert function_arn == "arn:aws-cn:lambda:cn-north-1:234567890123:function:name"
256
+
257
+ # The raised ValueError does not contain ResourceConflictException
258
+ lambda_func .create .side_effect = ValueError ()
259
+ lambda_step2 = LambdaStep (
260
+ name = "MyLambdaStep2" ,
261
+ depends_on = ["TestStep" ],
262
+ lambda_func = lambda_func ,
263
+ inputs = {},
264
+ outputs = [],
265
+ )
266
+ with pytest .raises (ValueError ):
267
+ lambda_step2 ._get_function_arn ()
0 commit comments