14
14
from __future__ import absolute_import
15
15
16
16
import abc
17
-
18
17
from enum import Enum
19
18
from typing import Dict , List , Union
20
19
21
20
import attr
22
21
23
22
from sagemaker .estimator import EstimatorBase , _TrainingJob
24
- from sagemaker .inputs import CreateModelInput , TrainingInput , TransformInput , FileSystemInput
23
+ from sagemaker .inputs import (
24
+ CompilationInput ,
25
+ CreateModelInput ,
26
+ FileSystemInput ,
27
+ TrainingInput ,
28
+ TransformInput ,
29
+ )
25
30
from sagemaker .model import Model
26
31
from sagemaker .processing import (
27
32
ProcessingInput ,
31
36
)
32
37
from sagemaker .transformer import Transformer , _TransformJob
33
38
from sagemaker .tuner import HyperparameterTuner , _TuningJob
34
- from sagemaker .workflow .entities import (
35
- DefaultEnumMeta ,
36
- Entity ,
37
- RequestType ,
38
- )
39
- from sagemaker .workflow .properties import (
40
- PropertyFile ,
41
- Properties ,
42
- )
39
+ from sagemaker .workflow .entities import DefaultEnumMeta , Entity , RequestType
43
40
from sagemaker .workflow .functions import Join
41
+ from sagemaker .workflow .properties import Properties , PropertyFile
44
42
from sagemaker .workflow .retry import RetryPolicy
45
43
46
44
@@ -55,6 +53,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
55
53
TRANSFORM = "Transform"
56
54
CALLBACK = "Callback"
57
55
TUNING = "Tuning"
56
+ COMPILATION = "Compilation"
58
57
LAMBDA = "Lambda"
59
58
60
59
@@ -681,3 +680,81 @@ def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") ->
681
680
"output/model.tar.gz" ,
682
681
],
683
682
)
683
+
684
+
685
+ class CompilationStep (ConfigurableRetryStep ):
686
+ """Compilation step for workflow."""
687
+
688
+ def __init__ (
689
+ self ,
690
+ name : str ,
691
+ estimator : EstimatorBase ,
692
+ model : Model ,
693
+ inputs : CompilationInput = None ,
694
+ job_arguments : List [str ] = None ,
695
+ depends_on : Union [List [str ], List [Step ]] = None ,
696
+ retry_policies : List [RetryPolicy ] = None ,
697
+ display_name : str = None ,
698
+ description : str = None ,
699
+ cache_config : CacheConfig = None ,
700
+ ):
701
+ """Construct a CompilationStep.
702
+
703
+ Given an `EstimatorBase` and a `sagemaker.model.Model` instance construct a CompilationStep.
704
+
705
+ In addition to the estimator and Model instances, the other arguments are those that are
706
+ supplied to the `compile_model` method of the `sagemaker.model.Model.compile_model`.
707
+
708
+ Args:
709
+ name (str): The name of the compilation step.
710
+ estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
711
+ model (Model): A `sagemaker.model.Model` instance.
712
+ inputs (CompilationInput): A `sagemaker.inputs.CompilationInput` instance.
713
+ Defaults to `None`.
714
+ job_arguments (List[str]): A list of strings to be passed into the processing job.
715
+ Defaults to `None`.
716
+ depends_on (List[str] or List[Step]): A list of step names or step instances
717
+ this `sagemaker.workflow.steps.CompilationStep` depends on
718
+ retry_policies (List[RetryPolicy]): A list of retry policy
719
+ display_name (str): The display name of the compilation step.
720
+ description (str): The description of the compilation step.
721
+ cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
722
+ """
723
+ super (CompilationStep , self ).__init__ (
724
+ name , StepTypeEnum .COMPILATION , display_name , description , depends_on , retry_policies
725
+ )
726
+ self .estimator = estimator
727
+ self .model = model
728
+ self .inputs = inputs
729
+ self .job_arguments = job_arguments
730
+ self ._properties = Properties (
731
+ path = f"Steps.{ name } " , shape_name = "DescribeCompilationJobResponse"
732
+ )
733
+ self .cache_config = cache_config
734
+
735
+ @property
736
+ def arguments (self ) -> RequestType :
737
+ """The arguments dict that is used to call `create_compilation_job`.
738
+
739
+ NOTE: The CreateTrainingJob request is not quite the args list that workflow needs.
740
+ The TrainingJobName and ExperimentConfig attributes cannot be included.
741
+ """
742
+
743
+ compilation_args = self .model ._get_compilation_args (self .estimator , self .inputs )
744
+ request_dict = self .model .sagemaker_session ._get_compilation_request (** compilation_args )
745
+ request_dict .pop ("CompilationJobName" )
746
+
747
+ return request_dict
748
+
749
+ @property
750
+ def properties (self ):
751
+ """A Properties object representing the DescribeTrainingJobResponse data model."""
752
+ return self ._properties
753
+
754
+ def to_request (self ) -> RequestType :
755
+ """Updates the dictionary with cache configuration."""
756
+ request_dict = super ().to_request ()
757
+ if self .cache_config :
758
+ request_dict .update (self .cache_config .config )
759
+
760
+ return request_dict
0 commit comments