@@ -205,7 +205,9 @@ def __init__(
205
205
)
206
206
207
207
if not _TrialComponent ._trial_component_is_associated_to_trial (
208
- self ._trial_component .trial_component_name , self ._trial .trial_name , sagemaker_session
208
+ self ._trial_component .trial_component_name ,
209
+ self ._trial .trial_name ,
210
+ sagemaker_session ,
209
211
):
210
212
self ._trial .add_trial_component (self ._trial_component )
211
213
@@ -340,7 +342,9 @@ def log_precision_recall(
340
342
if positive_label is not None :
341
343
kwargs ["pos_label" ] = positive_label
342
344
343
- precision , recall , _ = precision_recall_curve (y_true , predicted_probabilities , ** kwargs )
345
+ precision , recall , _ = precision_recall_curve (
346
+ y_true , predicted_probabilities , ** kwargs
347
+ )
344
348
345
349
kwargs ["average" ] = "micro"
346
350
ap = average_precision_score (y_true , predicted_probabilities , ** kwargs )
@@ -560,7 +564,9 @@ def _is_input_valid(input_type, field_name, field_value) -> bool:
560
564
field_name (str): The name of the field to be checked.
561
565
field_value (str or int or float): The value of the field to be checked.
562
566
"""
563
- if isinstance (field_value , Number ) and (isnan (field_value ) or isinf (field_value )):
567
+ if isinstance (field_value , Number ) and (
568
+ isnan (field_value ) or isinf (field_value )
569
+ ):
564
570
logger .warning (
565
571
"Failed to log %s %s. Received invalid value: %s." ,
566
572
input_type ,
@@ -622,10 +628,14 @@ def _verify_trial_component_artifacts_length(self, is_output):
622
628
err_msg_template = "Cannot add more than {} {}_artifacts under run"
623
629
if is_output :
624
630
if len (self ._trial_component .output_artifacts ) >= MAX_RUN_TC_ARTIFACTS_LEN :
625
- raise ValueError (err_msg_template .format (MAX_RUN_TC_ARTIFACTS_LEN , "output" ))
631
+ raise ValueError (
632
+ err_msg_template .format (MAX_RUN_TC_ARTIFACTS_LEN , "output" )
633
+ )
626
634
else :
627
635
if len (self ._trial_component .input_artifacts ) >= MAX_RUN_TC_ARTIFACTS_LEN :
628
- raise ValueError (err_msg_template .format (MAX_RUN_TC_ARTIFACTS_LEN , "input" ))
636
+ raise ValueError (
637
+ err_msg_template .format (MAX_RUN_TC_ARTIFACTS_LEN , "input" )
638
+ )
629
639
630
640
@staticmethod
631
641
def _generate_trial_component_name (run_name : str , experiment_name : str ) -> str :
@@ -646,20 +656,28 @@ def _generate_trial_component_name(run_name: str, experiment_name: str) -> str:
646
656
"""
647
657
buffer = 1 # leave length buffers for delimiters
648
658
max_len = int (MAX_NAME_LEN_IN_BACKEND / 2 ) - buffer
649
- err_msg_template = "The {} (length: {}) must have length less than or equal to {}"
659
+ err_msg_template = (
660
+ "The {} (length: {}) must have length less than or equal to {}"
661
+ )
650
662
if len (run_name ) > max_len :
651
- raise ValueError (err_msg_template .format ("run_name" , len (run_name ), max_len ))
663
+ raise ValueError (
664
+ err_msg_template .format ("run_name" , len (run_name ), max_len )
665
+ )
652
666
if len (experiment_name ) > max_len :
653
667
raise ValueError (
654
- err_msg_template .format ("experiment_name" , len (experiment_name ), max_len )
668
+ err_msg_template .format (
669
+ "experiment_name" , len (experiment_name ), max_len
670
+ )
655
671
)
656
672
trial_component_name = "{}{}{}" .format (experiment_name , DELIMITER , run_name )
657
673
# due to mixed-case concerns on the backend
658
674
trial_component_name = trial_component_name .lower ()
659
675
return trial_component_name
660
676
661
677
@staticmethod
662
- def _extract_run_name_from_tc_name (trial_component_name : str , experiment_name : str ) -> str :
678
+ def _extract_run_name_from_tc_name (
679
+ trial_component_name : str , experiment_name : str
680
+ ) -> str :
663
681
"""Extract the user supplied run name from a trial component name.
664
682
665
683
Args:
@@ -676,7 +694,9 @@ def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: s
676
694
)
677
695
678
696
@staticmethod
679
- def _append_run_tc_label_to_tags (tags : Optional [List [Dict [str , str ]]] = None ) -> list :
697
+ def _append_run_tc_label_to_tags (
698
+ tags : Optional [List [Dict [str , str ]]] = None
699
+ ) -> list :
680
700
"""Append the run trial component label to tags used to create a trial component.
681
701
682
702
Args:
0 commit comments