Skip to content

Commit eb5e97e

Browse files
ananth102root
authored and
root
committed
feat: Supporting tbac in load_run (aws#4039)
1 parent dc22353 commit eb5e97e

File tree

2 files changed

+67
-30
lines changed

2 files changed

+67
-30
lines changed

src/sagemaker/experiments/run.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,9 @@ def __init__(
210210
)
211211

212212
if not _TrialComponent._trial_component_is_associated_to_trial(
213-
self._trial_component.trial_component_name, self._trial.trial_name, sagemaker_session
213+
self._trial_component.trial_component_name,
214+
self._trial.trial_name,
215+
sagemaker_session,
214216
):
215217
self._trial.add_trial_component(self._trial_component)
216218

@@ -781,6 +783,7 @@ def load_run(
781783
sagemaker_session: Optional["Session"] = None,
782784
artifact_bucket: Optional[str] = None,
783785
artifact_prefix: Optional[str] = None,
786+
tags: Optional[List[Dict[str, str]]] = None,
784787
) -> Run:
785788
"""Load an existing run.
786789
@@ -849,6 +852,8 @@ def load_run(
849852
will be used.
850853
artifact_prefix (str): The S3 key prefix used to generate the S3 path
851854
to upload the artifact to (default: "trial-component-artifacts").
855+
tags (List[Dict[str, str]]): A list of tags to be used for all create calls,
856+
e.g. to create an experiment, a run group, etc. (default: None).
852857
853858
Returns:
854859
Run: The loaded Run object.
@@ -870,6 +875,7 @@ def load_run(
870875
sagemaker_session=sagemaker_session or _utils.default_session(),
871876
artifact_bucket=artifact_bucket,
872877
artifact_prefix=artifact_prefix,
878+
tags=tags,
873879
)
874880
elif _RunContext.get_current_run():
875881
run_instance = _RunContext.get_current_run()
@@ -889,6 +895,7 @@ def load_run(
889895
sagemaker_session=sagemaker_session or _utils.default_session(),
890896
artifact_bucket=artifact_bucket,
891897
artifact_prefix=artifact_prefix,
898+
tags=tags,
892899
)
893900
else:
894901
raise RuntimeError(

tests/unit/sagemaker/experiments/test_run.py

+59-29
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
TEST_RUN_DISPLAY_NAME,
5656
TEST_ARTIFACT_BUCKET,
5757
TEST_ARTIFACT_PREFIX,
58+
TEST_TAGS,
5859
)
5960

6061

@@ -155,24 +156,22 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
155156

156157

157158
@pytest.mark.parametrize(
158-
("kwargs", "expected_artifact_bucket", "expected_artifact_prefix"),
159+
("kwargs", "expected_artifact_bucket", "expected_artifact_prefix", "expected_tags"),
159160
[
160-
({}, None, _DEFAULT_ARTIFACT_PREFIX),
161+
({}, None, _DEFAULT_ARTIFACT_PREFIX, None),
161162
(
162163
{
163164
"artifact_bucket": TEST_ARTIFACT_BUCKET,
164165
"artifact_prefix": TEST_ARTIFACT_PREFIX,
166+
"tags": TEST_TAGS,
165167
},
166168
TEST_ARTIFACT_BUCKET,
167169
TEST_ARTIFACT_PREFIX,
170+
TEST_TAGS,
168171
),
169172
],
170173
)
171174
@patch.object(_TrialComponent, "save", MagicMock(return_value=None))
172-
@patch(
173-
"sagemaker.experiments.run.Experiment._load_or_create",
174-
MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)),
175-
)
176175
@patch(
177176
"sagemaker.experiments.run._Trial._load_or_create",
178177
MagicMock(side_effect=mock_trial_load_or_create_func),
@@ -189,6 +188,7 @@ def test_run_load_no_run_name_and_in_train_job(
189188
kwargs,
190189
expected_artifact_bucket,
191190
expected_artifact_prefix,
191+
expected_tags,
192192
):
193193
client = sagemaker_session.sagemaker_client
194194
job_name = "my-train-job"
@@ -213,26 +213,32 @@ def test_run_load_no_run_name_and_in_train_job(
213213
{
214214
"TrialComponent": {
215215
"Parents": [
216-
{"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]}
216+
{
217+
"ExperimentName": TEST_EXP_NAME,
218+
"TrialName": exp_config[TRIAL_NAME],
219+
}
217220
],
218221
"TrialComponentName": expected_tc_name,
219222
}
220223
}
221224
]
222225
}
223-
with load_run(sagemaker_session=sagemaker_session, **kwargs) as run_obj:
224-
assert run_obj._in_load
225-
assert not run_obj._inside_init_context
226-
assert run_obj._inside_load_context
227-
assert run_obj.run_name == TEST_RUN_NAME
228-
assert run_obj._trial_component.trial_component_name == expected_tc_name
229-
assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME)
230-
assert run_obj._trial
231-
assert run_obj.experiment_name == TEST_EXP_NAME
232-
assert run_obj._experiment
233-
assert run_obj.experiment_config == exp_config
234-
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
235-
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix
226+
expmock = MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME, tags=expected_tags))
227+
with patch("sagemaker.experiments.run.Experiment._load_or_create", expmock):
228+
with load_run(sagemaker_session=sagemaker_session, **kwargs) as run_obj:
229+
assert run_obj._in_load
230+
assert not run_obj._inside_init_context
231+
assert run_obj._inside_load_context
232+
assert run_obj.run_name == TEST_RUN_NAME
233+
assert run_obj._trial_component.trial_component_name == expected_tc_name
234+
assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME)
235+
assert run_obj._trial
236+
assert run_obj.experiment_name == TEST_EXP_NAME
237+
assert run_obj._experiment
238+
assert run_obj.experiment_config == exp_config
239+
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
240+
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix
241+
assert run_obj._experiment.tags == expected_tags
236242

237243
client.describe_training_job.assert_called_once_with(TrainingJobName=job_name)
238244
run_obj._trial.add_trial_component.assert_not_called()
@@ -265,7 +271,9 @@ def test_run_load_no_run_name_and_not_in_train_job(run_obj, sagemaker_session):
265271
assert run_obj == run
266272

267273

268-
def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemaker_session):
274+
def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(
275+
sagemaker_session,
276+
):
269277
with pytest.raises(RuntimeError) as err:
270278
with load_run(sagemaker_session=sagemaker_session):
271279
pass
@@ -388,7 +396,10 @@ def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session):
388396
{
389397
"TrialComponent": {
390398
"Parents": [
391-
{"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]}
399+
{
400+
"ExperimentName": TEST_EXP_NAME,
401+
"TrialName": exp_config[TRIAL_NAME],
402+
}
392403
],
393404
"TrialComponentName": expected_tc_name,
394405
}
@@ -442,7 +453,10 @@ def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session):
442453
{
443454
"TrialComponent": {
444455
"Parents": [
445-
{"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]}
456+
{
457+
"ExperimentName": TEST_EXP_NAME,
458+
"TrialName": exp_config[TRIAL_NAME],
459+
}
446460
],
447461
"TrialComponentName": expected_tc_name,
448462
}
@@ -589,7 +603,10 @@ def test_log_output_artifact_outside_run_context(run_obj):
589603

590604

591605
def test_log_output_artifact(run_obj):
592-
run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value")
606+
run_obj._artifact_uploader.upload_artifact.return_value = (
607+
"s3uri_value",
608+
"etag_value",
609+
)
593610
with run_obj:
594611
run_obj.log_file("foo.txt", "name", "whizz/bang")
595612
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt", extra_args=None)
@@ -608,7 +625,10 @@ def test_log_input_artifact_outside_run_context(run_obj):
608625

609626

610627
def test_log_input_artifact(run_obj):
611-
run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value")
628+
run_obj._artifact_uploader.upload_artifact.return_value = (
629+
"s3uri_value",
630+
"etag_value",
631+
)
612632
with run_obj:
613633
run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False)
614634
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt", extra_args=None)
@@ -653,7 +673,10 @@ def test_log_multiple_input_artifacts(run_obj):
653673
"etag_value" + str(index),
654674
)
655675
run_obj.log_file(
656-
file_path, "name" + str(index), "whizz/bang" + str(index), is_output=False
676+
file_path,
677+
"name" + str(index),
678+
"whizz/bang" + str(index),
679+
is_output=False,
657680
)
658681
run_obj._artifact_uploader.upload_artifact.assert_called_with(
659682
file_path, extra_args=None
@@ -757,7 +780,12 @@ def test_log_precision_recall_invalid_input(run_obj):
757780
with run_obj:
758781
with pytest.raises(ValueError) as error:
759782
run_obj.log_precision_recall(
760-
y_true, y_scores, 0, title="TestPrecisionRecall", no_skill=no_skill, is_output=False
783+
y_true,
784+
y_scores,
785+
0,
786+
title="TestPrecisionRecall",
787+
no_skill=no_skill,
788+
is_output=False,
761789
)
762790
assert "Lengths mismatch between true labels and predicted probabilities" in str(error)
763791

@@ -905,7 +933,8 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
905933
display_name="C" + str(i),
906934
source_arn="D" + str(i),
907935
status=TrialComponentStatus(
908-
primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i)
936+
primary_status=_TrialComponentStatusType.InProgress.value,
937+
message="E" + str(i),
909938
),
910939
start_time=start_time + datetime.timedelta(hours=i),
911940
end_time=end_time + datetime.timedelta(hours=i),
@@ -925,7 +954,8 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
925954
display_name="C" + str(i),
926955
source_arn="D" + str(i),
927956
status=TrialComponentStatus(
928-
primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i)
957+
primary_status=_TrialComponentStatusType.InProgress.value,
958+
message="E" + str(i),
929959
),
930960
start_time=start_time + datetime.timedelta(hours=i),
931961
end_time=end_time + datetime.timedelta(hours=i),

0 commit comments

Comments
 (0)