13
13
# language governing permissions and limitations under the License.
14
14
from __future__ import absolute_import
15
15
16
+ import os
16
17
import json
17
18
18
19
import pytest
19
20
import sagemaker
20
21
import warnings
21
22
22
23
from sagemaker .workflow .pipeline_context import PipelineSession
24
+ from sagemaker .workflow .parameters import ParameterString
23
25
24
26
from sagemaker .workflow .steps import TrainingStep
25
27
from sagemaker .workflow .pipeline import Pipeline
46
48
from sagemaker .amazon .ntm import NTM
47
49
from sagemaker .amazon .object2vec import Object2Vec
48
50
51
+ from tests .integ import DATA_DIR
49
52
50
53
from sagemaker .inputs import TrainingInput
51
54
from tests .unit .sagemaker .workflow .helpers import CustomStep
52
55
53
56
REGION = "us-west-2"
54
57
IMAGE_URI = "fakeimage"
55
58
MODEL_NAME = "gisele"
59
+ DUMMY_LOCAL_SCRIPT_PATH = os .path .join (DATA_DIR , "dummy_script.py" )
56
60
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
57
61
DUMMY_S3_SOURCE_DIR = "s3://dummy-s3-source-dir/"
58
62
INSTANCE_TYPE = "ml.m4.xlarge"
@@ -122,6 +126,36 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
122
126
assert step .properties .TrainingJobName .expr == {"Get" : "Steps.MyTrainingStep.TrainingJobName" }
123
127
124
128
129
+ def test_estimator_with_parameterized_output (pipeline_session , training_input ):
130
+ output_path = ParameterString (name = "OutputPath" )
131
+ estimator = XGBoost (
132
+ framework_version = "1.3-1" ,
133
+ py_version = "py3" ,
134
+ role = sagemaker .get_execution_role (),
135
+ instance_type = INSTANCE_TYPE ,
136
+ instance_count = 1 ,
137
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
138
+ output_path = output_path ,
139
+ sagemaker_session = pipeline_session ,
140
+ )
141
+ step_args = estimator .fit (inputs = training_input )
142
+ step = TrainingStep (
143
+ name = "MyTrainingStep" ,
144
+ step_args = step_args ,
145
+ description = "TrainingStep description" ,
146
+ display_name = "MyTrainingStep" ,
147
+ )
148
+ pipeline = Pipeline (
149
+ name = "MyPipeline" ,
150
+ steps = [step ],
151
+ sagemaker_session = pipeline_session ,
152
+ )
153
+ step_def = json .loads (pipeline .definition ())["Steps" ][0 ]
154
+ assert step_def ["Arguments" ]["OutputDataConfig" ]["S3OutputPath" ] == {
155
+ "Get" : "Parameters.OutputPath"
156
+ }
157
+
158
+
125
159
@pytest .mark .parametrize (
126
160
"estimator" ,
127
161
[
@@ -131,23 +165,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
131
165
instance_type = INSTANCE_TYPE ,
132
166
instance_count = 1 ,
133
167
role = sagemaker .get_execution_role (),
134
- entry_point = "entry_point.py" ,
168
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
135
169
),
136
170
PyTorch (
137
171
role = sagemaker .get_execution_role (),
138
172
instance_type = INSTANCE_TYPE ,
139
173
instance_count = 1 ,
140
174
framework_version = "1.8.0" ,
141
175
py_version = "py36" ,
142
- entry_point = "entry_point.py" ,
176
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
143
177
),
144
178
TensorFlow (
145
179
role = sagemaker .get_execution_role (),
146
180
instance_type = INSTANCE_TYPE ,
147
181
instance_count = 1 ,
148
182
framework_version = "2.0" ,
149
183
py_version = "py3" ,
150
- entry_point = "entry_point.py" ,
184
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
151
185
),
152
186
HuggingFace (
153
187
transformers_version = "4.6" ,
@@ -156,23 +190,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
156
190
instance_type = "ml.p3.2xlarge" ,
157
191
instance_count = 1 ,
158
192
py_version = "py36" ,
159
- entry_point = "entry_point.py" ,
193
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
160
194
),
161
195
XGBoost (
162
196
framework_version = "1.3-1" ,
163
197
py_version = "py3" ,
164
198
role = sagemaker .get_execution_role (),
165
199
instance_type = INSTANCE_TYPE ,
166
200
instance_count = 1 ,
167
- entry_point = "entry_point.py" ,
201
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
168
202
),
169
203
MXNet (
170
204
framework_version = "1.4.1" ,
171
205
py_version = "py3" ,
172
206
role = sagemaker .get_execution_role (),
173
207
instance_type = INSTANCE_TYPE ,
174
208
instance_count = 1 ,
175
- entry_point = "entry_point.py" ,
209
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
176
210
),
177
211
RLEstimator (
178
212
entry_point = "cartpole.py" ,
@@ -185,7 +219,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
185
219
),
186
220
Chainer (
187
221
role = sagemaker .get_execution_role (),
188
- entry_point = "entry_point.py" ,
222
+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
189
223
use_mpi = True ,
190
224
num_processes = 4 ,
191
225
framework_version = "5.0.0" ,
0 commit comments