13
13
# language governing permissions and limitations under the License.
14
14
from __future__ import absolute_import
15
15
16
+ import os
16
17
import json
17
18
18
19
import pytest
19
20
import sagemaker
20
21
import warnings
21
22
22
23
from sagemaker .workflow .pipeline_context import PipelineSession
24
+ from sagemaker .workflow .parameters import ParameterString
23
25
24
26
from sagemaker .workflow .steps import TrainingStep
25
27
from sagemaker .workflow .pipeline import Pipeline
46
48
from sagemaker .amazon .ntm import NTM
47
49
from sagemaker .amazon .object2vec import Object2Vec
48
50
51
+ from tests .integ import DATA_DIR
49
52
50
53
from sagemaker .inputs import TrainingInput
51
54
52
55
REGION = "us-west-2"
53
56
IMAGE_URI = "fakeimage"
54
57
MODEL_NAME = "gisele"
58
+ DUMMY_LOCAL_SCRIPT_PATH = os .path .join (DATA_DIR , "dummy_script.py" )
55
59
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
56
60
DUMMY_S3_SOURCE_DIR = "s3://dummy-s3-source-dir/"
57
61
INSTANCE_TYPE = "ml.m4.xlarge"
@@ -119,6 +123,36 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
119
123
assert step .properties .TrainingJobName .expr == {"Get" : "Steps.MyTrainingStep.TrainingJobName" }
120
124
121
125
126
+ def test_estimator_with_parameterized_output (pipeline_session , training_input ):
127
+ output_path = ParameterString (name = "OutputPath" )
128
+ estimator = XGBoost (
129
+ framework_version = "1.3-1" ,
130
+ py_version = "py3" ,
131
+ role = sagemaker .get_execution_role (),
132
+ instance_type = INSTANCE_TYPE ,
133
+ instance_count = 1 ,
134
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
135
+ output_path = output_path ,
136
+ sagemaker_session = pipeline_session ,
137
+ )
138
+ step_args = estimator .fit (inputs = training_input )
139
+ step = TrainingStep (
140
+ name = "MyTrainingStep" ,
141
+ step_args = step_args ,
142
+ description = "TrainingStep description" ,
143
+ display_name = "MyTrainingStep" ,
144
+ )
145
+ pipeline = Pipeline (
146
+ name = "MyPipeline" ,
147
+ steps = [step ],
148
+ sagemaker_session = pipeline_session ,
149
+ )
150
+ step_def = json .loads (pipeline .definition ())["Steps" ][0 ]
151
+ assert step_def ["Arguments" ]["OutputDataConfig" ]["S3OutputPath" ] == {
152
+ "Get" : "Parameters.OutputPath"
153
+ }
154
+
155
+
122
156
@pytest .mark .parametrize (
123
157
"estimator" ,
124
158
[
@@ -128,23 +162,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
128
162
instance_type = INSTANCE_TYPE ,
129
163
instance_count = 1 ,
130
164
role = sagemaker .get_execution_role (),
131
- entry_point = "entry_point.py" ,
165
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
132
166
),
133
167
PyTorch (
134
168
role = sagemaker .get_execution_role (),
135
169
instance_type = INSTANCE_TYPE ,
136
170
instance_count = 1 ,
137
171
framework_version = "1.8.0" ,
138
172
py_version = "py36" ,
139
- entry_point = "entry_point.py" ,
173
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
140
174
),
141
175
TensorFlow (
142
176
role = sagemaker .get_execution_role (),
143
177
instance_type = INSTANCE_TYPE ,
144
178
instance_count = 1 ,
145
179
framework_version = "2.0" ,
146
180
py_version = "py3" ,
147
- entry_point = "entry_point.py" ,
181
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
148
182
),
149
183
HuggingFace (
150
184
transformers_version = "4.6" ,
@@ -153,23 +187,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
153
187
instance_type = "ml.p3.2xlarge" ,
154
188
instance_count = 1 ,
155
189
py_version = "py36" ,
156
- entry_point = "entry_point.py" ,
190
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
157
191
),
158
192
XGBoost (
159
193
framework_version = "1.3-1" ,
160
194
py_version = "py3" ,
161
195
role = sagemaker .get_execution_role (),
162
196
instance_type = INSTANCE_TYPE ,
163
197
instance_count = 1 ,
164
- entry_point = "entry_point.py" ,
198
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
165
199
),
166
200
MXNet (
167
201
framework_version = "1.4.1" ,
168
202
py_version = "py3" ,
169
203
role = sagemaker .get_execution_role (),
170
204
instance_type = INSTANCE_TYPE ,
171
205
instance_count = 1 ,
172
- entry_point = "entry_point.py" ,
206
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
173
207
),
174
208
RLEstimator (
175
209
entry_point = "cartpole.py" ,
@@ -182,7 +216,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
182
216
),
183
217
Chainer (
184
218
role = sagemaker .get_execution_role (),
185
- entry_point = "entry_point.py" ,
219
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
186
220
use_mpi = True ,
187
221
num_processes = 4 ,
188
222
framework_version = "5.0.0" ,
0 commit comments