30
30
Processor ,
31
31
)
32
32
from sagemaker .transformer import Transformer , _TransformJob
33
+ from sagemaker .tuner import HyperparameterTuner , _TuningJob
33
34
from sagemaker .workflow .entities import (
34
35
DefaultEnumMeta ,
35
36
Entity ,
39
40
PropertyFile ,
40
41
Properties ,
41
42
)
43
+ from sagemaker .workflow .functions import Join
42
44
43
45
44
46
class StepTypeEnum (Enum , metaclass = DefaultEnumMeta ):
@@ -51,6 +53,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
51
53
TRAINING = "Training"
52
54
TRANSFORM = "Transform"
53
55
CALLBACK = "Callback"
56
+ TUNING = "Tuning"
54
57
55
58
56
59
@attr .s
@@ -92,6 +95,7 @@ def add_depends_on(self, step_names: List[str]):
92
95
"""Add step names to the current step depends on list"""
93
96
if not step_names :
94
97
return
98
+
95
99
if not self .depends_on :
96
100
self .depends_on = []
97
101
self .depends_on .extend (step_names )
@@ -429,3 +433,132 @@ def to_request(self) -> RequestType:
429
433
property_file .expr for property_file in self .property_files
430
434
]
431
435
return request_dict
436
+
437
+
438
+ class TuningStep (Step ):
439
+ """Tuning step for workflow."""
440
+
441
+ def __init__ (
442
+ self ,
443
+ name : str ,
444
+ tuner : HyperparameterTuner ,
445
+ inputs = None ,
446
+ job_arguments : List [str ] = None ,
447
+ cache_config : CacheConfig = None ,
448
+ depends_on : List [str ] = None ,
449
+ ):
450
+ """Construct a TuningStep, given a `HyperparameterTuner` instance.
451
+
452
+ In addition to the tuner instance, the other arguments are those that are supplied to
453
+ the `fit` method of the `sagemaker.tuner.HyperparameterTuner`.
454
+
455
+ Args:
456
+ name (str): The name of the tuning step.
457
+ tuner (HyperparameterTuner): A `sagemaker.tuner.HyperparameterTuner` instance.
458
+ inputs: Information about the training data. Please refer to the
459
+ ``fit()`` method of the associated estimator, as this can take
460
+ any of the following forms:
461
+
462
+ * (str) - The S3 location where training data is saved.
463
+ * (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) -
464
+ If using multiple channels for training data, you can specify
465
+ a dict mapping channel names to strings or
466
+ :func:`~sagemaker.inputs.TrainingInput` objects.
467
+ * (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources
468
+ that can provide additional information about the training dataset.
469
+ See :func:`sagemaker.inputs.TrainingInput` for full details.
470
+ * (sagemaker.session.FileSystemInput) - channel configuration for
471
+ a file system data source that can provide additional information as well as
472
+ the path to the training dataset.
473
+ * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
474
+ Amazon :class:~`Record` objects serialized and stored in S3.
475
+ For use with an estimator for an Amazon algorithm.
476
+ * (sagemaker.amazon.amazon_estimator.FileSystemRecordSet) -
477
+ Amazon SageMaker channel configuration for a file system data source for
478
+ Amazon algorithms.
479
+ * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
480
+ :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
481
+ where each instance is a different channel of training data.
482
+ * (list[sagemaker.amazon.amazon_estimator.FileSystemRecordSet]) - A list of
483
+ :class:~`sagemaker.amazon.amazon_estimator.FileSystemRecordSet` objects,
484
+ where each instance is a different channel of training data.
485
+ job_arguments (List[str]): A list of strings to be passed into the processing job.
486
+ Defaults to `None`.
487
+ cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
488
+ depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
489
+ depends on
490
+ """
491
+ super (TuningStep , self ).__init__ (name , StepTypeEnum .TUNING , depends_on )
492
+ self .tuner = tuner
493
+ self .inputs = inputs
494
+ self .job_arguments = job_arguments
495
+ self ._properties = Properties (
496
+ path = f"Steps.{ name } " ,
497
+ shape_names = [
498
+ "DescribeHyperParameterTuningJobResponse" ,
499
+ "ListTrainingJobsForHyperParameterTuningJobResponse" ,
500
+ ],
501
+ )
502
+ self .cache_config = cache_config
503
+
504
+ @property
505
+ def arguments (self ) -> RequestType :
506
+ """The arguments dict that is used to call `create_hyper_parameter_tuning_job`.
507
+
508
+ NOTE: The CreateHyperParameterTuningJob request is not quite the
509
+ args list that workflow needs.
510
+ The HyperParameterTuningJobName attribute cannot be included.
511
+ """
512
+ if self .tuner .estimator is not None :
513
+ self .tuner .estimator ._prepare_for_training ()
514
+ else :
515
+ for _ , estimator in self .tuner .estimator_dict .items ():
516
+ estimator ._prepare_for_training ()
517
+
518
+ self .tuner ._prepare_for_tuning ()
519
+ tuner_args = _TuningJob ._get_tuner_args (self .tuner , self .inputs )
520
+ request_dict = self .tuner .sagemaker_session ._get_tuning_request (** tuner_args )
521
+ request_dict .pop ("HyperParameterTuningJobName" )
522
+
523
+ return request_dict
524
+
525
+ @property
526
+ def properties (self ):
527
+ """A Properties object representing
528
+
529
+ `DescribeHyperParameterTuningJobResponse` and
530
+ `ListTrainingJobsForHyperParameterTuningJobResponse` data model.
531
+ """
532
+ return self ._properties
533
+
534
+ def to_request (self ) -> RequestType :
535
+ """Updates the dictionary with cache configuration."""
536
+ request_dict = super ().to_request ()
537
+ if self .cache_config :
538
+ request_dict .update (self .cache_config .config )
539
+
540
+ return request_dict
541
+
542
+ def get_top_model_s3_uri (self , top_k : int , s3_bucket : str , prefix : str = "" ):
543
+ """Get the model artifact s3 uri from the top performing training jobs.
544
+
545
+ Args:
546
+ top_k (int): the index of the top performing training job
547
+ tuning step stores up to 50 top performing training jobs, hence
548
+ a valid top_k value is from 0 to 49. The best training job
549
+ model is at index 0
550
+ s3_bucket (str): the s3 bucket to store the training job output artifact
551
+ prefix (str): the s3 key prefix to store the training job output artifact
552
+ """
553
+ values = ["s3:/" , s3_bucket ]
554
+ if prefix != "" and prefix is not None :
555
+ values .append (prefix )
556
+
557
+ return Join (
558
+ on = "/" ,
559
+ values = values
560
+ + [
561
+ self .properties .TrainingJobSummaries [top_k ].TrainingJobName ,
562
+ "output/model.tar.gz" ,
563
+ ],
564
+ )
0 commit comments