Skip to content

Commit 406052b

Browse files
qidewenwhenLokiiiiii
authored andcommitted
fix: Fix Tensorflow default model_dir generation when output_path is pipeline variable (aws#3146)
1 parent 99cb46b commit 406052b

File tree

3 files changed

+46
-9
lines changed

3 files changed

+46
-9
lines changed

src/sagemaker/tensorflow/estimator.py

+4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sagemaker.transformer import Transformer
2727
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2828
from sagemaker.tensorflow.training_compiler.config import TrainingCompilerConfig
29+
from sagemaker.workflow import is_pipeline_variable
2930

3031
logger = logging.getLogger("sagemaker")
3132

@@ -392,6 +393,9 @@ def _default_s3_path(self, directory, mpi=False):
392393
if mpi:
393394
return "/opt/ml/model"
394395
if self._current_job_name:
396+
if is_pipeline_variable(self.output_path):
397+
output_path = "s3://{}".format(self.sagemaker_session.default_bucket())
398+
return s3.s3_path_join(output_path, self._current_job_name, directory)
395399
return s3.s3_path_join(self.output_path, self._current_job_name, directory)
396400
return None
397401

tests/integ/sagemaker/workflow/test_model_steps.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,9 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
608608
)
609609
inputs = TrainingInput(s3_data=input_path)
610610
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
611+
output_path = ParameterString(
612+
name="OutputPath", default_value=f"s3://{pipeline_session.default_bucket()}"
613+
)
611614

612615
# If image_uri is not provided, the instance_type should not be a pipeline variable
613616
# since instance_type is used to retrieve image_uri in compile time (PySDK)
@@ -619,6 +622,7 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
619622
framework_version=tf_full_version,
620623
py_version=tf_full_py_version,
621624
sagemaker_session=pipeline_session,
625+
output_path=output_path,
622626
)
623627
train_step_args = tensorflow_estimator.fit(inputs=inputs)
624628
step_train = TrainingStep(
@@ -648,7 +652,7 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
648652
)
649653
pipeline = Pipeline(
650654
name=pipeline_name,
651-
parameters=[instance_count],
655+
parameters=[instance_count, output_path],
652656
steps=[step_train, step_register_model],
653657
sagemaker_session=pipeline_session,
654658
)

tests/unit/sagemaker/workflow/test_training_step.py

+37-8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import json
1717
from mock import Mock, PropertyMock
18+
import re
1819

1920
import pytest
2021
import warnings
@@ -163,6 +164,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
163164

164165
def test_estimator_with_parameterized_output(pipeline_session, training_input):
165166
output_path = ParameterString(name="OutputPath")
167+
# XGBoost
166168
estimator = XGBoost(
167169
framework_version="1.3-1",
168170
py_version="py3",
@@ -174,21 +176,48 @@ def test_estimator_with_parameterized_output(pipeline_session, training_input):
174176
sagemaker_session=pipeline_session,
175177
)
176178
step_args = estimator.fit(inputs=training_input)
177-
step = TrainingStep(
178-
name="MyTrainingStep",
179+
step1 = TrainingStep(
180+
name="MyTrainingStep1",
181+
step_args=step_args,
182+
description="TrainingStep description",
183+
display_name="MyTrainingStep",
184+
)
185+
186+
# TensorFlow
187+
# If model_dir is None and output_path is a pipeline variable
188+
# a default model_dir will be generated with default bucket
189+
estimator = TensorFlow(
190+
framework_version="2.4.1",
191+
py_version="py37",
192+
role=ROLE,
193+
instance_type=INSTANCE_TYPE,
194+
instance_count=1,
195+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
196+
output_path=output_path,
197+
sagemaker_session=pipeline_session,
198+
)
199+
step_args = estimator.fit(inputs=training_input)
200+
step2 = TrainingStep(
201+
name="MyTrainingStep2",
179202
step_args=step_args,
180203
description="TrainingStep description",
181204
display_name="MyTrainingStep",
182205
)
183206
pipeline = Pipeline(
184207
name="MyPipeline",
185-
steps=[step],
208+
steps=[step1, step2],
209+
parameters=[output_path],
186210
sagemaker_session=pipeline_session,
187211
)
188-
step_def = json.loads(pipeline.definition())["Steps"][0]
189-
assert step_def["Arguments"]["OutputDataConfig"]["S3OutputPath"] == {
190-
"Get": "Parameters.OutputPath"
191-
}
212+
step_defs = json.loads(pipeline.definition())["Steps"]
213+
for step_def in step_defs:
214+
assert step_def["Arguments"]["OutputDataConfig"]["S3OutputPath"] == {
215+
"Get": "Parameters.OutputPath"
216+
}
217+
if step_def["Name"] != "MyTrainingStep2":
218+
continue
219+
model_dir = step_def["Arguments"]["HyperParameters"]["model_dir"]
220+
assert re.match(rf'"s3://{BUCKET}/.*/model"', model_dir)
192221

193222

194223
@pytest.mark.parametrize(
@@ -316,7 +345,7 @@ def test_training_step_with_algorithm_base(algo_estimator, pipeline_session):
316345
sagemaker_session=pipeline_session,
317346
)
318347
data = RecordSet(
319-
"s3://{}/{}".format(pipeline_session.default_bucket(), "dummy"),
348+
"s3://{}/{}".format(BUCKET, "dummy"),
320349
num_records=1000,
321350
feature_dim=128,
322351
channel="train",

0 commit comments

Comments
 (0)