55
55
TEST_RUN_DISPLAY_NAME ,
56
56
TEST_ARTIFACT_BUCKET ,
57
57
TEST_ARTIFACT_PREFIX ,
58
+ TEST_TAGS ,
58
59
)
59
60
60
61
@@ -155,24 +156,22 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
155
156
156
157
157
158
@pytest .mark .parametrize (
158
- ("kwargs" , "expected_artifact_bucket" , "expected_artifact_prefix" ),
159
+ ("kwargs" , "expected_artifact_bucket" , "expected_artifact_prefix" , "expected_tags" ),
159
160
[
160
- ({}, None , _DEFAULT_ARTIFACT_PREFIX ),
161
+ ({}, None , _DEFAULT_ARTIFACT_PREFIX , None ),
161
162
(
162
163
{
163
164
"artifact_bucket" : TEST_ARTIFACT_BUCKET ,
164
165
"artifact_prefix" : TEST_ARTIFACT_PREFIX ,
166
+ "tags" : TEST_TAGS ,
165
167
},
166
168
TEST_ARTIFACT_BUCKET ,
167
169
TEST_ARTIFACT_PREFIX ,
170
+ TEST_TAGS ,
168
171
),
169
172
],
170
173
)
171
174
@patch .object (_TrialComponent , "save" , MagicMock (return_value = None ))
172
- @patch (
173
- "sagemaker.experiments.run.Experiment._load_or_create" ,
174
- MagicMock (return_value = Experiment (experiment_name = TEST_EXP_NAME )),
175
- )
176
175
@patch (
177
176
"sagemaker.experiments.run._Trial._load_or_create" ,
178
177
MagicMock (side_effect = mock_trial_load_or_create_func ),
@@ -189,6 +188,7 @@ def test_run_load_no_run_name_and_in_train_job(
189
188
kwargs ,
190
189
expected_artifact_bucket ,
191
190
expected_artifact_prefix ,
191
+ expected_tags ,
192
192
):
193
193
client = sagemaker_session .sagemaker_client
194
194
job_name = "my-train-job"
@@ -213,26 +213,32 @@ def test_run_load_no_run_name_and_in_train_job(
213
213
{
214
214
"TrialComponent" : {
215
215
"Parents" : [
216
- {"ExperimentName" : TEST_EXP_NAME , "TrialName" : exp_config [TRIAL_NAME ]}
216
+ {
217
+ "ExperimentName" : TEST_EXP_NAME ,
218
+ "TrialName" : exp_config [TRIAL_NAME ],
219
+ }
217
220
],
218
221
"TrialComponentName" : expected_tc_name ,
219
222
}
220
223
}
221
224
]
222
225
}
223
- with load_run (sagemaker_session = sagemaker_session , ** kwargs ) as run_obj :
224
- assert run_obj ._in_load
225
- assert not run_obj ._inside_init_context
226
- assert run_obj ._inside_load_context
227
- assert run_obj .run_name == TEST_RUN_NAME
228
- assert run_obj ._trial_component .trial_component_name == expected_tc_name
229
- assert run_obj .run_group_name == Run ._generate_trial_name (TEST_EXP_NAME )
230
- assert run_obj ._trial
231
- assert run_obj .experiment_name == TEST_EXP_NAME
232
- assert run_obj ._experiment
233
- assert run_obj .experiment_config == exp_config
234
- assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
235
- assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
226
+ expmock = MagicMock (return_value = Experiment (experiment_name = TEST_EXP_NAME , tags = expected_tags ))
227
+ with patch ("sagemaker.experiments.run.Experiment._load_or_create" , expmock ):
228
+ with load_run (sagemaker_session = sagemaker_session , ** kwargs ) as run_obj :
229
+ assert run_obj ._in_load
230
+ assert not run_obj ._inside_init_context
231
+ assert run_obj ._inside_load_context
232
+ assert run_obj .run_name == TEST_RUN_NAME
233
+ assert run_obj ._trial_component .trial_component_name == expected_tc_name
234
+ assert run_obj .run_group_name == Run ._generate_trial_name (TEST_EXP_NAME )
235
+ assert run_obj ._trial
236
+ assert run_obj .experiment_name == TEST_EXP_NAME
237
+ assert run_obj ._experiment
238
+ assert run_obj .experiment_config == exp_config
239
+ assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
240
+ assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
241
+ assert run_obj ._experiment .tags == expected_tags
236
242
237
243
client .describe_training_job .assert_called_once_with (TrainingJobName = job_name )
238
244
run_obj ._trial .add_trial_component .assert_not_called ()
@@ -265,7 +271,9 @@ def test_run_load_no_run_name_and_not_in_train_job(run_obj, sagemaker_session):
265
271
assert run_obj == run
266
272
267
273
268
- def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context (sagemaker_session ):
274
+ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context (
275
+ sagemaker_session ,
276
+ ):
269
277
with pytest .raises (RuntimeError ) as err :
270
278
with load_run (sagemaker_session = sagemaker_session ):
271
279
pass
@@ -388,7 +396,10 @@ def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session):
388
396
{
389
397
"TrialComponent" : {
390
398
"Parents" : [
391
- {"ExperimentName" : TEST_EXP_NAME , "TrialName" : exp_config [TRIAL_NAME ]}
399
+ {
400
+ "ExperimentName" : TEST_EXP_NAME ,
401
+ "TrialName" : exp_config [TRIAL_NAME ],
402
+ }
392
403
],
393
404
"TrialComponentName" : expected_tc_name ,
394
405
}
@@ -442,7 +453,10 @@ def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session):
442
453
{
443
454
"TrialComponent" : {
444
455
"Parents" : [
445
- {"ExperimentName" : TEST_EXP_NAME , "TrialName" : exp_config [TRIAL_NAME ]}
456
+ {
457
+ "ExperimentName" : TEST_EXP_NAME ,
458
+ "TrialName" : exp_config [TRIAL_NAME ],
459
+ }
446
460
],
447
461
"TrialComponentName" : expected_tc_name ,
448
462
}
@@ -589,7 +603,10 @@ def test_log_output_artifact_outside_run_context(run_obj):
589
603
590
604
591
605
def test_log_output_artifact (run_obj ):
592
- run_obj ._artifact_uploader .upload_artifact .return_value = ("s3uri_value" , "etag_value" )
606
+ run_obj ._artifact_uploader .upload_artifact .return_value = (
607
+ "s3uri_value" ,
608
+ "etag_value" ,
609
+ )
593
610
with run_obj :
594
611
run_obj .log_file ("foo.txt" , "name" , "whizz/bang" )
595
612
run_obj ._artifact_uploader .upload_artifact .assert_called_with ("foo.txt" , extra_args = None )
@@ -608,7 +625,10 @@ def test_log_input_artifact_outside_run_context(run_obj):
608
625
609
626
610
627
def test_log_input_artifact (run_obj ):
611
- run_obj ._artifact_uploader .upload_artifact .return_value = ("s3uri_value" , "etag_value" )
628
+ run_obj ._artifact_uploader .upload_artifact .return_value = (
629
+ "s3uri_value" ,
630
+ "etag_value" ,
631
+ )
612
632
with run_obj :
613
633
run_obj .log_file ("foo.txt" , "name" , "whizz/bang" , is_output = False )
614
634
run_obj ._artifact_uploader .upload_artifact .assert_called_with ("foo.txt" , extra_args = None )
@@ -653,7 +673,10 @@ def test_log_multiple_input_artifacts(run_obj):
653
673
"etag_value" + str (index ),
654
674
)
655
675
run_obj .log_file (
656
- file_path , "name" + str (index ), "whizz/bang" + str (index ), is_output = False
676
+ file_path ,
677
+ "name" + str (index ),
678
+ "whizz/bang" + str (index ),
679
+ is_output = False ,
657
680
)
658
681
run_obj ._artifact_uploader .upload_artifact .assert_called_with (
659
682
file_path , extra_args = None
@@ -757,7 +780,12 @@ def test_log_precision_recall_invalid_input(run_obj):
757
780
with run_obj :
758
781
with pytest .raises (ValueError ) as error :
759
782
run_obj .log_precision_recall (
760
- y_true , y_scores , 0 , title = "TestPrecisionRecall" , no_skill = no_skill , is_output = False
783
+ y_true ,
784
+ y_scores ,
785
+ 0 ,
786
+ title = "TestPrecisionRecall" ,
787
+ no_skill = no_skill ,
788
+ is_output = False ,
761
789
)
762
790
assert "Lengths mismatch between true labels and predicted probabilities" in str (error )
763
791
@@ -905,7 +933,8 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
905
933
display_name = "C" + str (i ),
906
934
source_arn = "D" + str (i ),
907
935
status = TrialComponentStatus (
908
- primary_status = _TrialComponentStatusType .InProgress .value , message = "E" + str (i )
936
+ primary_status = _TrialComponentStatusType .InProgress .value ,
937
+ message = "E" + str (i ),
909
938
),
910
939
start_time = start_time + datetime .timedelta (hours = i ),
911
940
end_time = end_time + datetime .timedelta (hours = i ),
@@ -925,7 +954,8 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
925
954
display_name = "C" + str (i ),
926
955
source_arn = "D" + str (i ),
927
956
status = TrialComponentStatus (
928
- primary_status = _TrialComponentStatusType .InProgress .value , message = "E" + str (i )
957
+ primary_status = _TrialComponentStatusType .InProgress .value ,
958
+ message = "E" + str (i ),
929
959
),
930
960
start_time = start_time + datetime .timedelta (hours = i ),
931
961
end_time = end_time + datetime .timedelta (hours = i ),
0 commit comments