Skip to content

Commit b39049c

Browse files
committed
Supporting tbac in load_run
1 parent 219ad24 commit b39049c

File tree

2 files changed

+27
-19
lines changed

2 files changed

+27
-19
lines changed

src/sagemaker/experiments/run.py

+5
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,7 @@ def load_run(
771771
sagemaker_session: Optional["Session"] = None,
772772
artifact_bucket: Optional[str] = None,
773773
artifact_prefix: Optional[str] = None,
774+
tags: Optional[List[Dict[str, str]]] = None,
774775
) -> Run:
775776
"""Load an existing run.
776777
@@ -839,6 +840,8 @@ def load_run(
839840
will be used.
840841
artifact_prefix (str): The S3 key prefix used to generate the S3 path
841842
to upload the artifact to (default: "trial-component-artifacts").
843+
tags (List[Dict[str, str]]): A list of tags to be used for all create calls,
844+
e.g. to create an experiment, a run group, etc. (default: None).
842845
843846
Returns:
844847
Run: The loaded Run object.
@@ -860,6 +863,7 @@ def load_run(
860863
sagemaker_session=sagemaker_session or _utils.default_session(),
861864
artifact_bucket=artifact_bucket,
862865
artifact_prefix=artifact_prefix,
866+
tags=tags,
863867
)
864868
elif _RunContext.get_current_run():
865869
run_instance = _RunContext.get_current_run()
@@ -879,6 +883,7 @@ def load_run(
879883
sagemaker_session=sagemaker_session or _utils.default_session(),
880884
artifact_bucket=artifact_bucket,
881885
artifact_prefix=artifact_prefix,
886+
tags=tags,
882887
)
883888
else:
884889
raise RuntimeError(

tests/unit/sagemaker/experiments/test_run.py

+22-19
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
TEST_RUN_DISPLAY_NAME,
5656
TEST_ARTIFACT_BUCKET,
5757
TEST_ARTIFACT_PREFIX,
58+
TEST_TAGS
5859
)
5960

6061

@@ -155,24 +156,22 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
155156

156157

157158
@pytest.mark.parametrize(
158-
("kwargs", "expected_artifact_bucket", "expected_artifact_prefix"),
159+
("kwargs", "expected_artifact_bucket", "expected_artifact_prefix","expected_tags"),
159160
[
160-
({}, None, _DEFAULT_ARTIFACT_PREFIX),
161+
({}, None, _DEFAULT_ARTIFACT_PREFIX, None),
161162
(
162163
{
163164
"artifact_bucket": TEST_ARTIFACT_BUCKET,
164165
"artifact_prefix": TEST_ARTIFACT_PREFIX,
166+
"tags": TEST_TAGS
165167
},
166168
TEST_ARTIFACT_BUCKET,
167169
TEST_ARTIFACT_PREFIX,
170+
TEST_TAGS
168171
),
169172
],
170173
)
171174
@patch.object(_TrialComponent, "save", MagicMock(return_value=None))
172-
@patch(
173-
"sagemaker.experiments.run.Experiment._load_or_create",
174-
MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)),
175-
)
176175
@patch(
177176
"sagemaker.experiments.run._Trial._load_or_create",
178177
MagicMock(side_effect=mock_trial_load_or_create_func),
@@ -189,6 +188,7 @@ def test_run_load_no_run_name_and_in_train_job(
189188
kwargs,
190189
expected_artifact_bucket,
191190
expected_artifact_prefix,
191+
expected_tags
192192
):
193193
client = sagemaker_session.sagemaker_client
194194
job_name = "my-train-job"
@@ -220,19 +220,22 @@ def test_run_load_no_run_name_and_in_train_job(
220220
}
221221
]
222222
}
223-
with load_run(sagemaker_session=sagemaker_session, **kwargs) as run_obj:
224-
assert run_obj._in_load
225-
assert not run_obj._inside_init_context
226-
assert run_obj._inside_load_context
227-
assert run_obj.run_name == TEST_RUN_NAME
228-
assert run_obj._trial_component.trial_component_name == expected_tc_name
229-
assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME)
230-
assert run_obj._trial
231-
assert run_obj.experiment_name == TEST_EXP_NAME
232-
assert run_obj._experiment
233-
assert run_obj.experiment_config == exp_config
234-
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
235-
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix
223+
expmock = MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME,tags=expected_tags))
224+
with patch("sagemaker.experiments.run.Experiment._load_or_create", expmock):
225+
with load_run(sagemaker_session=sagemaker_session, **kwargs) as run_obj:
226+
assert run_obj._in_load
227+
assert not run_obj._inside_init_context
228+
assert run_obj._inside_load_context
229+
assert run_obj.run_name == TEST_RUN_NAME
230+
assert run_obj._trial_component.trial_component_name == expected_tc_name
231+
assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME)
232+
assert run_obj._trial
233+
assert run_obj.experiment_name == TEST_EXP_NAME
234+
assert run_obj._experiment
235+
assert run_obj.experiment_config == exp_config
236+
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
237+
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix
238+
assert run_obj._experiment.tags == expected_tags
236239

237240
client.describe_training_job.assert_called_once_with(TrainingJobName=job_name)
238241
run_obj._trial.add_trial_component.assert_not_called()

0 commit comments

Comments
 (0)