14
14
from __future__ import absolute_import
15
15
16
16
import abc
17
-
18
17
from enum import Enum
19
18
from typing import Dict , List , Union
20
19
21
20
import attr
22
21
23
22
from sagemaker .estimator import EstimatorBase , _TrainingJob
24
- from sagemaker .inputs import CreateModelInput , TrainingInput , TransformInput , FileSystemInput
23
+ from sagemaker .inputs import (
24
+ CompilationInput ,
25
+ CreateModelInput ,
26
+ FileSystemInput ,
27
+ TrainingInput ,
28
+ TransformInput ,
29
+ )
25
30
from sagemaker .model import Model
26
31
from sagemaker .processing import (
27
32
ProcessingInput ,
31
36
)
32
37
from sagemaker .transformer import Transformer , _TransformJob
33
38
from sagemaker .tuner import HyperparameterTuner , _TuningJob
34
- from sagemaker .workflow .entities import (
35
- DefaultEnumMeta ,
36
- Entity ,
37
- RequestType ,
38
- )
39
- from sagemaker .workflow .properties import (
40
- PropertyFile ,
41
- Properties ,
42
- )
39
+ from sagemaker .workflow .entities import DefaultEnumMeta , Entity , RequestType
43
40
from sagemaker .workflow .functions import Join
41
+ from sagemaker .workflow .properties import Properties , PropertyFile
44
42
from sagemaker .workflow .retry import RetryPolicy
45
43
46
44
@@ -55,6 +53,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
55
53
TRANSFORM = "Transform"
56
54
CALLBACK = "Callback"
57
55
TUNING = "Tuning"
56
+ COMPILATION = "Compilation"
58
57
LAMBDA = "Lambda"
59
58
EMR = "EMR"
60
59
@@ -682,3 +681,81 @@ def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") ->
682
681
"output/model.tar.gz" ,
683
682
],
684
683
)
684
+
685
+
686
+ class CompilationStep (ConfigurableRetryStep ):
687
+ """Compilation step for workflow."""
688
+
689
+ def __init__ (
690
+ self ,
691
+ name : str ,
692
+ estimator : EstimatorBase ,
693
+ model : Model ,
694
+ inputs : CompilationInput = None ,
695
+ job_arguments : List [str ] = None ,
696
+ depends_on : Union [List [str ], List [Step ]] = None ,
697
+ retry_policies : List [RetryPolicy ] = None ,
698
+ display_name : str = None ,
699
+ description : str = None ,
700
+ cache_config : CacheConfig = None ,
701
+ ):
702
+ """Construct a CompilationStep.
703
+
704
+ Given an `EstimatorBase` and a `sagemaker.model.Model` instance construct a CompilationStep.
705
+
706
+ In addition to the estimator and Model instances, the other arguments are those that are
707
+ supplied to the `compile_model` method of the `sagemaker.model.Model.compile_model`.
708
+
709
+ Args:
710
+ name (str): The name of the compilation step.
711
+ estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
712
+ model (Model): A `sagemaker.model.Model` instance.
713
+ inputs (CompilationInput): A `sagemaker.inputs.CompilationInput` instance.
714
+ Defaults to `None`.
715
+ job_arguments (List[str]): A list of strings to be passed into the processing job.
716
+ Defaults to `None`.
717
+ depends_on (List[str] or List[Step]): A list of step names or step instances
718
+ this `sagemaker.workflow.steps.CompilationStep` depends on
719
+ retry_policies (List[RetryPolicy]): A list of retry policy
720
+ display_name (str): The display name of the compilation step.
721
+ description (str): The description of the compilation step.
722
+ cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
723
+ """
724
+ super (CompilationStep , self ).__init__ (
725
+ name , StepTypeEnum .COMPILATION , display_name , description , depends_on , retry_policies
726
+ )
727
+ self .estimator = estimator
728
+ self .model = model
729
+ self .inputs = inputs
730
+ self .job_arguments = job_arguments
731
+ self ._properties = Properties (
732
+ path = f"Steps.{ name } " , shape_name = "DescribeCompilationJobResponse"
733
+ )
734
+ self .cache_config = cache_config
735
+
736
+ @property
737
+ def arguments (self ) -> RequestType :
738
+ """The arguments dict that is used to call `create_compilation_job`.
739
+
740
+ NOTE: The CreateTrainingJob request is not quite the args list that workflow needs.
741
+ The TrainingJobName and ExperimentConfig attributes cannot be included.
742
+ """
743
+
744
+ compilation_args = self .model ._get_compilation_args (self .estimator , self .inputs )
745
+ request_dict = self .model .sagemaker_session ._get_compilation_request (** compilation_args )
746
+ request_dict .pop ("CompilationJobName" )
747
+
748
+ return request_dict
749
+
750
+ @property
751
+ def properties (self ):
752
+ """A Properties object representing the DescribeTrainingJobResponse data model."""
753
+ return self ._properties
754
+
755
+ def to_request (self ) -> RequestType :
756
+ """Updates the dictionary with cache configuration."""
757
+ request_dict = super ().to_request ()
758
+ if self .cache_config :
759
+ request_dict .update (self .cache_config .config )
760
+
761
+ return request_dict
0 commit comments