@@ -150,9 +150,8 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
150
150
sagemaker_session = sagemaker_session ,
151
151
)
152
152
153
- assert (
154
- f"The run_name (length: { MAX_NAME_LEN_IN_BACKEND } ) must have length less than"
155
- in str (err )
153
+ assert f"The run_name (length: { MAX_NAME_LEN_IN_BACKEND } ) must have length less than" in str (
154
+ err
156
155
)
157
156
158
157
@@ -224,9 +223,7 @@ def test_run_load_no_run_name_and_in_train_job(
224
223
}
225
224
]
226
225
}
227
- expmock = MagicMock (
228
- return_value = Experiment (experiment_name = TEST_EXP_NAME , tags = expected_tags )
229
- )
226
+ expmock = MagicMock (return_value = Experiment (experiment_name = TEST_EXP_NAME , tags = expected_tags ))
230
227
with patch ("sagemaker.experiments.run.Experiment._load_or_create" , expmock ):
231
228
with load_run (sagemaker_session = sagemaker_session , ** kwargs ) as run_obj :
232
229
assert run_obj ._in_load
@@ -239,12 +236,8 @@ def test_run_load_no_run_name_and_in_train_job(
239
236
assert run_obj .experiment_name == TEST_EXP_NAME
240
237
assert run_obj ._experiment
241
238
assert run_obj .experiment_config == exp_config
242
- assert (
243
- run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
244
- )
245
- assert (
246
- run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
247
- )
239
+ assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
240
+ assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
248
241
assert run_obj ._experiment .tags == expected_tags
249
242
250
243
client .describe_training_job .assert_called_once_with (TrainingJobName = job_name )
@@ -269,9 +262,7 @@ def test_run_load_no_run_name_and_in_train_job_but_fail_to_get_exp_cfg(
269
262
with load_run (sagemaker_session = sagemaker_session ):
270
263
pass
271
264
272
- assert "Not able to fetch RunName in ExperimentConfig of the sagemaker job" in str (
273
- err
274
- )
265
+ assert "Not able to fetch RunName in ExperimentConfig of the sagemaker job" in str (err )
275
266
276
267
277
268
def test_run_load_no_run_name_and_not_in_train_job (run_obj , sagemaker_session ):
@@ -291,9 +282,7 @@ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(
291
282
292
283
# experiment_name is given but is not supplied along with the run_name so it's ignored.
293
284
with pytest .raises (RuntimeError ) as err :
294
- with load_run (
295
- experiment_name = TEST_EXP_NAME , sagemaker_session = sagemaker_session
296
- ):
285
+ with load_run (experiment_name = TEST_EXP_NAME , sagemaker_session = sagemaker_session ):
297
286
pass
298
287
299
288
assert "Failed to load a Run object" in str (err )
@@ -621,17 +610,12 @@ def test_log_output_artifact(run_obj):
621
610
with run_obj :
622
611
run_obj .log_file ("foo.txt" , "name" , "whizz/bang" )
623
612
run_obj ._artifact_uploader .upload_artifact .assert_called_with ("foo.txt" )
624
- assert (
625
- "whizz/bang" == run_obj ._trial_component .output_artifacts ["name" ].media_type
626
- )
613
+ assert "whizz/bang" == run_obj ._trial_component .output_artifacts ["name" ].media_type
627
614
628
615
run_obj .log_file ("foo.txt" )
629
616
run_obj ._artifact_uploader .upload_artifact .assert_called_with ("foo.txt" )
630
617
assert "foo.txt" in run_obj ._trial_component .output_artifacts
631
- assert (
632
- "text/plain"
633
- == run_obj ._trial_component .output_artifacts ["foo.txt" ].media_type
634
- )
618
+ assert "text/plain" == run_obj ._trial_component .output_artifacts ["foo.txt" ].media_type
635
619
636
620
637
621
def test_log_input_artifact_outside_run_context (run_obj ):
@@ -648,51 +632,36 @@ def test_log_input_artifact(run_obj):
648
632
with run_obj :
649
633
run_obj .log_file ("foo.txt" , "name" , "whizz/bang" , is_output = False )
650
634
run_obj ._artifact_uploader .upload_artifact .assert_called_with ("foo.txt" )
651
- assert (
652
- "whizz/bang" == run_obj ._trial_component .input_artifacts ["name" ].media_type
653
- )
635
+ assert "whizz/bang" == run_obj ._trial_component .input_artifacts ["name" ].media_type
654
636
655
637
run_obj .log_file ("foo.txt" , is_output = False )
656
638
run_obj ._artifact_uploader .upload_artifact .assert_called_with ("foo.txt" )
657
639
assert "foo.txt" in run_obj ._trial_component .input_artifacts
658
- assert (
659
- "text/plain"
660
- == run_obj ._trial_component .input_artifacts ["foo.txt" ].media_type
661
- )
640
+ assert "text/plain" == run_obj ._trial_component .input_artifacts ["foo.txt" ].media_type
662
641
663
642
664
643
def test_log_multiple_inputs (run_obj ):
665
644
with run_obj :
666
645
for index in range (0 , MAX_RUN_TC_ARTIFACTS_LEN ):
667
646
file_path = "foo" + str (index ) + ".txt"
668
647
run_obj ._trial_component .input_artifacts [file_path ] = {
669
- "foo" : TrialComponentArtifact (
670
- value = "baz" + str (index ), media_type = "text/text"
671
- )
648
+ "foo" : TrialComponentArtifact (value = "baz" + str (index ), media_type = "text/text" )
672
649
}
673
650
with pytest .raises (ValueError ) as error :
674
651
run_obj .log_artifact ("foo.txt" , "name" , "whizz/bang" , False )
675
- assert (
676
- f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } input_artifacts"
677
- in str (error )
678
- )
652
+ assert f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } input_artifacts" in str (error )
679
653
680
654
681
655
def test_log_multiple_outputs (run_obj ):
682
656
with run_obj :
683
657
for index in range (0 , MAX_RUN_TC_ARTIFACTS_LEN ):
684
658
file_path = "foo" + str (index ) + ".txt"
685
659
run_obj ._trial_component .output_artifacts [file_path ] = {
686
- "foo" : TrialComponentArtifact (
687
- value = "baz" + str (index ), media_type = "text/text"
688
- )
660
+ "foo" : TrialComponentArtifact (value = "baz" + str (index ), media_type = "text/text" )
689
661
}
690
662
with pytest .raises (ValueError ) as error :
691
663
run_obj .log_artifact ("foo.txt" , "name" , "whizz/bang" )
692
- assert (
693
- f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } output_artifacts"
694
- in str (error )
695
- )
664
+ assert f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } output_artifacts" in str (error )
696
665
697
666
698
667
def test_log_multiple_input_artifacts (run_obj ):
@@ -722,10 +691,7 @@ def test_log_multiple_input_artifacts(run_obj):
722
691
# log an extra input artifact, should raise exception
723
692
with pytest .raises (ValueError ) as error :
724
693
run_obj .log_file ("foo.txt" , "name" , "whizz/bang" , is_output = False )
725
- assert (
726
- f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } input_artifacts"
727
- in str (error )
728
- )
694
+ assert f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } input_artifacts" in str (error )
729
695
730
696
731
697
def test_log_multiple_output_artifacts (run_obj ):
@@ -750,10 +716,7 @@ def test_log_multiple_output_artifacts(run_obj):
750
716
# log an extra output artifact, should raise exception
751
717
with pytest .raises (ValueError ) as error :
752
718
run_obj .log_file ("foo.txt" , "name" , "whizz/bang" )
753
- assert (
754
- f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } output_artifacts"
755
- in str (error )
756
- )
719
+ assert f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } output_artifacts" in str (error )
757
720
758
721
759
722
def test_log_precision_recall_outside_run_context (run_obj ):
@@ -820,10 +783,7 @@ def test_log_precision_recall_invalid_input(run_obj):
820
783
no_skill = no_skill ,
821
784
is_output = False ,
822
785
)
823
- assert (
824
- "Lengths mismatch between true labels and predicted probabilities"
825
- in str (error )
826
- )
786
+ assert "Lengths mismatch between true labels and predicted probabilities" in str (error )
827
787
828
788
829
789
def test_log_confusion_matrix_outside_run_context (run_obj ):
@@ -921,9 +881,7 @@ def test_log_roc_curve_invalid_input(run_obj):
921
881
922
882
with run_obj :
923
883
with pytest .raises (ValueError ) as error :
924
- run_obj .log_roc_curve (
925
- y_true , y_scores , title = "TestROCCurve" , is_output = False
926
- )
884
+ run_obj .log_roc_curve (y_true , y_scores , title = "TestROCCurve" , is_output = False )
927
885
assert "Lengths mismatch between true labels and predicted scores" in str (error )
928
886
929
887
@@ -940,18 +898,10 @@ def test_log_roc_curve_invalid_input(run_obj):
940
898
@patch ("sagemaker.experiments.run._TrialComponent.list" )
941
899
@patch ("sagemaker.experiments.run._TrialComponent.search" )
942
900
def test_list (mock_tc_search , mock_tc_list , mock_tc_load , run_obj , sagemaker_session ):
943
- start_time = datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (
944
- hours = 1
945
- )
946
- end_time = datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (
947
- hours = 2
948
- )
949
- creation_time = datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (
950
- hours = 3
951
- )
952
- last_modified_time = datetime .datetime .now (
953
- datetime .timezone .utc
954
- ) + datetime .timedelta (hours = 4 )
901
+ start_time = datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (hours = 1 )
902
+ end_time = datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (hours = 2 )
903
+ creation_time = datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (hours = 3 )
904
+ last_modified_time = datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (hours = 4 )
955
905
tc_list_len = 20
956
906
tc_list_len_half = int (tc_list_len / 2 )
957
907
mock_tc_search .side_effect = [
@@ -1039,9 +989,8 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
1039
989
assert run ._experiment
1040
990
assert run ._trial
1041
991
assert isinstance (run ._trial_component , _TrialComponent )
1042
- assert (
1043
- run ._trial_component .trial_component_name
1044
- == Run ._generate_trial_component_name ("a" + str (i ), TEST_EXP_NAME )
992
+ assert run ._trial_component .trial_component_name == Run ._generate_trial_component_name (
993
+ "a" + str (i ), TEST_EXP_NAME
1045
994
)
1046
995
assert run ._in_load is False
1047
996
assert run ._inside_load_context is False
@@ -1054,9 +1003,7 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
1054
1003
@patch ("sagemaker.experiments.run._TrialComponent.list" )
1055
1004
def test_list_empty (mock_tc_list , sagemaker_session ):
1056
1005
mock_tc_list .return_value = []
1057
- assert [] == list_runs (
1058
- experiment_name = TEST_EXP_NAME , sagemaker_session = sagemaker_session
1059
- )
1006
+ assert [] == list_runs (experiment_name = TEST_EXP_NAME , sagemaker_session = sagemaker_session )
1060
1007
1061
1008
1062
1009
@patch (
@@ -1122,10 +1069,7 @@ def test_exit_fail(sagemaker_session, run_obj):
1122
1069
except ValueError :
1123
1070
pass
1124
1071
1125
- assert (
1126
- run_obj ._trial_component .status .primary_status
1127
- == _TrialComponentStatusType .Failed .value
1128
- )
1072
+ assert run_obj ._trial_component .status .primary_status == _TrialComponentStatusType .Failed .value
1129
1073
assert run_obj ._trial_component .status .message
1130
1074
assert isinstance (run_obj ._trial_component .end_time , datetime .datetime )
1131
1075
@@ -1182,9 +1126,7 @@ def _verify_tc_status_before_enter_init(trial_component):
1182
1126
assert not trial_component .status
1183
1127
1184
1128
1185
- def _verify_tc_status_when_entering (
1186
- trial_component , init_start_time = None , has_completed = False
1187
- ):
1129
+ def _verify_tc_status_when_entering (trial_component , init_start_time = None , has_completed = False ):
1188
1130
if not init_start_time :
1189
1131
assert isinstance (trial_component .start_time , datetime .datetime )
1190
1132
now = datetime .datetime .now (dateutil .tz .tzlocal ())
@@ -1194,17 +1136,11 @@ def _verify_tc_status_when_entering(
1194
1136
1195
1137
if not has_completed :
1196
1138
assert not trial_component .end_time
1197
- assert (
1198
- trial_component .status .primary_status
1199
- == _TrialComponentStatusType .InProgress .value
1200
- )
1139
+ assert trial_component .status .primary_status == _TrialComponentStatusType .InProgress .value
1201
1140
1202
1141
1203
1142
def _verify_tc_status_when_successfully_exit (trial_component , old_end_time = None ):
1204
- assert (
1205
- trial_component .status .primary_status
1206
- == _TrialComponentStatusType .Completed .value
1207
- )
1143
+ assert trial_component .status .primary_status == _TrialComponentStatusType .Completed .value
1208
1144
assert isinstance (trial_component .start_time , datetime .datetime )
1209
1145
assert isinstance (trial_component .end_time , datetime .datetime )
1210
1146
if old_end_time :
0 commit comments