Skip to content

Commit ab1ac37

Browse files
committed
Formatted files
1 parent b3276e9 commit ab1ac37

File tree

2 files changed

+39
-121
lines changed

2 files changed

+39
-121
lines changed

src/sagemaker/experiments/run.py

+9-27
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,7 @@ def log_precision_recall(
342342
if positive_label is not None:
343343
kwargs["pos_label"] = positive_label
344344

345-
precision, recall, _ = precision_recall_curve(
346-
y_true, predicted_probabilities, **kwargs
347-
)
345+
precision, recall, _ = precision_recall_curve(y_true, predicted_probabilities, **kwargs)
348346

349347
kwargs["average"] = "micro"
350348
ap = average_precision_score(y_true, predicted_probabilities, **kwargs)
@@ -564,9 +562,7 @@ def _is_input_valid(input_type, field_name, field_value) -> bool:
564562
field_name (str): The name of the field to be checked.
565563
field_value (str or int or float): The value of the field to be checked.
566564
"""
567-
if isinstance(field_value, Number) and (
568-
isnan(field_value) or isinf(field_value)
569-
):
565+
if isinstance(field_value, Number) and (isnan(field_value) or isinf(field_value)):
570566
logger.warning(
571567
"Failed to log %s %s. Received invalid value: %s.",
572568
input_type,
@@ -628,14 +624,10 @@ def _verify_trial_component_artifacts_length(self, is_output):
628624
err_msg_template = "Cannot add more than {} {}_artifacts under run"
629625
if is_output:
630626
if len(self._trial_component.output_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN:
631-
raise ValueError(
632-
err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "output")
633-
)
627+
raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "output"))
634628
else:
635629
if len(self._trial_component.input_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN:
636-
raise ValueError(
637-
err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "input")
638-
)
630+
raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "input"))
639631

640632
@staticmethod
641633
def _generate_trial_component_name(run_name: str, experiment_name: str) -> str:
@@ -656,28 +648,20 @@ def _generate_trial_component_name(run_name: str, experiment_name: str) -> str:
656648
"""
657649
buffer = 1 # leave length buffers for delimiters
658650
max_len = int(MAX_NAME_LEN_IN_BACKEND / 2) - buffer
659-
err_msg_template = (
660-
"The {} (length: {}) must have length less than or equal to {}"
661-
)
651+
err_msg_template = "The {} (length: {}) must have length less than or equal to {}"
662652
if len(run_name) > max_len:
663-
raise ValueError(
664-
err_msg_template.format("run_name", len(run_name), max_len)
665-
)
653+
raise ValueError(err_msg_template.format("run_name", len(run_name), max_len))
666654
if len(experiment_name) > max_len:
667655
raise ValueError(
668-
err_msg_template.format(
669-
"experiment_name", len(experiment_name), max_len
670-
)
656+
err_msg_template.format("experiment_name", len(experiment_name), max_len)
671657
)
672658
trial_component_name = "{}{}{}".format(experiment_name, DELIMITER, run_name)
673659
# due to mixed-case concerns on the backend
674660
trial_component_name = trial_component_name.lower()
675661
return trial_component_name
676662

677663
@staticmethod
678-
def _extract_run_name_from_tc_name(
679-
trial_component_name: str, experiment_name: str
680-
) -> str:
664+
def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: str) -> str:
681665
"""Extract the user supplied run name from a trial component name.
682666
683667
Args:
@@ -694,9 +678,7 @@ def _extract_run_name_from_tc_name(
694678
)
695679

696680
@staticmethod
697-
def _append_run_tc_label_to_tags(
698-
tags: Optional[List[Dict[str, str]]] = None
699-
) -> list:
681+
def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list:
700682
"""Append the run trial component label to tags used to create a trial component.
701683
702684
Args:

tests/unit/sagemaker/experiments/test_run.py

+30-94
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,8 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
150150
sagemaker_session=sagemaker_session,
151151
)
152152

