15
15
16
16
import abc
17
17
import warnings
18
+
18
19
from enum import Enum
19
20
from typing import Dict , List , Union
20
21
from urllib .parse import urlparse
21
22
22
23
import attr
23
24
24
25
from sagemaker .estimator import EstimatorBase , _TrainingJob
25
- from sagemaker .inputs import (
26
- CompilationInput ,
27
- CreateModelInput ,
28
- FileSystemInput ,
29
- TrainingInput ,
30
- TransformInput ,
31
- )
26
+ from sagemaker .inputs import CreateModelInput , TrainingInput , TransformInput , FileSystemInput
32
27
from sagemaker .model import Model
33
28
from sagemaker .pipeline import PipelineModel
34
29
from sagemaker .processing import (
39
34
)
40
35
from sagemaker .transformer import Transformer , _TransformJob
41
36
from sagemaker .tuner import HyperparameterTuner , _TuningJob
42
- from sagemaker .workflow .entities import DefaultEnumMeta , Entity , RequestType
37
+ from sagemaker .workflow .entities import (
38
+ DefaultEnumMeta ,
39
+ Entity ,
40
+ RequestType ,
41
+ )
42
+ from sagemaker .workflow .properties import (
43
+ PropertyFile ,
44
+ Properties ,
45
+ )
43
46
from sagemaker .workflow .functions import Join
44
- from sagemaker .workflow .properties import Properties , PropertyFile
45
47
from sagemaker .workflow .retry import RetryPolicy
46
48
47
49
@@ -56,7 +58,6 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
56
58
TRANSFORM = "Transform"
57
59
CALLBACK = "Callback"
58
60
TUNING = "Tuning"
59
- COMPILATION = "Compilation"
60
61
LAMBDA = "Lambda"
61
62
QUALITY_CHECK = "QualityCheck"
62
63
CLARIFY_CHECK = "ClarifyCheck"
@@ -730,81 +731,3 @@ def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") ->
730
731
"output/model.tar.gz" ,
731
732
],
732
733
)
733
-
734
-
735
- class CompilationStep (ConfigurableRetryStep ):
736
- """Compilation step for workflow."""
737
-
738
- def __init__ (
739
- self ,
740
- name : str ,
741
- estimator : EstimatorBase ,
742
- model : Model ,
743
- inputs : CompilationInput = None ,
744
- job_arguments : List [str ] = None ,
745
- depends_on : Union [List [str ], List [Step ]] = None ,
746
- retry_policies : List [RetryPolicy ] = None ,
747
- display_name : str = None ,
748
- description : str = None ,
749
- cache_config : CacheConfig = None ,
750
- ):
751
- """Construct a CompilationStep.
752
-
753
- Given an `EstimatorBase` and a `sagemaker.model.Model` instance construct a CompilationStep.
754
-
755
- In addition to the estimator and Model instances, the other arguments are those that are
756
- supplied to the `compile_model` method of the `sagemaker.model.Model.compile_model`.
757
-
758
- Args:
759
- name (str): The name of the compilation step.
760
- estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
761
- model (Model): A `sagemaker.model.Model` instance.
762
- inputs (CompilationInput): A `sagemaker.inputs.CompilationInput` instance.
763
- Defaults to `None`.
764
- job_arguments (List[str]): A list of strings to be passed into the processing job.
765
- Defaults to `None`.
766
- depends_on (List[str] or List[Step]): A list of step names or step instances
767
- this `sagemaker.workflow.steps.CompilationStep` depends on
768
- retry_policies (List[RetryPolicy]): A list of retry policy
769
- display_name (str): The display name of the compilation step.
770
- description (str): The description of the compilation step.
771
- cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
772
- """
773
- super (CompilationStep , self ).__init__ (
774
- name , StepTypeEnum .COMPILATION , display_name , description , depends_on , retry_policies
775
- )
776
- self .estimator = estimator
777
- self .model = model
778
- self .inputs = inputs
779
- self .job_arguments = job_arguments
780
- self ._properties = Properties (
781
- path = f"Steps.{ name } " , shape_name = "DescribeCompilationJobResponse"
782
- )
783
- self .cache_config = cache_config
784
-
785
- @property
786
- def arguments (self ) -> RequestType :
787
- """The arguments dict that is used to call `create_compilation_job`.
788
-
789
- NOTE: The CreateTrainingJob request is not quite the args list that workflow needs.
790
- The TrainingJobName and ExperimentConfig attributes cannot be included.
791
- """
792
-
793
- compilation_args = self .model ._get_compilation_args (self .estimator , self .inputs )
794
- request_dict = self .model .sagemaker_session ._get_compilation_request (** compilation_args )
795
- request_dict .pop ("CompilationJobName" )
796
-
797
- return request_dict
798
-
799
- @property
800
- def properties (self ):
801
- """A Properties object representing the DescribeTrainingJobResponse data model."""
802
- return self ._properties
803
-
804
- def to_request (self ) -> RequestType :
805
- """Updates the dictionary with cache configuration."""
806
- request_dict = super ().to_request ()
807
- if self .cache_config :
808
- request_dict .update (self .cache_config .config )
809
-
810
- return request_dict
0 commit comments