diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 655209e763..2ddfd3f433 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -695,26 +695,45 @@ def _stage_user_code_in_s3(self) -> str: Returns: S3 URI """ - local_mode = self.output_path.startswith("file://") - - if self.code_location is None and local_mode: - code_bucket = self.sagemaker_session.default_bucket() - code_s3_prefix = "{}/{}".format(self._current_job_name, "source") - kms_key = None - elif self.code_location is None: - code_bucket, _ = parse_s3_url(self.output_path) - code_s3_prefix = "{}/{}".format(self._current_job_name, "source") - kms_key = self.output_kms_key - elif local_mode: - code_bucket, key_prefix = parse_s3_url(self.code_location) - code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"])) - kms_key = None + if is_pipeline_variable(self.output_path): + if self.code_location is None: + code_bucket = self.sagemaker_session.default_bucket() + code_s3_prefix = "{}/{}".format(self._current_job_name, "source") + kms_key = None + else: + code_bucket, key_prefix = parse_s3_url(self.code_location) + code_s3_prefix = "/".join( + filter(None, [key_prefix, self._current_job_name, "source"]) + ) + + output_bucket = self.sagemaker_session.default_bucket() + kms_key = self.output_kms_key if code_bucket == output_bucket else None else: - code_bucket, key_prefix = parse_s3_url(self.code_location) - code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"])) + local_mode = self.output_path.startswith("file://") + if local_mode: + if self.code_location is None: + code_bucket = self.sagemaker_session.default_bucket() + code_s3_prefix = "{}/{}".format(self._current_job_name, "source") + kms_key = None + else: + code_bucket, key_prefix = parse_s3_url(self.code_location) + code_s3_prefix = "/".join( + filter(None, [key_prefix, self._current_job_name, "source"]) + ) + kms_key = None + else: + if self.code_location is None: + code_bucket, _ = parse_s3_url(self.output_path) + code_s3_prefix = "{}/{}".format(self._current_job_name, "source") + kms_key = self.output_kms_key + else: + code_bucket, key_prefix = parse_s3_url(self.code_location) + code_s3_prefix = "/".join( + filter(None, [key_prefix, self._current_job_name, "source"]) + ) - output_bucket, _ = parse_s3_url(self.output_path) - kms_key = self.output_kms_key if code_bucket == output_bucket else None + output_bucket, _ = parse_s3_url(self.output_path) + kms_key = self.output_kms_key if code_bucket == output_bucket else None return tar_and_upload_dir( session=self.sagemaker_session.boto_session, diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index 8840d5bfa4..7137230d95 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -13,6 +13,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import os import json import pytest @@ -20,6 +21,7 @@ import warnings from sagemaker.workflow.pipeline_context import PipelineSession +from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.steps import TrainingStep from sagemaker.workflow.pipeline import Pipeline @@ -46,6 +48,7 @@ from sagemaker.amazon.ntm import NTM from sagemaker.amazon.object2vec import Object2Vec +from tests.integ import DATA_DIR from sagemaker.inputs import TrainingInput from tests.unit.sagemaker.workflow.helpers import CustomStep @@ -53,6 +56,7 @@ REGION = "us-west-2" IMAGE_URI = "fakeimage" MODEL_NAME = "gisele" +DUMMY_LOCAL_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py" DUMMY_S3_SOURCE_DIR = "s3://dummy-s3-source-dir/" INSTANCE_TYPE = "ml.m4.xlarge" @@ -122,6 +126,36 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"} +def test_estimator_with_parameterized_output(pipeline_session, training_input): + output_path = ParameterString(name="OutputPath") + estimator = XGBoost( + framework_version="1.3-1", + py_version="py3", + role=sagemaker.get_execution_role(), + instance_type=INSTANCE_TYPE, + instance_count=1, + entry_point=DUMMY_LOCAL_SCRIPT_PATH, + output_path=output_path, + sagemaker_session=pipeline_session, + ) + step_args = estimator.fit(inputs=training_input) + step = TrainingStep( + name="MyTrainingStep", + step_args=step_args, + description="TrainingStep description", + display_name="MyTrainingStep", + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[step], + sagemaker_session=pipeline_session, + ) + step_def = json.loads(pipeline.definition())["Steps"][0] + assert step_def["Arguments"]["OutputDataConfig"]["S3OutputPath"] == { + "Get": "Parameters.OutputPath" + } + + @pytest.mark.parametrize( "estimator", [ @@ -131,7 +165,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar instance_type=INSTANCE_TYPE, instance_count=1, role=sagemaker.get_execution_role(), - entry_point="entry_point.py", + entry_point=DUMMY_LOCAL_SCRIPT_PATH, ), PyTorch( role=sagemaker.get_execution_role(), @@ -139,7 +173,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar instance_count=1, framework_version="1.8.0", py_version="py36", - entry_point="entry_point.py", + entry_point=DUMMY_LOCAL_SCRIPT_PATH, ), TensorFlow( role=sagemaker.get_execution_role(), @@ -147,7 +181,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar instance_count=1, framework_version="2.0", py_version="py3", - entry_point="entry_point.py", + entry_point=DUMMY_LOCAL_SCRIPT_PATH, ), HuggingFace( transformers_version="4.6", @@ -156,7 +190,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar instance_type="ml.p3.2xlarge", instance_count=1, py_version="py36", - entry_point="entry_point.py", + entry_point=DUMMY_LOCAL_SCRIPT_PATH, ), XGBoost( framework_version="1.3-1", @@ -164,7 +198,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar role=sagemaker.get_execution_role(), instance_type=INSTANCE_TYPE, instance_count=1, - entry_point="entry_point.py", + entry_point=DUMMY_LOCAL_SCRIPT_PATH, ), MXNet( framework_version="1.4.1", @@ -172,7 +206,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar role=sagemaker.get_execution_role(), instance_type=INSTANCE_TYPE, instance_count=1, - entry_point="entry_point.py", + entry_point=DUMMY_LOCAL_SCRIPT_PATH, ), RLEstimator( entry_point="cartpole.py", @@ -185,7 +219,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar ), Chainer( role=sagemaker.get_execution_role(), - entry_point="entry_point.py", + entry_point=DUMMY_LOCAL_SCRIPT_PATH, use_mpi=True, num_processes=4, framework_version="5.0.0",