diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index afa512f0f7..149accd7d3 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -92,6 +92,38 @@ def ref(self) -> Dict[str, str]: return {"Name": self.name} +@attr.s +class CacheConfig: + """Configuration class to enable caching in pipeline workflow. + + If caching is enabled, the pipeline attempts to find a previous execution of a step + that was called with the same arguments. Step caching only considers successful execution. + If a successful previous execution is found, the pipeline propagates the values + from previous execution rather than recomputing the step. When multiple successful executions + exist within the timeout period, it uses the result for the most recent successful execution. + + + Attributes: + enable_caching (bool): To enable step caching. Defaults to `False`. + expire_after (str): If step caching is enabled, a timeout also needs to defined. + It defines how old a previous execution can be to be considered for reuse. + Value should be an ISO 8601 duration string. Defaults to `None`. + """ + + enable_caching: bool = attr.ib(default=False) + expire_after = attr.ib( + default=None, validator=attr.validators.optional(attr.validators.instance_of(str)) + ) + + @property + def config(self): + """Configures caching in pipeline steps.""" + config = {"Enabled": self.enable_caching} + if self.expire_after is not None: + config["ExpireAfter"] = self.expire_after + return {"CacheConfig": config} + + class TrainingStep(Step): """Training step for workflow.""" @@ -100,6 +132,7 @@ def __init__( name: str, estimator: EstimatorBase, inputs: TrainingInput = None, + cache_config: CacheConfig = None, ): """Construct a TrainingStep, given an `EstimatorBase` instance. @@ -110,14 +143,15 @@ def __init__( name (str): The name of the training step. estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance. inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`. + cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. """ super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING) self.estimator = estimator self.inputs = inputs - self._properties = Properties( path=f"Steps.{name}", shape_name="DescribeTrainingJobResponse" ) + self.cache_config = cache_config @property def arguments(self) -> RequestType: @@ -144,6 +178,14 @@ def properties(self): """A Properties object representing the DescribeTrainingJobResponse data model.""" return self._properties + def to_request(self) -> RequestType: + """Updates the dictionary with cache configuration.""" + request_dict = super().to_request() + if self.cache_config: + request_dict.update(self.cache_config.config) + + return request_dict + class CreateModelStep(Step): """CreateModel step for workflow.""" @@ -207,6 +249,7 @@ def __init__( name: str, transformer: Transformer, inputs: TransformInput, + cache_config: CacheConfig = None, ): """Constructs a TransformStep, given an `Transformer` instance. @@ -217,11 +260,12 @@ def __init__( name (str): The name of the transform step. transformer (Transformer): A `sagemaker.transformer.Transformer` instance. inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance. + cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. """ super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM) self.transformer = transformer self.inputs = inputs - + self.cache_config = cache_config self._properties = Properties( path=f"Steps.{name}", shape_name="DescribeTransformJobResponse" ) @@ -257,6 +301,14 @@ def properties(self): """A Properties object representing the DescribeTransformJobResponse data model.""" return self._properties + def to_request(self) -> RequestType: + """Updates the dictionary with cache configuration.""" + request_dict = super().to_request() + if self.cache_config: + request_dict.update(self.cache_config.config) + + return request_dict + class ProcessingStep(Step): """Processing step for workflow.""" @@ -270,6 +322,7 @@ def __init__( job_arguments: List[str] = None, code: str = None, property_files: List[PropertyFile] = None, + cache_config: CacheConfig = None, ): """Construct a ProcessingStep, given a `Processor` instance. @@ -289,6 +342,7 @@ def __init__( script to run. Defaults to `None`. property_files (List[PropertyFile]): A list of property files that workflow looks for and resolves from the configured processing output list. + cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. """ super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING) self.processor = processor @@ -305,6 +359,7 @@ def __init__( self._properties = Properties( path=f"Steps.{name}", shape_name="DescribeProcessingJobResponse" ) + self.cache_config = cache_config @property def arguments(self) -> RequestType: @@ -335,6 +390,8 @@ def properties(self): def to_request(self) -> RequestType: """Get the request structure for workflow service calls.""" request_dict = super(ProcessingStep, self).to_request() + if self.cache_config: + request_dict.update(self.cache_config.config) if self.property_files: request_dict["PropertyFiles"] = [ property_file.expr for property_file in self.property_files diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index b3044185eb..9f328bda25 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -46,6 +46,7 @@ CreateModelStep, ProcessingStep, TrainingStep, + CacheConfig, ) from sagemaker.workflow.step_collections import RegisterModel from sagemaker.workflow.pipeline import Pipeline @@ -274,6 +275,8 @@ def test_one_step_sklearn_processing_pipeline( ProcessingInput(dataset_definition=athena_dataset_definition), ] + cache_config = CacheConfig(enable_caching=True, expire_after="T30m") + sklearn_processor = SKLearnProcessor( framework_version=sklearn_latest_version, role=role, @@ -289,6 +292,7 @@ def test_one_step_sklearn_processing_pipeline( processor=sklearn_processor, inputs=inputs, code=script_path, + cache_config=cache_config, ) pipeline = Pipeline( name=pipeline_name, @@ -328,6 +332,11 @@ def test_one_step_sklearn_processing_pipeline( response = execution.describe() assert response["PipelineArn"] == create_arn + # Check CacheConfig + response = json.loads(pipeline.describe()["PipelineDefinition"])["Steps"][0]["CacheConfig"] + assert response["Enabled"] == cache_config.enable_caching + assert response["ExpireAfter"] == cache_config.expire_after + try: execution.wait(delay=30, max_attempts=3) except WaiterError: diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index a28fbf4878..6bb2586a7c 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -37,6 +37,7 @@ TrainingStep, TransformStep, CreateModelStep, + CacheConfig, ) REGION = "us-west-2" @@ -114,10 +115,9 @@ def test_training_step(sagemaker_session): sagemaker_session=sagemaker_session, ) inputs = TrainingInput(f"s3://{BUCKET}/train_manifest") + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") step = TrainingStep( - name="MyTrainingStep", - estimator=estimator, - inputs=inputs, + name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config ) assert step.to_request() == { "Name": "MyTrainingStep", @@ -145,6 +145,7 @@ def test_training_step(sagemaker_session): "RoleArn": ROLE, "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, }, + "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, } assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"} @@ -163,11 +164,13 @@ def test_processing_step(sagemaker_session): destination="processing_manifest", ) ] + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") step = ProcessingStep( name="MyProcessingStep", processor=processor, inputs=inputs, outputs=[], + cache_config=cache_config, ) assert step.to_request() == { "Name": "MyProcessingStep", @@ -197,6 +200,7 @@ def test_processing_step(sagemaker_session): }, "RoleArn": "DummyRole", }, + "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, } assert step.properties.ProcessingJobName.expr == { "Get": "Steps.MyProcessingStep.ProcessingJobName" @@ -238,10 +242,9 @@ def test_transform_step(sagemaker_session): sagemaker_session=sagemaker_session, ) inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest") + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") step = TransformStep( - name="MyTransformStep", - transformer=transformer, - inputs=inputs, + name="MyTransformStep", transformer=transformer, inputs=inputs, cache_config=cache_config ) assert step.to_request() == { "Name": "MyTransformStep", @@ -262,6 +265,7 @@ def test_transform_step(sagemaker_session): "InstanceType": "c4.4xlarge", }, }, + "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, } assert step.properties.TransformJobName.expr == { "Get": "Steps.MyTransformStep.TransformJobName"