diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index fbe6e1f90f..095fc52ef9 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -68,7 +68,7 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): input_config = _Job._format_inputs_to_input_config(inputs, validate_uri) role = ( estimator.sagemaker_session.expand_role(estimator.role) - if expand_role + if (expand_role and not is_pipeline_variable(estimator.role)) else estimator.role ) output_config = _Job._prepare_output_config(estimator.output_path, estimator.output_kms_key) diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index aa3587b3b4..deff0a3cdb 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -23,6 +23,7 @@ from sagemaker.instance_group import InstanceGroup from sagemaker.job import _Job from sagemaker.model import FrameworkModel +from sagemaker.workflow.parameters import ParameterString BUCKET_NAME = "s3://mybucket/train" S3_OUTPUT_PATH = "s3://bucket/prefix" @@ -218,6 +219,15 @@ def test_load_config_with_code_channel_no_code_uri(framework): assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE +def test_load_config_with_role_as_pipeline_parameter(estimator): + inputs = TrainingInput(BUCKET_NAME) + estimator.role = ParameterString(name="Role") + + config = _Job._load_config(inputs, estimator) + + assert config["role"] == estimator.role + + def test_format_inputs_none(): channels = _Job._format_inputs_to_input_config(inputs=None)