Skip to content

Commit 952890f

Browse files
committed
formatted files
1 parent 19aba72 commit 952890f

File tree

2 files changed

+166
-55
lines changed

2 files changed

+166
-55
lines changed

src/sagemaker/experiments/run.py

+30-10
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ def __init__(
205205
)
206206

207207
if not _TrialComponent._trial_component_is_associated_to_trial(
208-
self._trial_component.trial_component_name, self._trial.trial_name, sagemaker_session
208+
self._trial_component.trial_component_name,
209+
self._trial.trial_name,
210+
sagemaker_session,
209211
):
210212
self._trial.add_trial_component(self._trial_component)
211213

@@ -340,7 +342,9 @@ def log_precision_recall(
340342
if positive_label is not None:
341343
kwargs["pos_label"] = positive_label
342344

343-
precision, recall, _ = precision_recall_curve(y_true, predicted_probabilities, **kwargs)
345+
precision, recall, _ = precision_recall_curve(
346+
y_true, predicted_probabilities, **kwargs
347+
)
344348

345349
kwargs["average"] = "micro"
346350
ap = average_precision_score(y_true, predicted_probabilities, **kwargs)
@@ -560,7 +564,9 @@ def _is_input_valid(input_type, field_name, field_value) -> bool:
560564
field_name (str): The name of the field to be checked.
561565
field_value (str or int or float): The value of the field to be checked.
562566
"""
563-
if isinstance(field_value, Number) and (isnan(field_value) or isinf(field_value)):
567+
if isinstance(field_value, Number) and (
568+
isnan(field_value) or isinf(field_value)
569+
):
564570
logger.warning(
565571
"Failed to log %s %s. Received invalid value: %s.",
566572
input_type,
@@ -622,10 +628,14 @@ def _verify_trial_component_artifacts_length(self, is_output):
622628
err_msg_template = "Cannot add more than {} {}_artifacts under run"
623629
if is_output:
624630
if len(self._trial_component.output_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN:
625-
raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "output"))
631+
raise ValueError(
632+
err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "output")
633+
)
626634
else:
627635
if len(self._trial_component.input_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN:
628-
raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "input"))
636+
raise ValueError(
637+
err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "input")
638+
)
629639

630640
@staticmethod
631641
def _generate_trial_component_name(run_name: str, experiment_name: str) -> str:
@@ -646,20 +656,28 @@ def _generate_trial_component_name(run_name: str, experiment_name: str) -> str:
646656
"""
647657
buffer = 1 # leave length buffers for delimiters
648658
max_len = int(MAX_NAME_LEN_IN_BACKEND / 2) - buffer
649-
err_msg_template = "The {} (length: {}) must have length less than or equal to {}"
659+
err_msg_template = (
660+
"The {} (length: {}) must have length less than or equal to {}"
661+
)
650662
if len(run_name) > max_len:
651-
raise ValueError(err_msg_template.format("run_name", len(run_name), max_len))
663+
raise ValueError(
664+
err_msg_template.format("run_name", len(run_name), max_len)
665+
)
652666
if len(experiment_name) > max_len:
653667
raise ValueError(
654-
err_msg_template.format("experiment_name", len(experiment_name), max_len)
668+
err_msg_template.format(
669+
"experiment_name", len(experiment_name), max_len
670+
)
655671
)
656672
trial_component_name = "{}{}{}".format(experiment_name, DELIMITER, run_name)
657673
# due to mixed-case concerns on the backend
658674
trial_component_name = trial_component_name.lower()
659675
return trial_component_name
660676

661677
@staticmethod
662-
def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: str) -> str:
678+
def _extract_run_name_from_tc_name(
679+
trial_component_name: str, experiment_name: str
680+
) -> str:
663681
"""Extract the user supplied run name from a trial component name.
664682
665683
Args:
@@ -676,7 +694,9 @@ def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: s
676694
)
677695

678696
@staticmethod
679-
def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list:
697+
def _append_run_tc_label_to_tags(
698+
tags: Optional[List[Dict[str, str]]] = None
699+
) -> list:
680700
"""Append the run trial component label to tags used to create a trial component.
681701
682702
Args:

0 commit comments

Comments
 (0)