Skip to content

Commit 1a39422

Browse files
ivannotesYi Li
and
Yi Li
authored
feature: do not expand estimator role when it is pipeline parameter (#3416)
Currently when dump the pipeline definition if estimator role is assigned with a pipeline parameter, the SDK will raise error as pipeline parameter expected as a valid input. This change is to support pass pipeline parameter to the role and don't expand it when load the job config. Co-authored-by: Yi Li <[email protected]>
1 parent 7669263 commit 1a39422

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

src/sagemaker/job.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
6868
input_config = _Job._format_inputs_to_input_config(inputs, validate_uri)
6969
role = (
7070
estimator.sagemaker_session.expand_role(estimator.role)
71-
if expand_role
71+
if (expand_role and not is_pipeline_variable(estimator.role))
7272
else estimator.role
7373
)
7474
output_config = _Job._prepare_output_config(estimator.output_path, estimator.output_kms_key)

tests/unit/test_job.py

+10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sagemaker.instance_group import InstanceGroup
2424
from sagemaker.job import _Job
2525
from sagemaker.model import FrameworkModel
26+
from sagemaker.workflow.parameters import ParameterString
2627

2728
BUCKET_NAME = "s3://mybucket/train"
2829
S3_OUTPUT_PATH = "s3://bucket/prefix"
@@ -218,6 +219,15 @@ def test_load_config_with_code_channel_no_code_uri(framework):
218219
assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE
219220

220221

222+
def test_load_config_with_role_as_pipeline_parameter(estimator):
223+
inputs = TrainingInput(BUCKET_NAME)
224+
estimator.role = ParameterString(name="Role")
225+
226+
config = _Job._load_config(inputs, estimator)
227+
228+
assert config["role"] == estimator.role
229+
230+
221231
def test_format_inputs_none():
222232
channels = _Job._format_inputs_to_input_config(inputs=None)
223233

0 commit comments

Comments
 (0)