153-
assert (
154-
f"The run_name (length: {MAX_NAME_LEN_IN_BACKEND}) must have length less than"
155-
in str(err)
153+
assert f"The run_name (length: {MAX_NAME_LEN_IN_BACKEND}) must have length less than" in str(
154+
err
156155
)
157156

158157

@@ -224,9 +223,7 @@ def test_run_load_no_run_name_and_in_train_job(
224223
}
225224
]
226225
}
227-
expmock = MagicMock(
228-
return_value=Experiment(experiment_name=TEST_EXP_NAME, tags=expected_tags)
229-
)
226+
expmock = MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME, tags=expected_tags))
230227
with patch("sagemaker.experiments.run.Experiment._load_or_create", expmock):
231228
with load_run(sagemaker_session=sagemaker_session, **kwargs) as run_obj:
232229
assert run_obj._in_load
@@ -239,12 +236,8 @@ def test_run_load_no_run_name_and_in_train_job(
239236
assert run_obj.experiment_name == TEST_EXP_NAME
240237
assert run_obj._experiment
241238
assert run_obj.experiment_config == exp_config
242-
assert (
243-
run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
244-
)
245-
assert (
246-
run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix
247-
)
239+
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
240+
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix
248241
assert run_obj._experiment.tags == expected_tags
249242

250243
client.describe_training_job.assert_called_once_with(TrainingJobName=job_name)
@@ -269,9 +262,7 @@ def test_run_load_no_run_name_and_in_train_job_but_fail_to_get_exp_cfg(
269262
with load_run(sagemaker_session=sagemaker_session):
270263
pass
271264

272-
assert "Not able to fetch RunName in ExperimentConfig of the sagemaker job" in str(
273-
err
274-
)
265+
assert "Not able to fetch RunName in ExperimentConfig of the sagemaker job" in str(err)
275266

276267

277268
def test_run_load_no_run_name_and_not_in_train_job(run_obj, sagemaker_session):
@@ -291,9 +282,7 @@ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(
291282

292283
# experiment_name is given but is not supplied along with the run_name so it's ignored.
293284
with pytest.raises(RuntimeError) as err:
294-
with load_run(
295-
experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session
296-
):
285+
with load_run(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session):
297286
pass
298287

299288
assert "Failed to load a Run object" in str(err)
@@ -621,17 +610,12 @@ def test_log_output_artifact(run_obj):
621610
with run_obj:
622611
run_obj.log_file("foo.txt", "name", "whizz/bang")
623612
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
624-
assert (
625-
"whizz/bang" == run_obj._trial_component.output_artifacts["name"].media_type
626-
)
613+
assert "whizz/bang" == run_obj._trial_component.output_artifacts["name"].media_type
627614

628615
run_obj.log_file("foo.txt")
629616
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
630617
assert "foo.txt" in run_obj._trial_component.output_artifacts
631-
assert (
632-
"text/plain"
633-
== run_obj._trial_component.output_artifacts["foo.txt"].media_type
634-
)
618+
assert "text/plain" == run_obj._trial_component.output_artifacts["foo.txt"].media_type
635619

636620

637621
def test_log_input_artifact_outside_run_context(run_obj):
@@ -648,51 +632,36 @@ def test_log_input_artifact(run_obj):
648632
with run_obj:
649633
run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False)
650634
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
651-
assert (
652-
"whizz/bang" == run_obj._trial_component.input_artifacts["name"].media_type
653-
)
635+
assert "whizz/bang" == run_obj._trial_component.input_artifacts["name"].media_type
654636

655637
run_obj.log_file("foo.txt", is_output=False)
656638
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
657639
assert "foo.txt" in run_obj._trial_component.input_artifacts
658-
assert (
659-
"text/plain"
660-
== run_obj._trial_component.input_artifacts["foo.txt"].media_type
661-
)
640+
assert "text/plain" == run_obj._trial_component.input_artifacts["foo.txt"].media_type
662641

663642

664643
def test_log_multiple_inputs(run_obj):
665644
with run_obj:
666645
for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN):
667646
file_path = "foo" + str(index) + ".txt"
668647
run_obj._trial_component.input_artifacts[file_path] = {
669-
"foo": TrialComponentArtifact(
670-
value="baz" + str(index), media_type="text/text"
671-
)
648+
"foo": TrialComponentArtifact(value="baz" + str(index), media_type="text/text")
672649
}
673650
with pytest.raises(ValueError) as error:
674651
run_obj.log_artifact("foo.txt", "name", "whizz/bang", False)
675-
assert (
676-
f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} input_artifacts"
677-
in str(error)
678-
)
652+
assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} input_artifacts" in str(error)
679653

680654

681655
def test_log_multiple_outputs(run_obj):
682656
with run_obj:
683657
for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN):
684658
file_path = "foo" + str(index) + ".txt"
685659
run_obj._trial_component.output_artifacts[file_path] = {
686-
"foo": TrialComponentArtifact(
687-
value="baz" + str(index), media_type="text/text"
688-
)
660+
"foo": TrialComponentArtifact(value="baz" + str(index), media_type="text/text")
689661
}
690662
with pytest.raises(ValueError) as error:
691663
run_obj.log_artifact("foo.txt", "name", "whizz/bang")
692-
assert (
693-
f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} output_artifacts"
694-
in str(error)
695-
)
664+
assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} output_artifacts" in str(error)
696665

697666

698667
def test_log_multiple_input_artifacts(run_obj):
@@ -722,10 +691,7 @@ def test_log_multiple_input_artifacts(run_obj):
722691
# log an extra input artifact, should raise exception
723692
with pytest.raises(ValueError) as error:
724693
run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False)
725-
assert (
726-
f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} input_artifacts"
727-
in str(error)
728-
)
694+
assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} input_artifacts" in str(error)
729695

730696

731697
def test_log_multiple_output_artifacts(run_obj):
@@ -750,10 +716,7 @@ def test_log_multiple_output_artifacts(run_obj):
750716
# log an extra output artifact, should raise exception
751717
with pytest.raises(ValueError) as error:
752718
run_obj.log_file("foo.txt", "name", "whizz/bang")
753-
assert (
754-
f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} output_artifacts"
755-
in str(error)
756-
)
719+
assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} output_artifacts" in str(error)
757720

