Skip to content

Commit d61dec6

Browse files
fix: support estimator output path parameterization (#3111)
1 parent 19b5560 commit d61dec6

File tree

2 files changed

+78
-25
lines changed

2 files changed

+78
-25
lines changed

src/sagemaker/estimator.py

+37-18
Original file line numberDiff line numberDiff line change
@@ -695,26 +695,45 @@ def _stage_user_code_in_s3(self) -> str:
695695
696696
Returns: S3 URI
697697
"""
698-
local_mode = self.output_path.startswith("file://")
699-
700-
if self.code_location is None and local_mode:
701-
code_bucket = self.sagemaker_session.default_bucket()
702-
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
703-
kms_key = None
704-
elif self.code_location is None:
705-
code_bucket, _ = parse_s3_url(self.output_path)
706-
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
707-
kms_key = self.output_kms_key
708-
elif local_mode:
709-
code_bucket, key_prefix = parse_s3_url(self.code_location)
710-
code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"]))
711-
kms_key = None
698+
if is_pipeline_variable(self.output_path):
699+
if self.code_location is None:
700+
code_bucket = self.sagemaker_session.default_bucket()
701+
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
702+
kms_key = None
703+
else:
704+
code_bucket, key_prefix = parse_s3_url(self.code_location)
705+
code_s3_prefix = "/".join(
706+
filter(None, [key_prefix, self._current_job_name, "source"])
707+
)
708+
709+
output_bucket = self.sagemaker_session.default_bucket()
710+
kms_key = self.output_kms_key if code_bucket == output_bucket else None
712711
else:
713-
code_bucket, key_prefix = parse_s3_url(self.code_location)
714-
code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"]))
712+
local_mode = self.output_path.startswith("file://")
713+
if local_mode:
714+
if self.code_location is None:
715+
code_bucket = self.sagemaker_session.default_bucket()
716+
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
717+
kms_key = None
718+
else:
719+
code_bucket, key_prefix = parse_s3_url(self.code_location)
720+
code_s3_prefix = "/".join(
721+
filter(None, [key_prefix, self._current_job_name, "source"])
722+
)
723+
kms_key = None
724+
else:
725+
if self.code_location is None:
726+
code_bucket, _ = parse_s3_url(self.output_path)
727+
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
728+
kms_key = self.output_kms_key
729+
else:
730+
code_bucket, key_prefix = parse_s3_url(self.code_location)
731+
code_s3_prefix = "/".join(
732+
filter(None, [key_prefix, self._current_job_name, "source"])
733+
)
715734

716-
output_bucket, _ = parse_s3_url(self.output_path)
717-
kms_key = self.output_kms_key if code_bucket == output_bucket else None
735+
output_bucket, _ = parse_s3_url(self.output_path)
736+
kms_key = self.output_kms_key if code_bucket == output_bucket else None
718737

719738
return tar_and_upload_dir(
720739
session=self.sagemaker_session.boto_session,

tests/unit/sagemaker/workflow/test_training_step.py

+41-7
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
# language governing permissions and limitations under the License.
1414
from __future__ import absolute_import
1515

16+
import os
1617
import json
1718

1819
import pytest
1920
import sagemaker
2021
import warnings
2122

2223
from sagemaker.workflow.pipeline_context import PipelineSession
24+
from sagemaker.workflow.parameters import ParameterString
2325

2426
from sagemaker.workflow.steps import TrainingStep
2527
from sagemaker.workflow.pipeline import Pipeline
@@ -46,13 +48,15 @@
4648
from sagemaker.amazon.ntm import NTM
4749
from sagemaker.amazon.object2vec import Object2Vec
4850

51+
from tests.integ import DATA_DIR
4952

5053
from sagemaker.inputs import TrainingInput
5154
from tests.unit.sagemaker.workflow.helpers import CustomStep
5255

5356
REGION = "us-west-2"
5457
IMAGE_URI = "fakeimage"
5558
MODEL_NAME = "gisele"
59+
DUMMY_LOCAL_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
5660
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
5761
DUMMY_S3_SOURCE_DIR = "s3://dummy-s3-source-dir/"
5862
INSTANCE_TYPE = "ml.m4.xlarge"
@@ -122,6 +126,36 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
122126
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
123127

124128

129+
def test_estimator_with_parameterized_output(pipeline_session, training_input):
130+
output_path = ParameterString(name="OutputPath")
131+
estimator = XGBoost(
132+
framework_version="1.3-1",
133+
py_version="py3",
134+
role=sagemaker.get_execution_role(),
135+
instance_type=INSTANCE_TYPE,
136+
instance_count=1,
137+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
138+
output_path=output_path,
139+
sagemaker_session=pipeline_session,
140+
)
141+
step_args = estimator.fit(inputs=training_input)
142+
step = TrainingStep(
143+
name="MyTrainingStep",
144+
step_args=step_args,
145+
description="TrainingStep description",
146+
display_name="MyTrainingStep",
147+
)
148+
pipeline = Pipeline(
149+
name="MyPipeline",
150+
steps=[step],
151+
sagemaker_session=pipeline_session,
152+
)
153+
step_def = json.loads(pipeline.definition())["Steps"][0]
154+
assert step_def["Arguments"]["OutputDataConfig"]["S3OutputPath"] == {
155+
"Get": "Parameters.OutputPath"
156+
}
157+
158+
125159
@pytest.mark.parametrize(
126160
"estimator",
127161
[
@@ -131,23 +165,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
131165
instance_type=INSTANCE_TYPE,
132166
instance_count=1,
133167
role=sagemaker.get_execution_role(),
134-
entry_point="entry_point.py",
168+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
135169
),
136170
PyTorch(
137171
role=sagemaker.get_execution_role(),
138172
instance_type=INSTANCE_TYPE,
139173
instance_count=1,
140174
framework_version="1.8.0",
141175
py_version="py36",
142-
entry_point="entry_point.py",
176+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
143177
),
144178
TensorFlow(
145179
role=sagemaker.get_execution_role(),
146180
instance_type=INSTANCE_TYPE,
147181
instance_count=1,
148182
framework_version="2.0",
149183
py_version="py3",
150-
entry_point="entry_point.py",
184+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
151185
),
152186
HuggingFace(
153187
transformers_version="4.6",
@@ -156,23 +190,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
156190
instance_type="ml.p3.2xlarge",
157191
instance_count=1,
158192
py_version="py36",
159-
entry_point="entry_point.py",
193+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
160194
),
161195
XGBoost(
162196
framework_version="1.3-1",
163197
py_version="py3",
164198
role=sagemaker.get_execution_role(),
165199
instance_type=INSTANCE_TYPE,
166200
instance_count=1,
167-
entry_point="entry_point.py",
201+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
168202
),
169203
MXNet(
170204
framework_version="1.4.1",
171205
py_version="py3",
172206
role=sagemaker.get_execution_role(),
173207
instance_type=INSTANCE_TYPE,
174208
instance_count=1,
175-
entry_point="entry_point.py",
209+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
176210
),
177211
RLEstimator(
178212
entry_point="cartpole.py",
@@ -185,7 +219,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
185219
),
186220
Chainer(
187221
role=sagemaker.get_execution_role(),
188-
entry_point="entry_point.py",
222+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
189223
use_mpi=True,
190224
num_processes=4,
191225
framework_version="5.0.0",

0 commit comments

Comments
 (0)