@@ -158,12 +158,18 @@ def init(
158
158
sagemaker_session = sagemaker_session ,
159
159
)
160
160
161
- run_tc = _TrialComponent ._load_or_create (
161
+ run_tc , is_existed = _TrialComponent ._load_or_create (
162
162
trial_component_name = trial_component_name ,
163
163
display_name = run_display_name ,
164
164
tags = Run ._append_run_tc_label_to_tags (tags ),
165
165
sagemaker_session = sagemaker_session ,
166
166
)
167
+ if is_existed :
168
+ logger .warning (
169
+ "The Run (%s) under experiment (%s) already exists. Loading it." ,
170
+ run_name ,
171
+ experiment_name ,
172
+ )
167
173
168
174
trial .add_trial_component (run_tc )
169
175
@@ -184,12 +190,12 @@ def load(
184
190
experiment_name : Optional [str ] = None ,
185
191
sagemaker_session : Optional ["Session" ] = None ,
186
192
):
187
- """Load a Run Trial Component by the run name or from the job environment.
193
+ """Load a Run by the run name or from the job environment.
188
194
189
195
Args:
190
196
run_name (str): The name of the Run to be loaded (default: None).
191
197
If it is None, the `RunName` in the `ExperimentConfig` of the job will be
192
- fetched to load the Run Trial Component .
198
+ fetched to load the Run.
193
199
experiment_name (str): The name of the Experiment that the to be loaded Run
194
200
is associated with (default: None).
195
201
Note: the experiment_name must be supplied along with a valid run_name.
@@ -253,7 +259,7 @@ def _experiment_config(self):
253
259
254
260
@validate_invoked_inside_run_context
255
261
def log_parameter (self , name , value ):
256
- """Record a single parameter value for this run trial component .
262
+ """Record a single parameter value for this run.
257
263
258
264
Overwrites any previous value recorded for the specified parameter name.
259
265
@@ -266,7 +272,7 @@ def log_parameter(self, name, value):
266
272
267
273
@validate_invoked_inside_run_context
268
274
def log_parameters (self , parameters ):
269
- """Record a collection of parameter values for this run trial component .
275
+ """Record a collection of parameter values for this run.
270
276
271
277
Args:
272
278
parameters (dict[str, str or numbers.Number]): The parameters to record.
@@ -280,7 +286,7 @@ def log_parameters(self, parameters):
280
286
281
287
@validate_invoked_inside_run_context
282
288
def log_metric (self , name , value , timestamp = None , step = None ):
283
- """Record a custom scalar metric value for this run trial component .
289
+ """Record a custom scalar metric value for this run.
284
290
285
291
Note:
286
292
1. This method is for manual custom metrics, for automatic metrics see the
@@ -313,9 +319,9 @@ def log_precision_recall(
313
319
"""Create and log a precision recall graph artifact for Studio UI to render.
314
320
315
321
The artifact is stored in S3 and represented as a lineage artifact
316
- with an association with the run trial component .
322
+ with an association with the run.
317
323
318
- You can view the artifact in the charts tab of the Trial Component UI.
324
+ You can view the artifact in the UI.
319
325
If your job is created by a pipeline execution you can view the artifact
320
326
by selecting the corresponding step in the pipelines UI.
321
327
See also `SageMaker Pipelines <https://aws.amazon.com/sagemaker/pipelines/>`_
@@ -329,7 +335,7 @@ def log_precision_recall(
329
335
positive_label (str or int): Label of the positive class (default: None).
330
336
title (str): Title of the graph (default: None).
331
337
is_output (bool): Determines direction of association to the
332
- trial component . Defaults to True (output artifact).
338
+ run . Defaults to True (output artifact).
333
339
If set to False then represented as input association.
334
340
no_skill (int): The precision threshold under which the classifier cannot discriminate
335
341
between the classes and would predict a random class or a constant class in
@@ -378,9 +384,9 @@ def log_roc_curve(
378
384
"""Create and log a receiver operating characteristic (ROC curve) artifact.
379
385
380
386
The artifact is stored in S3 and represented as a lineage artifact
381
- with an association with the run trial component .
387
+ with an association with the run.
382
388
383
- You can view the artifact in the charts tab of the Trial Component UI.
389
+ You can view the artifact in the UI.
384
390
If your job is created by a pipeline execution you can view the artifact
385
391
by selecting the corresponding step in the pipelines UI.
386
392
See also `SageMaker Pipelines <https://aws.amazon.com/sagemaker/pipelines/>`_
@@ -393,7 +399,7 @@ def log_roc_curve(
393
399
y_score (list or array): Estimated/predicted probabilities.
394
400
title (str): Title of the graph (default: None).
395
401
is_output (bool): Determines direction of association to the
396
- trial component . Defaults to True (output artifact).
402
+ run . Defaults to True (output artifact).
397
403
If set to False then represented as input association.
398
404
"""
399
405
verify_length_of_true_and_predicted (
@@ -430,9 +436,9 @@ def log_confusion_matrix(
430
436
"""Create and log a confusion matrix artifact.
431
437
432
438
The artifact is stored in S3 and represented as a lineage artifact
433
- with an association with the run trial component .
439
+ with an association with the run.
434
440
435
- You can view the artifact in the charts tab of the Trial Component UI.
441
+ You can view the artifact in the UI.
436
442
If your job is created by a pipeline execution you can view the
437
443
artifact by selecting the corresponding step in the pipelines UI.
438
444
See also `SageMaker Pipelines <https://aws.amazon.com/sagemaker/pipelines/>`_
@@ -444,7 +450,7 @@ def log_confusion_matrix(
444
450
y_pred (list or array): Predicted labels.
445
451
title (str): Title of the graph (default: None).
446
452
is_output (bool): Determines direction of association to the
447
- trial component . Defaults to True (output artifact).
453
+ run . Defaults to True (output artifact).
448
454
If set to False then represented as input association.
449
455
"""
450
456
verify_length_of_true_and_predicted (
@@ -468,7 +474,7 @@ def log_confusion_matrix(
468
474
469
475
@validate_invoked_inside_run_context
470
476
def log_output (self , name , value , media_type = None ):
471
- """Record a single output artifact for this run trial component .
477
+ """Record a single output artifact for this run.
472
478
473
479
Overwrites any previous value recorded for the specified output name.
474
480
@@ -484,7 +490,7 @@ def log_output(self, name, value, media_type=None):
484
490
485
491
@validate_invoked_inside_run_context
486
492
def log_input (self , name , value , media_type = None ):
487
- """Record a single input artifact for this run trial component .
493
+ """Record a single input artifact for this run.
488
494
489
495
Overwrites any previous value recorded for the specified input name.
490
496
@@ -500,7 +506,7 @@ def log_input(self, name, value, media_type=None):
500
506
501
507
@validate_invoked_inside_run_context
502
508
def log_artifact_file (self , file_path , name = None , media_type = None , is_output = True ):
503
- """Upload a file to s3 and store it as an input/output artifact in this trial component .
509
+ """Upload a file to s3 and store it as an input/output artifact in this run .
504
510
505
511
Args:
506
512
file_path (str): The path of the local file to upload.
@@ -509,7 +515,7 @@ def log_artifact_file(self, file_path, name=None, media_type=None, is_output=Tru
509
515
If not specified, this library will attempt to infer the media type
510
516
from the file extension of `file_path`.
511
517
is_output (bool): Determines direction of association to the
512
- trial component . Defaults to True (output artifact).
518
+ run . Defaults to True (output artifact).
513
519
If set to False then represented as input association.
514
520
"""
515
521
self ._verify_trial_component_artifacts_length (is_output )
@@ -527,7 +533,7 @@ def log_artifact_file(self, file_path, name=None, media_type=None, is_output=Tru
527
533
528
534
@validate_invoked_inside_run_context
529
535
def log_artifact_directory (self , directory , media_type = None , is_output = True ):
530
- """Upload files under directory to s3 and log as artifacts in this trial component .
536
+ """Upload files under directory to s3 and log as artifacts in this run .
531
537
532
538
The file name is used as the artifact name
533
539
@@ -537,7 +543,7 @@ def log_artifact_directory(self, directory, media_type=None, is_output=True):
537
543
If not specified, this library will attempt to infer the media type
538
544
from the file extension of `file_path`.
539
545
is_output (bool): Determines direction of association to the
540
- trial component . Defaults to True (output artifact).
546
+ run . Defaults to True (output artifact).
541
547
If set to False then represented as input association.
542
548
"""
543
549
for dir_file in os .listdir (directory ):
@@ -549,7 +555,7 @@ def log_artifact_directory(self, directory, media_type=None, is_output=True):
549
555
550
556
@validate_invoked_inside_run_context
551
557
def log_lineage_artifact (self , file_path , name = None , media_type = None , is_output = True ):
552
- """Upload a file to S3 and creates a lineage Artifact associated with this trial component .
558
+ """Upload a file to S3 and creates a lineage Artifact associated with this run .
553
559
554
560
Args:
555
561
file_path (str): The path of the local file to upload.
@@ -558,7 +564,7 @@ def log_lineage_artifact(self, file_path, name=None, media_type=None, is_output=
558
564
If not specified, this library will attempt to infer the media type
559
565
from the file extension of `file_path`.
560
566
is_output (bool): Determines direction of association to the
561
- trial component . Defaults to True (output artifact).
567
+ run . Defaults to True (output artifact).
562
568
If set to False then represented as input association.
563
569
"""
564
570
media_type = media_type or guess_media_type (file_path )
@@ -588,11 +594,11 @@ def list(
588
594
"""Return a list of `Run` objects matching the given criteria.
589
595
590
596
Args:
591
- experiment_name (str): Only trial components related to the specified experiment
597
+ experiment_name (str): Only Run objects related to the specified experiment
592
598
are returned.
593
- created_before (datetime.datetime): Return trial components created before this instant
599
+ created_before (datetime.datetime): Return Run objects created before this instant
594
600
(default: None).
595
- created_after (datetime.datetime): Return trial components created after this instant
601
+ created_after (datetime.datetime): Return Run objects created after this instant
596
602
(default: None).
597
603
sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime'
598
604
(default: 'CreationTime').
@@ -601,7 +607,7 @@ def list(
601
607
manages interactions with Amazon SageMaker APIs and any other
602
608
AWS services needed. If not specified, one is created using the
603
609
default AWS configuration chain.
604
- max_results (int): maximum number of trial components to retrieve (default: None).
610
+ max_results (int): maximum number of Run objects to retrieve (default: None).
605
611
next_token (str): token for next page of results (default: None).
606
612
607
613
Returns:
@@ -715,7 +721,7 @@ def _verify_trial_component_artifacts_length(self, is_output):
715
721
Raises:
716
722
ValueError: If the length of trial component artifacts exceeds the limit.
717
723
"""
718
- err_msg_template = "Cannot add more than {} {}_artifacts under run trial_component "
724
+ err_msg_template = "Cannot add more than {} {}_artifacts under run"
719
725
if is_output :
720
726
if len (self ._trial_component .output_artifacts ) >= MAX_RUN_TC_ARTIFACTS_LEN :
721
727
raise ValueError (err_msg_template .format (MAX_RUN_TC_ARTIFACTS_LEN , "output" ))
@@ -777,11 +783,19 @@ def _get_tc_and_exp_config_from_job_env(
777
783
callable_func = lambda : sagemaker_session .describe_training_job (job_name ),
778
784
num_attempts = 4 ,
779
785
)
780
- else : # environment.environment_type == _EnvironmentType.SageMakerProcessingJob
786
+ elif environment .environment_type == _EnvironmentType .SageMakerProcessingJob :
781
787
job_response = retry_with_backoff (
782
788
callable_func = lambda : sagemaker_session .describe_processing_job (job_name ),
783
789
num_attempts = 4 ,
784
790
)
791
+ else : # environment.environment_type == _EnvironmentType.SageMakerTransformJob
792
+ raise RuntimeError (
793
+ "Failed to load the Run as loading experiment config "
794
+ "from transform job environment is not currently supported. "
795
+ "As a workaround, please explicitly pass in "
796
+ "the experiment_name and run_name in Run.load."
797
+ )
798
+
785
799
job_exp_config = job_response .get ("ExperimentConfig" , dict ())
786
800
if job_exp_config .get (RUN_NAME , None ):
787
801
# The run with RunName has been created outside of the job env.
@@ -867,7 +881,7 @@ def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) ->
867
881
return tags
868
882
869
883
def __enter__ (self ):
870
- """Updates the start time of the tracked trial component .
884
+ """Updates the start time of the run .
871
885
872
886
Returns:
873
887
object: self.
@@ -897,7 +911,7 @@ def __enter__(self):
897
911
return self
898
912
899
913
def __exit__ (self , exc_type , exc_value , exc_traceback ):
900
- """Updates the end time of the tracked trial component .
914
+ """Updates the end time of the run .
901
915
902
916
Args:
903
917
exc_type (str): The exception type.
0 commit comments