23
23
from sagemaker .parameter import IntegerParameter
24
24
from sagemaker .tuner import HyperparameterTuner
25
25
from sagemaker .workflow .pipeline_context import PipelineSession
26
+ from tests .unit .sagemaker .workflow .helpers import CustomStep
26
27
27
28
from sagemaker .workflow .steps import TransformStep , TransformInput
28
29
from sagemaker .workflow .pipeline import Pipeline
29
30
from sagemaker .workflow .parameters import ParameterString
31
+ from sagemaker .workflow .functions import Join
32
+ from sagemaker .workflow import is_pipeline_variable
30
33
31
34
from sagemaker .transformer import Transformer
32
35
@@ -53,6 +56,7 @@ def client():
53
56
client_mock ._client_config .user_agent = (
54
57
"Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
55
58
)
59
+ client_mock .describe_model .return_value = {"PrimaryContainer" : {}, "Containers" : {}}
56
60
return client_mock
57
61
58
62
@@ -80,18 +84,35 @@ def pipeline_session(boto_session, client):
80
84
)
81
85
82
86
83
- def test_transform_step_with_transformer (pipeline_session ):
84
- model_name = ParameterString ("ModelName" )
87
+ @pytest .mark .parametrize ("model_name" , [
88
+ "my-model" ,
89
+ ParameterString ("ModelName" ),
90
+ ParameterString ("ModelName" , default_value = "my-model" ),
91
+ Join (on = "-" , values = ["my" , "model" ]),
92
+ CustomStep (name = "custom-step" ).properties .RoleArn
93
+ ])
94
+ @pytest .mark .parametrize ("data" , [
95
+ "s3://my-bucket/my-data" ,
96
+ ParameterString ("MyTransformInput" ),
97
+ ParameterString ("MyTransformInput" , default_value = "s3://my-model" ),
98
+ Join (on = "/" , values = ["s3://my-bucket" , "my-transform-data" , "input" ]),
99
+ CustomStep (name = "custom-step" ).properties .OutputDataConfig .S3OutputPath
100
+ ])
101
+ @pytest .mark .parametrize ("output_path" , [
102
+ "s3://my-bucket/my-output-path" ,
103
+ ParameterString ("MyOutputPath" ),
104
+ ParameterString ("MyOutputPath" , default_value = "s3://my-output" ),
105
+ Join (on = "/" , values = ["s3://my-bucket" , "my-transform-data" , "output" ]),
106
+ CustomStep (name = "custom-step" ).properties .OutputDataConfig .S3OutputPath
107
+ ])
108
+ def test_transform_step_with_transformer (model_name , data , output_path , pipeline_session ):
85
109
transformer = Transformer (
86
110
model_name = model_name ,
87
111
instance_type = "ml.m5.xlarge" ,
88
112
instance_count = 1 ,
89
- output_path = f"s3:// { pipeline_session . default_bucket () } /Transform" ,
113
+ output_path = output_path ,
90
114
sagemaker_session = pipeline_session ,
91
115
)
92
- data = ParameterString (
93
- name = "Data" , default_value = f"s3://{ pipeline_session .default_bucket ()} /batch-data"
94
- )
95
116
transform_inputs = TransformInput (data = data )
96
117
97
118
with warnings .catch_warnings (record = True ) as w :
@@ -123,12 +144,21 @@ def test_transform_step_with_transformer(pipeline_session):
123
144
parameters = [model_name , data ],
124
145
sagemaker_session = pipeline_session ,
125
146
)
126
- step_args .args ["ModelName" ] = model_name .expr
127
- step_args .args ["TransformInput" ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = data .expr
128
- assert json .loads (pipeline .definition ())["Steps" ][0 ] == {
147
+ step_args = step_args .args
148
+ step_def = json .loads (pipeline .definition ())["Steps" ][0 ]
149
+ step_args ["ModelName" ] = model_name .expr if is_pipeline_variable (model_name ) else model_name
150
+ step_args ["TransformInput" ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = data .expr if is_pipeline_variable (data ) else data
151
+ step_args ["TransformOutput" ]["S3OutputPath" ] = output_path .expr if is_pipeline_variable (
152
+ output_path ) else output_path
153
+
154
+ del step_args ["ModelName" ], step_args ["TransformInput" ]["DataSource" ]["S3DataSource" ]["S3Uri" ], \
155
+ step_args ["TransformOutput" ]["S3OutputPath" ]
156
+ del step_def ['Arguments' ]["ModelName" ], step_def ['Arguments' ]["TransformInput" ]["DataSource" ]["S3DataSource" ]["S3Uri" ], \
157
+ step_def ['Arguments' ]["TransformOutput" ]["S3OutputPath" ]
158
+ assert step_def == {
129
159
"Name" : "MyTransformStep" ,
130
160
"Type" : "Transform" ,
131
- "Arguments" : step_args . args ,
161
+ "Arguments" : step_args
132
162
}
133
163
134
164
0 commit comments