Skip to content

Commit f43dd91

Browse files
committed
support estimator output path parameterization
1 parent 8c52f1b commit f43dd91

File tree

2 files changed

+59
-10
lines changed

2 files changed

+59
-10
lines changed

src/sagemaker/estimator.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -695,14 +695,19 @@ def _stage_user_code_in_s3(self) -> str:
695695
696696
Returns: S3 URI
697697
"""
698-
local_mode = self.output_path.startswith("file://")
698+
local_mode = not is_pipeline_variable(self.output_path) and self.output_path.startswith(
699+
"file://"
700+
)
699701

700702
if self.code_location is None and local_mode:
701703
code_bucket = self.sagemaker_session.default_bucket()
702704
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
703705
kms_key = None
704706
elif self.code_location is None:
705-
code_bucket, _ = parse_s3_url(self.output_path)
707+
if is_pipeline_variable(self.output_path):
708+
code_bucket = self.sagemaker_session.default_bucket()
709+
else:
710+
code_bucket, _ = parse_s3_url(self.output_path)
706711
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
707712
kms_key = self.output_kms_key
708713
elif local_mode:
@@ -713,7 +718,10 @@ def _stage_user_code_in_s3(self) -> str:
713718
code_bucket, key_prefix = parse_s3_url(self.code_location)
714719
code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"]))
715720

716-
output_bucket, _ = parse_s3_url(self.output_path)
721+
if is_pipeline_variable(self.output_path):
722+
output_bucket = self.sagemaker_session.default_bucket()
723+
else:
724+
output_bucket, _ = parse_s3_url(self.output_path)
717725
kms_key = self.output_kms_key if code_bucket == output_bucket else None
718726

719727
return tar_and_upload_dir(

tests/unit/sagemaker/workflow/test_training_step.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
# language governing permissions and limitations under the License.
1414
from __future__ import absolute_import
1515

16+
import os
1617
import json
1718

1819
import pytest
1920
import sagemaker
2021
import warnings
2122

2223
from sagemaker.workflow.pipeline_context import PipelineSession
24+
from sagemaker.workflow.parameters import ParameterString
2325

2426
from sagemaker.workflow.steps import TrainingStep
2527
from sagemaker.workflow.pipeline import Pipeline
@@ -46,12 +48,14 @@
4648
from sagemaker.amazon.ntm import NTM
4749
from sagemaker.amazon.object2vec import Object2Vec
4850

51+
from tests.integ import DATA_DIR
4952

5053
from sagemaker.inputs import TrainingInput
5154

5255
REGION = "us-west-2"
5356
IMAGE_URI = "fakeimage"
5457
MODEL_NAME = "gisele"
58+
DUMMY_LOCAL_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
5559
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
5660
DUMMY_S3_SOURCE_DIR = "s3://dummy-s3-source-dir/"
5761
INSTANCE_TYPE = "ml.m4.xlarge"
@@ -119,6 +123,43 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
119123
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
120124

121125

126+
def test_estimator_with_parameterized_output(pipeline_session, training_input):
127+
output_path = ParameterString(name="OutputPath")
128+
estimator = XGBoost(
129+
framework_version="1.3-1",
130+
py_version="py3",
131+
role=sagemaker.get_execution_role(),
132+
instance_type=INSTANCE_TYPE,
133+
instance_count=1,
134+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
135+
output_path=output_path,
136+
sagemaker_session=pipeline_session,
137+
)
138+
step_args = estimator.fit(inputs=training_input)
139+
step = TrainingStep(
140+
name="MyTrainingStep",
141+
step_args=step_args,
142+
description="TrainingStep description",
143+
display_name="MyTrainingStep",
144+
)
145+
pipeline = Pipeline(
146+
name="MyPipeline",
147+
steps=[step],
148+
sagemaker_session=pipeline_session,
149+
)
150+
step_def = json.loads(pipeline.definition())["Steps"][0]
151+
assert step_def == {
152+
"Name": "MyTrainingStep",
153+
"Description": "TrainingStep description",
154+
"DisplayName": "MyTrainingStep",
155+
"Type": "Training",
156+
"Arguments": step_args,
157+
}
158+
assert step_def["Arguments"]["OutputDataConfig"]["S3OutputPath"] == {
159+
"Get": "Parameters.OutputPath"
160+
}
161+
162+
122163
@pytest.mark.parametrize(
123164
"estimator",
124165
[
@@ -128,23 +169,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
128169
instance_type=INSTANCE_TYPE,
129170
instance_count=1,
130171
role=sagemaker.get_execution_role(),
131-
entry_point="entry_point.py",
172+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
132173
),
133174
PyTorch(
134175
role=sagemaker.get_execution_role(),
135176
instance_type=INSTANCE_TYPE,
136177
instance_count=1,
137178
framework_version="1.8.0",
138179
py_version="py36",
139-
entry_point="entry_point.py",
180+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
140181
),
141182
TensorFlow(
142183
role=sagemaker.get_execution_role(),
143184
instance_type=INSTANCE_TYPE,
144185
instance_count=1,
145186
framework_version="2.0",
146187
py_version="py3",
147-
entry_point="entry_point.py",
188+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
148189
),
149190
HuggingFace(
150191
transformers_version="4.6",
@@ -153,23 +194,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
153194
instance_type="ml.p3.2xlarge",
154195
instance_count=1,
155196
py_version="py36",
156-
entry_point="entry_point.py",
197+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
157198
),
158199
XGBoost(
159200
framework_version="1.3-1",
160201
py_version="py3",
161202
role=sagemaker.get_execution_role(),
162203
instance_type=INSTANCE_TYPE,
163204
instance_count=1,
164-
entry_point="entry_point.py",
205+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
165206
),
166207
MXNet(
167208
framework_version="1.4.1",
168209
py_version="py3",
169210
role=sagemaker.get_execution_role(),
170211
instance_type=INSTANCE_TYPE,
171212
instance_count=1,
172-
entry_point="entry_point.py",
213+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
173214
),
174215
RLEstimator(
175216
entry_point="cartpole.py",
@@ -182,7 +223,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
182223
),
183224
Chainer(
184225
role=sagemaker.get_execution_role(),
185-
entry_point="entry_point.py",
226+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
186227
use_mpi=True,
187228
num_processes=4,
188229
framework_version="5.0.0",

0 commit comments

Comments
 (0)