From e1d0d45fa67a31e8b20fe8cef11b60987ae8e17a Mon Sep 17 00:00:00 2001 From: Ahsan Khan Date: Thu, 21 Jan 2021 15:04:03 -0800 Subject: [PATCH 01/10] add cache config and unit test --- src/sagemaker/workflow/steps.py | 47 ++++++++++++- tests/integ/test_workflow.py | 78 +++++++++++++++++++++ tests/unit/sagemaker/workflow/test_steps.py | 16 +++-- 3 files changed, 133 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 3327bb988f..7545390358 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -64,6 +64,7 @@ class Step(Entity): Attributes: name (str): The name of the step. step_type (StepTypeEnum): The type of the step. + """ name: str = attr.ib(factory=str) @@ -93,6 +94,26 @@ def ref(self) -> Dict[str, str]: return {"Name": self.name} +@attr.s +class CacheConfig: + """Step to cache pipeline workflow. + + Attributes: + enable_caching (bool): To enable step caching. Off by default. + expire_after (str): If step caching is enabled, a timeout also needs to defined. + It defines how old a previous execution can be to be considered for reuse. + Needs to be ISO 8601 duration string. + """ + + enable_caching: bool = attr.ib(default=False) + expire_after: str = attr.ib(factory=str) + + @property + def config(self): + """Enables caching in pipeline steps.""" + return {"CacheConfig": {"Enabled": self.enable_caching, "ExpireAfter": self.expire_after}} + + class TrainingStep(Step): """Training step for workflow.""" @@ -101,6 +122,7 @@ def __init__( name: str, estimator: EstimatorBase, inputs: TrainingInput = None, + cache_config: CacheConfig = None, ): """Construct a TrainingStep, given an `EstimatorBase` instance. @@ -111,14 +133,15 @@ def __init__( name (str): The name of the training step. estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance. inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`. + cache_config (CacheConfig): An instance to enable caching. """ super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING) self.estimator = estimator self.inputs = inputs - self._properties = Properties( path=f"Steps.{name}", shape_name="DescribeTrainingJobResponse" ) + self.cache_config = cache_config @property def arguments(self) -> RequestType: @@ -145,6 +168,13 @@ def properties(self): """A Properties object representing the DescribeTrainingJobResponse data model.""" return self._properties + def to_request(self) -> RequestType: + """Updates the dictionary with cache configuration.""" + request_dict = super().to_request() + request_dict.update(self.cache_config.config) + + return request_dict + class CreateModelStep(Step): """CreateModel step for workflow.""" @@ -208,6 +238,7 @@ def __init__( name: str, transformer: Transformer, inputs: TransformInput, + cache_config: CacheConfig = None, ): """Constructs a TransformStep, given an `Transformer` instance. @@ -218,11 +249,12 @@ def __init__( name (str): The name of the transform step. transformer (Transformer): A `sagemaker.transformer.Transformer` instance. inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance. + cache_config (CacheConfig): An instance to enable caching. """ super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM) self.transformer = transformer self.inputs = inputs - + self.cache_config = cache_config self._properties = Properties( path=f"Steps.{name}", shape_name="DescribeTransformJobResponse" ) @@ -258,6 +290,13 @@ def properties(self): """A Properties object representing the DescribeTransformJobResponse data model.""" return self._properties + def to_request(self) -> RequestType: + """Updates the dictionary with cache configuration.""" + request_dict = super().to_request() + request_dict.update(self.cache_config.config) + + return request_dict + class ProcessingStep(Step): """Processing step for workflow.""" @@ -271,6 +310,7 @@ def __init__( job_arguments: List[str] = None, code: str = None, property_files: List[PropertyFile] = None, + cache_config: CacheConfig = None, ): """Construct a ProcessingStep, given a `Processor` instance. @@ -290,6 +330,7 @@ def __init__( script to run. Defaults to `None`. property_files (List[PropertyFile]): A list of property files that workflow looks for and resolves from the configured processing output list. + cache_config (CacheConfig): An instance to enable caching. """ super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING) self.processor = processor @@ -306,6 +347,7 @@ def __init__( self._properties = Properties( path=f"Steps.{name}", shape_name="DescribeProcessingJobResponse" ) + self.cache_config = cache_config @property def arguments(self) -> RequestType: @@ -336,6 +378,7 @@ def properties(self): def to_request(self) -> RequestType: """Get the request structure for workflow service calls.""" request_dict = super(ProcessingStep, self).to_request() + request_dict.update(self.cache_config.config) if self.property_files: request_dict["PropertyFiles"] = [ property_file.expr for property_file in self.property_files diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index a07a849330..cb5d868f90 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -48,6 +48,7 @@ CreateModelStep, ProcessingStep, TrainingStep, + CacheConfig ) from sagemaker.workflow.step_collections import RegisterModel from sagemaker.workflow.pipeline import Pipeline @@ -551,3 +552,80 @@ def test_training_job_with_debugger( pipeline.delete() except Exception: pass + + +def test_cache_hit_expired_entry( + sagemaker_session, + workflow_session, + region_name, + role, + script_dir, + pipeline_name, +): + + instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") + instance_count = ParameterInteger(name="InstanceCount", default_value=1) + + + estimator = + + step_train = TrainingStep( + name="my-train", + estimator=sklearn_train, + inputs=TrainingInput( + s3_data=step_process.properties.ProcessingOutputConfig.Outputs[ + "train_data" + ].S3Output.S3Uri + ), + cache_config= + ) + pipeline = Pipeline( + name=pipeline_name, + parameters=[instance_type, instance_count], + steps=[step_train], + sagemaker_session=workflow_session, + ) + + try: + # NOTE: We should exercise the case when role used in the pipeline execution is + # different than that required of the steps in the pipeline itself. The role in + # the pipeline definition needs to create training and processing jobs and other + # sagemaker entities. However, the jobs created in the steps themselves execute + # under a potentially different role, often requiring access to S3 and other + # artifacts not required to during creation of the jobs in the pipeline steps. + response = pipeline.create(role) + create_arn = response["PipelineArn"] + assert re.match( + fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}", + create_arn, + ) + + pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)] + response = pipeline.update(role) + update_arn = response["PipelineArn"] + assert re.match( + fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}", + update_arn, + ) + + execution = pipeline.start(parameters={}) + assert re.match( + fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}/execution/", + execution.arn, + ) + + response = execution.describe() + assert response["PipelineArn"] == create_arn + + try: + execution.wait(delay=30, max_attempts=3) + except WaiterError: + pass + execution_steps = execution.list_steps() + assert len(execution_steps) == 1 + assert execution_steps[0]["StepName"] == "sklearn-process" + finally: + try: + pipeline.delete() + except Exception: + pass diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index a28fbf4878..9db2067bb6 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -37,6 +37,7 @@ TrainingStep, TransformStep, CreateModelStep, + CacheConfig, ) REGION = "us-west-2" @@ -114,10 +115,9 @@ def test_training_step(sagemaker_session): sagemaker_session=sagemaker_session, ) inputs = TrainingInput(f"s3://{BUCKET}/train_manifest") + cache_config = CacheConfig(enable_caching=False, expire_after="PT1H") step = TrainingStep( - name="MyTrainingStep", - estimator=estimator, - inputs=inputs, + name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config ) assert step.to_request() == { "Name": "MyTrainingStep", @@ -145,6 +145,7 @@ def test_training_step(sagemaker_session): "RoleArn": ROLE, "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, }, + "CacheConfig": {"Enabled": False, "ExpireAfter": "PT1H"}, } assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"} @@ -163,11 +164,13 @@ def test_processing_step(sagemaker_session): destination="processing_manifest", ) ] + cache_config = CacheConfig(enable_caching=False, expire_after="PT1H") step = ProcessingStep( name="MyProcessingStep", processor=processor, inputs=inputs, outputs=[], + cache_config=cache_config, ) assert step.to_request() == { "Name": "MyProcessingStep", @@ -197,6 +200,7 @@ def test_processing_step(sagemaker_session): }, "RoleArn": "DummyRole", }, + "CacheConfig": {"Enabled": False, "ExpireAfter": "PT1H"}, } assert step.properties.ProcessingJobName.expr == { "Get": "Steps.MyProcessingStep.ProcessingJobName" @@ -238,10 +242,9 @@ def test_transform_step(sagemaker_session): sagemaker_session=sagemaker_session, ) inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest") + cache_config = CacheConfig(enable_caching=False, expire_after="PT1H") step = TransformStep( - name="MyTransformStep", - transformer=transformer, - inputs=inputs, + name="MyTransformStep", transformer=transformer, inputs=inputs, cache_config=cache_config ) assert step.to_request() == { "Name": "MyTransformStep", @@ -262,6 +265,7 @@ def test_transform_step(sagemaker_session): "InstanceType": "c4.4xlarge", }, }, + "CacheConfig": {"Enabled": False, "ExpireAfter": "PT1H"}, } assert step.properties.TransformJobName.expr == { "Get": "Steps.MyTransformStep.TransformJobName" From 1c2875b4f304ed1b47692fcd8c598b32b64afdc5 Mon Sep 17 00:00:00 2001 From: Ahsan Khan Date: Thu, 4 Feb 2021 07:34:12 -0800 Subject: [PATCH 02/10] modify integ test --- tests/integ/test_workflow.py | 218 +++++++++++++++++++++++++++-------- 1 file changed, 173 insertions(+), 45 deletions(-) diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index cb5d868f90..963eb40555 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -44,12 +44,7 @@ ParameterInteger, ParameterString, ) -from sagemaker.workflow.steps import ( - CreateModelStep, - ProcessingStep, - TrainingStep, - CacheConfig -) +from sagemaker.workflow.steps import CreateModelStep, ProcessingStep, TrainingStep, CacheConfig from sagemaker.workflow.step_collections import RegisterModel from sagemaker.workflow.pipeline import Pipeline from tests.integ import DATA_DIR @@ -554,76 +549,209 @@ def test_training_job_with_debugger( pass -def test_cache_hit_expired_entry( - sagemaker_session, - workflow_session, - region_name, - role, - script_dir, - pipeline_name, +def test_cache_hit( + sagemaker_session, + workflow_session, + region_name, + role, + script_dir, + pipeline_name, + athena_dataset_definition, ): + cache_config = CacheConfig(enable_caching=True, expire_after="T30m") + + framework_version = "0.20.0" instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") instance_count = ParameterInteger(name="InstanceCount", default_value=1) + input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv" - estimator = + sklearn_processor = SKLearnProcessor( + framework_version=framework_version, + instance_type=instance_type, + instance_count=instance_count, + base_job_name="test-sklearn", + sagemaker_session=sagemaker_session, + role=role, + ) - step_train = TrainingStep( - name="my-train", - estimator=sklearn_train, - inputs=TrainingInput( - s3_data=step_process.properties.ProcessingOutputConfig.Outputs[ - "train_data" - ].S3Output.S3Uri - ), - cache_config= + step_process = ProcessingStep( + name="my-cache-test", + processor=sklearn_processor, + inputs=[ + ProcessingInput(source=input_data, destination="/opt/ml/processing/input"), + ProcessingInput(dataset_definition=athena_dataset_definition), + ], + outputs=[ + ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"), + ProcessingOutput(output_name="test_data", source="/opt/ml/processing/test"), + ], + code=os.path.join(script_dir, "preprocessing.py"), + cache_config=cache_config, ) + pipeline = Pipeline( name=pipeline_name, - parameters=[instance_type, instance_count], - steps=[step_train], + parameters=[instance_count, instance_type], + steps=[step_process], sagemaker_session=workflow_session, ) try: - # NOTE: We should exercise the case when role used in the pipeline execution is - # different than that required of the steps in the pipeline itself. The role in - # the pipeline definition needs to create training and processing jobs and other - # sagemaker entities. However, the jobs created in the steps themselves execute - # under a potentially different role, often requiring access to S3 and other - # artifacts not required to during creation of the jobs in the pipeline steps. response = pipeline.create(role) create_arn = response["PipelineArn"] + pytest.set_trace() + assert re.match( - fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}", + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn, ) - pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)] - response = pipeline.update(role) - update_arn = response["PipelineArn"] + # Run pipeline for the first time to get an entry in the cache + execution1 = pipeline.start(parameters={}) assert re.match( - fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}", - update_arn, + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", + execution1.arn, ) - execution = pipeline.start(parameters={}) + response = execution1.describe() + assert response["PipelineArn"] == create_arn + + try: + execution1.wait(delay=30, max_attempts=10) + except WaiterError: + pass + execution1_steps = execution1.list_steps() + assert len(execution1_steps) == 1 + assert execution1_steps[0]["StepName"] == "my-cache-test" + + # Run pipeline for the second time and expect cache hit + execution2 = pipeline.start(parameters={}) assert re.match( - fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}/execution/", - execution.arn, + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", + execution2.arn, ) - response = execution.describe() + response = execution2.describe() assert response["PipelineArn"] == create_arn try: - execution.wait(delay=30, max_attempts=3) + execution2.wait(delay=30, max_attempts=10) except WaiterError: pass - execution_steps = execution.list_steps() - assert len(execution_steps) == 1 - assert execution_steps[0]["StepName"] == "sklearn-process" + execution2_steps = execution2.list_steps() + assert len(execution2_steps) == 1 + assert execution2_steps[0]["StepName"] == "my-cache-test" + + assert execution1_steps[0] == execution2_steps[0] + + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_cache_expiry( + sagemaker_session, + workflow_session, + region_name, + role, + script_dir, + pipeline_name, + athena_dataset_definition, +): + + cache_config = CacheConfig(enable_caching=True, expire_after="T1m") + + framework_version = "0.20.0" + instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") + instance_count = ParameterInteger(name="InstanceCount", default_value=1) + + input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv" + + sklearn_processor = SKLearnProcessor( + framework_version=framework_version, + instance_type=instance_type, + instance_count=instance_count, + base_job_name="test-sklearn", + sagemaker_session=sagemaker_session, + role=role, + ) + + step_process = ProcessingStep( + name="my-cache-test-expiry", + processor=sklearn_processor, + inputs=[ + ProcessingInput(source=input_data, destination="/opt/ml/processing/input"), + ProcessingInput(dataset_definition=athena_dataset_definition), + ], + outputs=[ + ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"), + ProcessingOutput(output_name="test_data", source="/opt/ml/processing/test"), + ], + code=os.path.join(script_dir, "preprocessing.py"), + cache_config=cache_config, + ) + + pipeline = Pipeline( + name=pipeline_name, + parameters=[instance_count, instance_type], + steps=[step_process], + sagemaker_session=workflow_session, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + + assert re.match( + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", + create_arn, + ) + + # Run pipeline for the first time to get an entry in the cache + execution1 = pipeline.start(parameters={}) + assert re.match( + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", + execution1.arn, + ) + + response = execution1.describe() + assert response["PipelineArn"] == create_arn + + try: + execution1.wait(delay=30, max_attempts=3) + except WaiterError: + pass + execution1_steps = execution1.list_steps() + assert len(execution1_steps) == 1 + assert execution1_steps[0]["StepName"] == "my-cache-test-expiry" + + # wait 1 minute for cache to expire + time.sleep(60) + + # Run pipeline for the second time and expect cache miss + execution2 = pipeline.start(parameters={}) + assert re.match( + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", + execution2.arn, + ) + + response = execution2.describe() + assert response["PipelineArn"] == create_arn + + try: + execution2.wait(delay=30, max_attempts=3) + except WaiterError: + pass + execution2_steps = execution2.list_steps() + assert len(execution2_steps) == 1 + assert execution2_steps[0]["StepName"] == "my-cache-test-expiry" + + assert execution1_steps[0] != execution2_steps[0] + finally: try: pipeline.delete() From a0d3d7baeeaf6483160229556d19b03faa91683d Mon Sep 17 00:00:00 2001 From: Ahsan Khan Date: Mon, 8 Feb 2021 15:19:14 -0800 Subject: [PATCH 03/10] modify unit test to enable cache to True --- tests/integ/test_workflow.py | 218 +------------------- tests/unit/sagemaker/workflow/test_steps.py | 12 +- 2 files changed, 14 insertions(+), 216 deletions(-) diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index 963eb40555..9fb26209d4 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -289,6 +289,8 @@ def test_one_step_sklearn_processing_pipeline( ProcessingInput(dataset_definition=athena_dataset_definition), ] + cache_config = CacheConfig(enable_caching=True, expire_after="T30m") + sklearn_processor = SKLearnProcessor( framework_version=sklearn_latest_version, role=role, @@ -304,6 +306,7 @@ def test_one_step_sklearn_processing_pipeline( processor=sklearn_processor, inputs=inputs, code=script_path, + cache_config=cache_config, ) pipeline = Pipeline( name=pipeline_name, @@ -343,6 +346,11 @@ def test_one_step_sklearn_processing_pipeline( response = execution.describe() assert response["PipelineArn"] == create_arn + # Check CacheConfig + response = json.loads(pipeline.describe()["PipelineDefinition"])["Steps"][0]["CacheConfig"] + assert response["Enabled"] == cache_config.enable_caching + assert response["ExpireAfter"] == cache_config.expire_after + try: execution.wait(delay=30, max_attempts=3) except WaiterError: @@ -547,213 +555,3 @@ def test_training_job_with_debugger( pipeline.delete() except Exception: pass - - -def test_cache_hit( - sagemaker_session, - workflow_session, - region_name, - role, - script_dir, - pipeline_name, - athena_dataset_definition, -): - - cache_config = CacheConfig(enable_caching=True, expire_after="T30m") - - framework_version = "0.20.0" - instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") - instance_count = ParameterInteger(name="InstanceCount", default_value=1) - - input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv" - - sklearn_processor = SKLearnProcessor( - framework_version=framework_version, - instance_type=instance_type, - instance_count=instance_count, - base_job_name="test-sklearn", - sagemaker_session=sagemaker_session, - role=role, - ) - - step_process = ProcessingStep( - name="my-cache-test", - processor=sklearn_processor, - inputs=[ - ProcessingInput(source=input_data, destination="/opt/ml/processing/input"), - ProcessingInput(dataset_definition=athena_dataset_definition), - ], - outputs=[ - ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"), - ProcessingOutput(output_name="test_data", source="/opt/ml/processing/test"), - ], - code=os.path.join(script_dir, "preprocessing.py"), - cache_config=cache_config, - ) - - pipeline = Pipeline( - name=pipeline_name, - parameters=[instance_count, instance_type], - steps=[step_process], - sagemaker_session=workflow_session, - ) - - try: - response = pipeline.create(role) - create_arn = response["PipelineArn"] - pytest.set_trace() - - assert re.match( - fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", - create_arn, - ) - - # Run pipeline for the first time to get an entry in the cache - execution1 = pipeline.start(parameters={}) - assert re.match( - fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", - execution1.arn, - ) - - response = execution1.describe() - assert response["PipelineArn"] == create_arn - - try: - execution1.wait(delay=30, max_attempts=10) - except WaiterError: - pass - execution1_steps = execution1.list_steps() - assert len(execution1_steps) == 1 - assert execution1_steps[0]["StepName"] == "my-cache-test" - - # Run pipeline for the second time and expect cache hit - execution2 = pipeline.start(parameters={}) - assert re.match( - fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", - execution2.arn, - ) - - response = execution2.describe() - assert response["PipelineArn"] == create_arn - - try: - execution2.wait(delay=30, max_attempts=10) - except WaiterError: - pass - execution2_steps = execution2.list_steps() - assert len(execution2_steps) == 1 - assert execution2_steps[0]["StepName"] == "my-cache-test" - - assert execution1_steps[0] == execution2_steps[0] - - finally: - try: - pipeline.delete() - except Exception: - pass - - -def test_cache_expiry( - sagemaker_session, - workflow_session, - region_name, - role, - script_dir, - pipeline_name, - athena_dataset_definition, -): - - cache_config = CacheConfig(enable_caching=True, expire_after="T1m") - - framework_version = "0.20.0" - instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") - instance_count = ParameterInteger(name="InstanceCount", default_value=1) - - input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv" - - sklearn_processor = SKLearnProcessor( - framework_version=framework_version, - instance_type=instance_type, - instance_count=instance_count, - base_job_name="test-sklearn", - sagemaker_session=sagemaker_session, - role=role, - ) - - step_process = ProcessingStep( - name="my-cache-test-expiry", - processor=sklearn_processor, - inputs=[ - ProcessingInput(source=input_data, destination="/opt/ml/processing/input"), - ProcessingInput(dataset_definition=athena_dataset_definition), - ], - outputs=[ - ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"), - ProcessingOutput(output_name="test_data", source="/opt/ml/processing/test"), - ], - code=os.path.join(script_dir, "preprocessing.py"), - cache_config=cache_config, - ) - - pipeline = Pipeline( - name=pipeline_name, - parameters=[instance_count, instance_type], - steps=[step_process], - sagemaker_session=workflow_session, - ) - - try: - response = pipeline.create(role) - create_arn = response["PipelineArn"] - - assert re.match( - fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", - create_arn, - ) - - # Run pipeline for the first time to get an entry in the cache - execution1 = pipeline.start(parameters={}) - assert re.match( - fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", - execution1.arn, - ) - - response = execution1.describe() - assert response["PipelineArn"] == create_arn - - try: - execution1.wait(delay=30, max_attempts=3) - except WaiterError: - pass - execution1_steps = execution1.list_steps() - assert len(execution1_steps) == 1 - assert execution1_steps[0]["StepName"] == "my-cache-test-expiry" - - # wait 1 minute for cache to expire - time.sleep(60) - - # Run pipeline for the second time and expect cache miss - execution2 = pipeline.start(parameters={}) - assert re.match( - fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", - execution2.arn, - ) - - response = execution2.describe() - assert response["PipelineArn"] == create_arn - - try: - execution2.wait(delay=30, max_attempts=3) - except WaiterError: - pass - execution2_steps = execution2.list_steps() - assert len(execution2_steps) == 1 - assert execution2_steps[0]["StepName"] == "my-cache-test-expiry" - - assert execution1_steps[0] != execution2_steps[0] - - finally: - try: - pipeline.delete() - except Exception: - pass diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index 9db2067bb6..6bb2586a7c 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -115,7 +115,7 @@ def test_training_step(sagemaker_session): sagemaker_session=sagemaker_session, ) inputs = TrainingInput(f"s3://{BUCKET}/train_manifest") - cache_config = CacheConfig(enable_caching=False, expire_after="PT1H") + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") step = TrainingStep( name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config ) @@ -145,7 +145,7 @@ def test_training_step(sagemaker_session): "RoleArn": ROLE, "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, }, - "CacheConfig": {"Enabled": False, "ExpireAfter": "PT1H"}, + "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, } assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"} @@ -164,7 +164,7 @@ def test_processing_step(sagemaker_session): destination="processing_manifest", ) ] - cache_config = CacheConfig(enable_caching=False, expire_after="PT1H") + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") step = ProcessingStep( name="MyProcessingStep", processor=processor, @@ -200,7 +200,7 @@ def test_processing_step(sagemaker_session): }, "RoleArn": "DummyRole", }, - "CacheConfig": {"Enabled": False, "ExpireAfter": "PT1H"}, + "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, } assert step.properties.ProcessingJobName.expr == { "Get": "Steps.MyProcessingStep.ProcessingJobName" @@ -242,7 +242,7 @@ def test_transform_step(sagemaker_session): sagemaker_session=sagemaker_session, ) inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest") - cache_config = CacheConfig(enable_caching=False, expire_after="PT1H") + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") step = TransformStep( name="MyTransformStep", transformer=transformer, inputs=inputs, cache_config=cache_config ) @@ -265,7 +265,7 @@ def test_transform_step(sagemaker_session): "InstanceType": "c4.4xlarge", }, }, - "CacheConfig": {"Enabled": False, "ExpireAfter": "PT1H"}, + "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, } assert step.properties.TransformJobName.expr == { "Get": "Steps.MyTransformStep.TransformJobName" From f58941a6568d8ce3d21a3f701f1f0e6642abd4d5 Mon Sep 17 00:00:00 2001 From: Ahsan Khan Date: Tue, 9 Feb 2021 09:18:36 -0800 Subject: [PATCH 04/10] fix errors --- src/sagemaker/workflow/steps.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 7545390358..28d7010cb3 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -64,7 +64,6 @@ class Step(Entity): Attributes: name (str): The name of the step. step_type (StepTypeEnum): The type of the step. - """ name: str = attr.ib(factory=str) @@ -133,7 +132,7 @@ def __init__( name (str): The name of the training step. estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance. inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`. - cache_config (CacheConfig): An instance to enable caching. + cache_config (CacheConfig): A `sagemaker.steps.CacheConfig` instance to enable caching. """ super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING) self.estimator = estimator @@ -171,7 +170,8 @@ def properties(self): def to_request(self) -> RequestType: """Updates the dictionary with cache configuration.""" request_dict = super().to_request() - request_dict.update(self.cache_config.config) + if self.cache_config: + request_dict.update(self.cache_config.config) return request_dict @@ -293,7 +293,8 @@ def properties(self): def to_request(self) -> RequestType: """Updates the dictionary with cache configuration.""" request_dict = super().to_request() - request_dict.update(self.cache_config.config) + if self.cache_config: + request_dict.update(self.cache_config.config) return request_dict @@ -378,7 +379,8 @@ def properties(self): def to_request(self) -> RequestType: """Get the request structure for workflow service calls.""" request_dict = super(ProcessingStep, self).to_request() - request_dict.update(self.cache_config.config) + if self.cache_config: + request_dict.update(self.cache_config.config) if self.property_files: request_dict["PropertyFiles"] = [ property_file.expr for property_file in self.property_files From ccf4c1e848c5c3fa5d2f7335b78d2dab396c9401 Mon Sep 17 00:00:00 2001 From: Ahsan Khan Date: Wed, 10 Feb 2021 13:05:27 -0800 Subject: [PATCH 05/10] fix: address requested changes --- src/sagemaker/workflow/steps.py | 28 ++++++++++++++++++++-------- tests/integ/test_workflow.py | 7 ++++++- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 28d7010cb3..366912d3cc 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -95,21 +95,33 @@ def ref(self) -> Dict[str, str]: @attr.s class CacheConfig: - """Step to cache pipeline workflow. + """Configure steps to enable cache in pipeline workflow. + + If caching is enabled, the pipeline attempts to find a previous execution of a step. + If a successful previous execution is found, the pipeline propagates the values + from previous execution rather than recomputing the step. + Attributes: - enable_caching (bool): To enable step caching. Off by default. + enable_caching (bool): To enable step caching. Defaults to `False`. expire_after (str): If step caching is enabled, a timeout also needs to defined. It defines how old a previous execution can be to be considered for reuse. - Needs to be ISO 8601 duration string. + Value should be an ISO 8601 duration string. + If step caching is disabled, it defaults to an empty string. """ enable_caching: bool = attr.ib(default=False) - expire_after: str = attr.ib(factory=str) + expire_after = attr.ib(default="") + + @expire_after.validator + def validate_expire_after(self, enable_caching, expire_after): + """Validates ISO 8601 duration string.""" + if enable_caching and expire_after == "": + raise ValueError("expire_after must be an ISO 8601 duration string") @property def config(self): - """Enables caching in pipeline steps.""" + """Configures caching in pipeline steps.""" return {"CacheConfig": {"Enabled": self.enable_caching, "ExpireAfter": self.expire_after}} @@ -132,7 +144,7 @@ def __init__( name (str): The name of the training step. estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance. inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`. - cache_config (CacheConfig): A `sagemaker.steps.CacheConfig` instance to enable caching. + cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. """ super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING) self.estimator = estimator @@ -249,7 +261,7 @@ def __init__( name (str): The name of the transform step. transformer (Transformer): A `sagemaker.transformer.Transformer` instance. inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance. - cache_config (CacheConfig): An instance to enable caching. + cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. """ super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM) self.transformer = transformer @@ -331,7 +343,7 @@ def __init__( script to run. Defaults to `None`. property_files (List[PropertyFile]): A list of property files that workflow looks for and resolves from the configured processing output list. - cache_config (CacheConfig): An instance to enable caching. + cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. """ super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING) self.processor = processor diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index 9fb26209d4..3a49aaac46 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -44,7 +44,12 @@ ParameterInteger, ParameterString, ) -from sagemaker.workflow.steps import CreateModelStep, ProcessingStep, TrainingStep, CacheConfig +from sagemaker.workflow.steps import ( + CreateModelStep, + ProcessingStep, + TrainingStep, + CacheConfig, +) from sagemaker.workflow.step_collections import RegisterModel from sagemaker.workflow.pipeline import Pipeline from tests.integ import DATA_DIR From ccf04311b32d4da852f9df2e47ba3a579400ad95 Mon Sep 17 00:00:00 2001 From: Ahsan Khan Date: Thu, 11 Feb 2021 16:42:57 -0800 Subject: [PATCH 06/10] add optional validator to expire_after and change docstring --- src/sagemaker/workflow/steps.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 366912d3cc..884aff4e89 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -97,27 +97,24 @@ def ref(self) -> Dict[str, str]: class CacheConfig: """Configure steps to enable cache in pipeline workflow. - If caching is enabled, the pipeline attempts to find a previous execution of a step. + If caching is enabled, the pipeline attempts to find a previous execution of a step + that was called with the same arguments. Step caching only considers successful execution. If a successful previous execution is found, the pipeline propagates the values - from previous execution rather than recomputing the step. + from previous execution rather than recomputing the step. When multiple successful executions + exist within the timeout period, it uses the result for the most recent successful execution. Attributes: enable_caching (bool): To enable step caching. Defaults to `False`. expire_after (str): If step caching is enabled, a timeout also needs to defined. It defines how old a previous execution can be to be considered for reuse. - Value should be an ISO 8601 duration string. - If step caching is disabled, it defaults to an empty string. + Value should be an ISO 8601 duration string. Defaults to `None`. """ enable_caching: bool = attr.ib(default=False) - expire_after = attr.ib(default="") - - @expire_after.validator - def validate_expire_after(self, enable_caching, expire_after): - """Validates ISO 8601 duration string.""" - if enable_caching and expire_after == "": - raise ValueError("expire_after must be an ISO 8601 duration string") + expire_after = attr.ib( + default=None, validator=attr.validators.optional(attr.validators.instance_of(str)) + ) @property def config(self): From d7d7ad619cda2dddba34ffa158ea0d91e10e6028 Mon Sep 17 00:00:00 2001 From: Eric Johnson <65414824+metrizable@users.noreply.github.com> Date: Mon, 15 Feb 2021 13:10:49 -0800 Subject: [PATCH 07/10] infra: lower test TPS for experiment analytics (#2145) --- tests/integ/test_experiments_analytics.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integ/test_experiments_analytics.py b/tests/integ/test_experiments_analytics.py index 7eff0a1746..e66dd71dd7 100644 --- a/tests/integ/test_experiments_analytics.py +++ b/tests/integ/test_experiments_analytics.py @@ -35,6 +35,7 @@ def experiment(sagemaker_session): sm.associate_trial_component( TrialComponentName=trial_component_name, TrialName=trial_name ) + time.sleep(1) time.sleep(15) # wait for search to get updated @@ -87,6 +88,7 @@ def experiment_with_artifacts(sagemaker_session): sm.associate_trial_component( TrialComponentName=trial_component_name, TrialName=trial_name ) + time.sleep(1) time.sleep(15) # wait for search to get updated From 5edac4b67a8fbc00c7f08e1de02cfa65b0d11513 Mon Sep 17 00:00:00 2001 From: Ahsan Khan Date: Mon, 15 Feb 2021 14:05:25 -0800 Subject: [PATCH 08/10] fix: make expire_after optional --- src/sagemaker/workflow/steps.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 884aff4e89..154b574bd6 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -119,7 +119,10 @@ class CacheConfig: @property def config(self): """Configures caching in pipeline steps.""" - return {"CacheConfig": {"Enabled": self.enable_caching, "ExpireAfter": self.expire_after}} + config = {'Enabled': self.enable_caching} + if self.expire_after is not None: + config['ExpireAfter'] = self.expire_after + return {'CacheConfig': config} class TrainingStep(Step): From c5d37a7eaa7a5bd0985863deb8d563f60dcfecd8 Mon Sep 17 00:00:00 2001 From: Ahsan Khan Date: Mon, 15 Feb 2021 14:22:04 -0800 Subject: [PATCH 09/10] fix unit test --- src/sagemaker/workflow/steps.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 154b574bd6..ed6a9b928d 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -119,10 +119,10 @@ class CacheConfig: @property def config(self): """Configures caching in pipeline steps.""" - config = {'Enabled': self.enable_caching} + config = {"Enabled": self.enable_caching} if self.expire_after is not None: - config['ExpireAfter'] = self.expire_after - return {'CacheConfig': config} + config["ExpireAfter"] = self.expire_after + return {"CacheConfig": config} class TrainingStep(Step): From b7d24e3760a25c37401ccff133a1a8406ed4a4fc Mon Sep 17 00:00:00 2001 From: Ahsan Khan Date: Wed, 17 Feb 2021 15:15:32 -0500 Subject: [PATCH 10/10] Update docstring Co-authored-by: Eric Johnson <65414824+metrizable@users.noreply.github.com> --- src/sagemaker/workflow/steps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index ed6a9b928d..af4eb79a9b 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -95,7 +95,7 @@ def ref(self) -> Dict[str, str]: @attr.s class CacheConfig: - """Configure steps to enable cache in pipeline workflow. + """Configuration class to enable caching in pipeline workflow. If caching is enabled, the pipeline attempts to find a previous execution of a step that was called with the same arguments. Step caching only considers successful execution.