758721

759722
def test_log_precision_recall_outside_run_context(run_obj):
@@ -820,10 +783,7 @@ def test_log_precision_recall_invalid_input(run_obj):
820783
no_skill=no_skill,
821784
is_output=False,
822785
)
823-
assert (
824-
"Lengths mismatch between true labels and predicted probabilities"
825-
in str(error)
826-
)
786+
assert "Lengths mismatch between true labels and predicted probabilities" in str(error)
827787

828788

829789
def test_log_confusion_matrix_outside_run_context(run_obj):
@@ -921,9 +881,7 @@ def test_log_roc_curve_invalid_input(run_obj):
921881

922882
with run_obj:
923883
with pytest.raises(ValueError) as error:
924-
run_obj.log_roc_curve(
925-
y_true, y_scores, title="TestROCCurve", is_output=False
926-
)
884+
run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False)
927885
assert "Lengths mismatch between true labels and predicted scores" in str(error)
928886

929887

@@ -940,18 +898,10 @@ def test_log_roc_curve_invalid_input(run_obj):
940898
@patch("sagemaker.experiments.run._TrialComponent.list")
941899
@patch("sagemaker.experiments.run._TrialComponent.search")
942900
def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_session):
943-
start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
944-
hours=1
945-
)
946-
end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
947-
hours=2
948-
)
949-
creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
950-
hours=3
951-
)
952-
last_modified_time = datetime.datetime.now(
953-
datetime.timezone.utc
954-
) + datetime.timedelta(hours=4)
901+
start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1)
902+
end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=2)
903+
creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=3)
904+
last_modified_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=4)
955905
tc_list_len = 20
956906
tc_list_len_half = int(tc_list_len / 2)
957907
mock_tc_search.side_effect = [
@@ -1039,9 +989,8 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
1039989
assert run._experiment
1040990
assert run._trial
1041991
assert isinstance(run._trial_component, _TrialComponent)
1042-
assert (
1043-
run._trial_component.trial_component_name
1044-
== Run._generate_trial_component_name("a" + str(i), TEST_EXP_NAME)
992+
assert run._trial_component.trial_component_name == Run._generate_trial_component_name(
993+
"a" + str(i), TEST_EXP_NAME
1045994
)
1046995
assert run._in_load is False
1047996
assert run._inside_load_context is False
@@ -1054,9 +1003,7 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
10541003
@patch("sagemaker.experiments.run._TrialComponent.list")
10551004
def test_list_empty(mock_tc_list, sagemaker_session):
10561005
mock_tc_list.return_value = []
1057-
assert [] == list_runs(
1058-
experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session
1059-
)
1006+
assert [] == list_runs(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session)
10601007

10611008

10621009
@patch(
@@ -1122,10 +1069,7 @@ def test_exit_fail(sagemaker_session, run_obj):
11221069
except ValueError:
11231070
pass
11241071

1125-
assert (
1126-
run_obj._trial_component.status.primary_status
1127-
== _TrialComponentStatusType.Failed.value
1128-
)
1072+
assert run_obj._trial_component.status.primary_status == _TrialComponentStatusType.Failed.value
11291073
assert run_obj._trial_component.status.message
11301074
assert isinstance(run_obj._trial_component.end_time, datetime.datetime)
11311075

@@ -1182,9 +1126,7 @@ def _verify_tc_status_before_enter_init(trial_component):
11821126
assert not trial_component.status
11831127

11841128

1185-
def _verify_tc_status_when_entering(
1186-
trial_component, init_start_time=None, has_completed=False
1187-
):
1129+
def _verify_tc_status_when_entering(trial_component, init_start_time=None, has_completed=False):
11881130
if not init_start_time:
11891131
assert isinstance(trial_component.start_time, datetime.datetime)
11901132
now = datetime.datetime.now(dateutil.tz.tzlocal())
@@ -1194,17 +1136,11 @@ def _verify_tc_status_when_entering(
11941136

11951137
if not has_completed:
11961138
assert not trial_component.end_time
1197-
assert (
1198-
trial_component.status.primary_status
1199-
== _TrialComponentStatusType.InProgress.value
1200-
)
1139+
assert trial_component.status.primary_status == _TrialComponentStatusType.InProgress.value
12011140

12021141

12031142
def _verify_tc_status_when_successfully_exit(trial_component, old_end_time=None):
1204-
assert (
1205-
trial_component.status.primary_status
1206-
== _TrialComponentStatusType.Completed.value
1207-
)
1143+
assert trial_component.status.primary_status == _TrialComponentStatusType.Completed.value
12081144
assert isinstance(trial_component.start_time, datetime.datetime)
12091145
assert isinstance(trial_component.end_time, datetime.datetime)
12101146
if old_end_time:

0 commit comments

Comments
 (0)