23
23
from sagemaker .parameter import IntegerParameter
24
24
from sagemaker .tuner import HyperparameterTuner
25
25
from sagemaker .workflow .pipeline_context import PipelineSession
26
+ from tests .unit .sagemaker .workflow .helpers import CustomStep
26
27
27
28
from sagemaker .workflow .steps import TransformStep , TransformInput
28
29
from sagemaker .workflow .pipeline import Pipeline
29
30
from sagemaker .workflow .parameters import ParameterString
31
+ from sagemaker .workflow .functions import Join
32
+ from sagemaker .workflow import is_pipeline_variable
30
33
31
34
from sagemaker .transformer import Transformer
32
35
@@ -53,6 +56,7 @@ def client():
53
56
client_mock ._client_config .user_agent = (
54
57
"Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
55
58
)
59
+ client_mock .describe_model .return_value = {"PrimaryContainer" : {}, "Containers" : {}}
56
60
return client_mock
57
61
58
62
@@ -80,18 +84,44 @@ def pipeline_session(boto_session, client):
80
84
)
81
85
82
86
83
- def test_transform_step_with_transformer (pipeline_session ):
84
- model_name = ParameterString ("ModelName" )
87
+ @pytest .mark .parametrize (
88
+ "model_name" ,
89
+ [
90
+ "my-model" ,
91
+ ParameterString ("ModelName" ),
92
+ ParameterString ("ModelName" , default_value = "my-model" ),
93
+ Join (on = "-" , values = ["my" , "model" ]),
94
+ CustomStep (name = "custom-step" ).properties .RoleArn ,
95
+ ],
96
+ )
97
+ @pytest .mark .parametrize (
98
+ "data" ,
99
+ [
100
+ "s3://my-bucket/my-data" ,
101
+ ParameterString ("MyTransformInput" ),
102
+ ParameterString ("MyTransformInput" , default_value = "s3://my-model" ),
103
+ Join (on = "/" , values = ["s3://my-bucket" , "my-transform-data" , "input" ]),
104
+ CustomStep (name = "custom-step" ).properties .OutputDataConfig .S3OutputPath ,
105
+ ],
106
+ )
107
+ @pytest .mark .parametrize (
108
+ "output_path" ,
109
+ [
110
+ "s3://my-bucket/my-output-path" ,
111
+ ParameterString ("MyOutputPath" ),
112
+ ParameterString ("MyOutputPath" , default_value = "s3://my-output" ),
113
+ Join (on = "/" , values = ["s3://my-bucket" , "my-transform-data" , "output" ]),
114
+ CustomStep (name = "custom-step" ).properties .OutputDataConfig .S3OutputPath ,
115
+ ],
116
+ )
117
+ def test_transform_step_with_transformer (model_name , data , output_path , pipeline_session ):
85
118
transformer = Transformer (
86
119
model_name = model_name ,
87
120
instance_type = "ml.m5.xlarge" ,
88
121
instance_count = 1 ,
89
- output_path = f"s3:// { pipeline_session . default_bucket () } /Transform" ,
122
+ output_path = output_path ,
90
123
sagemaker_session = pipeline_session ,
91
124
)
92
- data = ParameterString (
93
- name = "Data" , default_value = f"s3://{ pipeline_session .default_bucket ()} /batch-data"
94
- )
95
125
transform_inputs = TransformInput (data = data )
96
126
97
127
with warnings .catch_warnings (record = True ) as w :
@@ -123,13 +153,27 @@ def test_transform_step_with_transformer(pipeline_session):
123
153
parameters = [model_name , data ],
124
154
sagemaker_session = pipeline_session ,
125
155
)
126
- step_args .args ["ModelName" ] = model_name .expr
127
- step_args .args ["TransformInput" ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = data .expr
128
- assert json .loads (pipeline .definition ())["Steps" ][0 ] == {
129
- "Name" : "MyTransformStep" ,
130
- "Type" : "Transform" ,
131
- "Arguments" : step_args .args ,
132
- }
156
+ step_args = step_args .args
157
+ step_def = json .loads (pipeline .definition ())["Steps" ][0 ]
158
+ step_args ["ModelName" ] = model_name .expr if is_pipeline_variable (model_name ) else model_name
159
+ step_args ["TransformInput" ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = (
160
+ data .expr if is_pipeline_variable (data ) else data
161
+ )
162
+ step_args ["TransformOutput" ]["S3OutputPath" ] = (
163
+ output_path .expr if is_pipeline_variable (output_path ) else output_path
164
+ )
165
+
166
+ del (
167
+ step_args ["ModelName" ],
168
+ step_args ["TransformInput" ]["DataSource" ]["S3DataSource" ]["S3Uri" ],
169
+ step_args ["TransformOutput" ]["S3OutputPath" ],
170
+ )
171
+ del (
172
+ step_def ["Arguments" ]["ModelName" ],
173
+ step_def ["Arguments" ]["TransformInput" ]["DataSource" ]["S3DataSource" ]["S3Uri" ],
174
+ step_def ["Arguments" ]["TransformOutput" ]["S3OutputPath" ],
175
+ )
176
+ assert step_def == {"Name" : "MyTransformStep" , "Type" : "Transform" , "Arguments" : step_args }
133
177
134
178
135
179
@pytest .mark .parametrize (
0 commit comments