Skip to content

Commit e315711

Browse files
committed
support estimator output path parameterization
1 parent 9014064 commit e315711

File tree

3 files changed

+57
-11
lines changed

3 files changed

+57
-11
lines changed

src/sagemaker/estimator.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -695,14 +695,19 @@ def _stage_user_code_in_s3(self) -> str:
695695
696696
Returns: S3 URI
697697
"""
698-
local_mode = self.output_path.startswith("file://")
698+
local_mode = not is_pipeline_variable(self.output_path) and self.output_path.startswith(
699+
"file://"
700+
)
699701

700702
if self.code_location is None and local_mode:
701703
code_bucket = self.sagemaker_session.default_bucket()
702704
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
703705
kms_key = None
704706
elif self.code_location is None:
705-
code_bucket, _ = parse_s3_url(self.output_path)
707+
if is_pipeline_variable(self.output_path):
708+
code_bucket = self.sagemaker_session.default_bucket()
709+
else:
710+
code_bucket, _ = parse_s3_url(self.output_path)
706711
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
707712
kms_key = self.output_kms_key
708713
elif local_mode:
@@ -713,7 +718,10 @@ def _stage_user_code_in_s3(self) -> str:
713718
code_bucket, key_prefix = parse_s3_url(self.code_location)
714719
code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"]))
715720

716-
output_bucket, _ = parse_s3_url(self.output_path)
721+
if is_pipeline_variable(self.output_path):
722+
output_bucket = self.sagemaker_session.default_bucket()
723+
else:
724+
output_bucket, _ = parse_s3_url(self.output_path)
717725
kms_key = self.output_kms_key if code_bucket == output_bucket else None
718726

719727
return tar_and_upload_dir(

tests/integ/sagemaker/workflow/test_training_steps.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def test_training_job_with_debugger_and_profiler(
6060
):
6161
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
6262
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
63+
output_path = ParameterString(
64+
name="OutputPath", default_value=f"s3://{sagemaker_session.default_bucket()}/test/"
65+
)
6366

6467
rules = [
6568
Rule.sagemaker(rule_configs.vanishing_gradient()),
@@ -88,6 +91,7 @@ def test_training_job_with_debugger_and_profiler(
8891
sagemaker_session=sagemaker_session,
8992
rules=rules,
9093
debugger_hook_config=debugger_hook_config,
94+
output_path=output_path,
9195
)
9296

9397
step_train = TrainingStep(
@@ -98,7 +102,7 @@ def test_training_job_with_debugger_and_profiler(
98102

99103
pipeline = Pipeline(
100104
name=pipeline_name,
101-
parameters=[instance_count, instance_type],
105+
parameters=[instance_count, instance_type, output_path],
102106
steps=[step_train],
103107
sagemaker_session=sagemaker_session,
104108
)

tests/unit/sagemaker/workflow/test_training_step.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
# language governing permissions and limitations under the License.
1414
from __future__ import absolute_import
1515

16+
import os
1617
import json
1718

1819
import pytest
1920
import sagemaker
2021
import warnings
2122

2223
from sagemaker.workflow.pipeline_context import PipelineSession
24+
from sagemaker.workflow.parameters import ParameterString
2325

2426
from sagemaker.workflow.steps import TrainingStep
2527
from sagemaker.workflow.pipeline import Pipeline
@@ -46,12 +48,14 @@
4648
from sagemaker.amazon.ntm import NTM
4749
from sagemaker.amazon.object2vec import Object2Vec
4850

51+
from tests.integ import DATA_DIR
4952

5053
from sagemaker.inputs import TrainingInput
5154

5255
REGION = "us-west-2"
5356
IMAGE_URI = "fakeimage"
5457
MODEL_NAME = "gisele"
58+
DUMMY_LOCAL_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
5559
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
5660
DUMMY_S3_SOURCE_DIR = "s3://dummy-s3-source-dir/"
5761
INSTANCE_TYPE = "ml.m4.xlarge"
@@ -119,6 +123,36 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
119123
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
120124

121125

126+
def test_estimator_with_parameterized_output(pipeline_session, training_input):
127+
output_path = ParameterString(name="OutputPath")
128+
estimator = XGBoost(
129+
framework_version="1.3-1",
130+
py_version="py3",
131+
role=sagemaker.get_execution_role(),
132+
instance_type=INSTANCE_TYPE,
133+
instance_count=1,
134+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
135+
output_path=output_path,
136+
sagemaker_session=pipeline_session,
137+
)
138+
step_args = estimator.fit(inputs=training_input)
139+
step = TrainingStep(
140+
name="MyTrainingStep",
141+
step_args=step_args,
142+
description="TrainingStep description",
143+
display_name="MyTrainingStep",
144+
)
145+
pipeline = Pipeline(
146+
name="MyPipeline",
147+
steps=[step],
148+
sagemaker_session=pipeline_session,
149+
)
150+
step_def = json.loads(pipeline.definition())["Steps"][0]
151+
assert step_def["Arguments"]["OutputDataConfig"]["S3OutputPath"] == {
152+
"Get": "Parameters.OutputPath"
153+
}
154+
155+
122156
@pytest.mark.parametrize(
123157
"estimator",
124158
[
@@ -128,23 +162,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
128162
instance_type=INSTANCE_TYPE,
129163
instance_count=1,
130164
role=sagemaker.get_execution_role(),
131-
entry_point="entry_point.py",
165+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
132166
),
133167
PyTorch(
134168
role=sagemaker.get_execution_role(),
135169
instance_type=INSTANCE_TYPE,
136170
instance_count=1,
137171
framework_version="1.8.0",
138172
py_version="py36",
139-
entry_point="entry_point.py",
173+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
140174
),
141175
TensorFlow(
142176
role=sagemaker.get_execution_role(),
143177
instance_type=INSTANCE_TYPE,
144178
instance_count=1,
145179
framework_version="2.0",
146180
py_version="py3",
147-
entry_point="entry_point.py",
181+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
148182
),
149183
HuggingFace(
150184
transformers_version="4.6",
@@ -153,23 +187,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
153187
instance_type="ml.p3.2xlarge",
154188
instance_count=1,
155189
py_version="py36",
156-
entry_point="entry_point.py",
190+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
157191
),
158192
XGBoost(
159193
framework_version="1.3-1",
160194
py_version="py3",
161195
role=sagemaker.get_execution_role(),
162196
instance_type=INSTANCE_TYPE,
163197
instance_count=1,
164-
entry_point="entry_point.py",
198+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
165199
),
166200
MXNet(
167201
framework_version="1.4.1",
168202
py_version="py3",
169203
role=sagemaker.get_execution_role(),
170204
instance_type=INSTANCE_TYPE,
171205
instance_count=1,
172-
entry_point="entry_point.py",
206+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
173207
),
174208
RLEstimator(
175209
entry_point="cartpole.py",
@@ -182,7 +216,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
182216
),
183217
Chainer(
184218
role=sagemaker.get_execution_role(),
185-
entry_point="entry_point.py",
219+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
186220
use_mpi=True,
187221
num_processes=4,
188222
framework_version="5.0.0",

0 commit comments

Comments
 (0)