13
13
# language governing permissions and limitations under the License.
14
14
from __future__ import absolute_import
15
15
16
+ import os
16
17
import json
17
18
18
19
import pytest
19
20
import sagemaker
20
21
import warnings
21
22
22
23
from sagemaker .workflow .pipeline_context import PipelineSession
24
+ from sagemaker .workflow .parameters import ParameterString
23
25
24
26
from sagemaker .workflow .steps import TrainingStep
25
27
from sagemaker .workflow .pipeline import Pipeline
46
48
from sagemaker .amazon .ntm import NTM
47
49
from sagemaker .amazon .object2vec import Object2Vec
48
50
51
+ from tests .integ import DATA_DIR
49
52
50
53
from sagemaker .inputs import TrainingInput
51
54
52
55
REGION = "us-west-2"
53
56
IMAGE_URI = "fakeimage"
54
57
MODEL_NAME = "gisele"
58
+ DUMMY_LOCAL_SCRIPT_PATH = os .path .join (DATA_DIR , "dummy_script.py" )
55
59
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
56
60
DUMMY_S3_SOURCE_DIR = "s3://dummy-s3-source-dir/"
57
61
INSTANCE_TYPE = "ml.m4.xlarge"
@@ -119,6 +123,43 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
119
123
assert step .properties .TrainingJobName .expr == {"Get" : "Steps.MyTrainingStep.TrainingJobName" }
120
124
121
125
126
+ def test_estimator_with_parameterized_output (pipeline_session , training_input ):
127
+ output_path = ParameterString (name = "OutputPath" )
128
+ estimator = XGBoost (
129
+ framework_version = "1.3-1" ,
130
+ py_version = "py3" ,
131
+ role = sagemaker .get_execution_role (),
132
+ instance_type = INSTANCE_TYPE ,
133
+ instance_count = 1 ,
134
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
135
+ output_path = output_path ,
136
+ sagemaker_session = pipeline_session ,
137
+ )
138
+ step_args = estimator .fit (inputs = training_input )
139
+ step = TrainingStep (
140
+ name = "MyTrainingStep" ,
141
+ step_args = step_args ,
142
+ description = "TrainingStep description" ,
143
+ display_name = "MyTrainingStep" ,
144
+ )
145
+ pipeline = Pipeline (
146
+ name = "MyPipeline" ,
147
+ steps = [step ],
148
+ sagemaker_session = pipeline_session ,
149
+ )
150
+ step_def = json .loads (pipeline .definition ())["Steps" ][0 ]
151
+ assert step_def == {
152
+ "Name" : "MyTrainingStep" ,
153
+ "Description" : "TrainingStep description" ,
154
+ "DisplayName" : "MyTrainingStep" ,
155
+ "Type" : "Training" ,
156
+ "Arguments" : step_args ,
157
+ }
158
+ assert step_def ["Arguments" ]["OutputDataConfig" ]["S3OutputPath" ] == {
159
+ "Get" : "Parameters.OutputPath"
160
+ }
161
+
162
+
122
163
@pytest .mark .parametrize (
123
164
"estimator" ,
124
165
[
@@ -128,23 +169,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
128
169
instance_type = INSTANCE_TYPE ,
129
170
instance_count = 1 ,
130
171
role = sagemaker .get_execution_role (),
131
- entry_point = "entry_point.py" ,
172
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
132
173
),
133
174
PyTorch (
134
175
role = sagemaker .get_execution_role (),
135
176
instance_type = INSTANCE_TYPE ,
136
177
instance_count = 1 ,
137
178
framework_version = "1.8.0" ,
138
179
py_version = "py36" ,
139
- entry_point = "entry_point.py" ,
180
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
140
181
),
141
182
TensorFlow (
142
183
role = sagemaker .get_execution_role (),
143
184
instance_type = INSTANCE_TYPE ,
144
185
instance_count = 1 ,
145
186
framework_version = "2.0" ,
146
187
py_version = "py3" ,
147
- entry_point = "entry_point.py" ,
188
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
148
189
),
149
190
HuggingFace (
150
191
transformers_version = "4.6" ,
@@ -153,23 +194,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
153
194
instance_type = "ml.p3.2xlarge" ,
154
195
instance_count = 1 ,
155
196
py_version = "py36" ,
156
- entry_point = "entry_point.py" ,
197
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
157
198
),
158
199
XGBoost (
159
200
framework_version = "1.3-1" ,
160
201
py_version = "py3" ,
161
202
role = sagemaker .get_execution_role (),
162
203
instance_type = INSTANCE_TYPE ,
163
204
instance_count = 1 ,
164
- entry_point = "entry_point.py" ,
205
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
165
206
),
166
207
MXNet (
167
208
framework_version = "1.4.1" ,
168
209
py_version = "py3" ,
169
210
role = sagemaker .get_execution_role (),
170
211
instance_type = INSTANCE_TYPE ,
171
212
instance_count = 1 ,
172
- entry_point = "entry_point.py" ,
213
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
173
214
),
174
215
RLEstimator (
175
216
entry_point = "cartpole.py" ,
@@ -182,7 +223,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
182
223
),
183
224
Chainer (
184
225
role = sagemaker .get_execution_role (),
185
- entry_point = "entry_point.py" ,
226
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
186
227
use_mpi = True ,
187
228
num_processes = 4 ,
188
229
framework_version = "5.0.0" ,
0 commit comments