Skip to content

feature: Enable step caching #2130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e1d0d45
add cache config and unit test
ahsan-z-khan Jan 21, 2021
1c2875b
modify integ test
ahsan-z-khan Feb 4, 2021
a0d3d7b
modify unit test to enable cache to True
ahsan-z-khan Feb 8, 2021
0a8ab6a
Merge branch 'master' into enable-step-caching
ahsan-z-khan Feb 9, 2021
f58941a
fix errors
ahsan-z-khan Feb 9, 2021
f5c0538
Merge branch 'master' into enable-step-caching
ahsan-z-khan Feb 9, 2021
ccf4c1e
fix: address requested changes
ahsan-z-khan Feb 10, 2021
ccf0431
add optional validator to expire_after and change docstring
ahsan-z-khan Feb 12, 2021
08660da
Merge branch 'master' into enable-step-caching
ahsan-z-khan Feb 12, 2021
feed6c5
Merge branch 'master' into enable-step-caching
ahsan-z-khan Feb 15, 2021
d7d7ad6
infra: lower test TPS for experiment analytics (#2145)
metrizable Feb 15, 2021
5edac4b
fix: make expire_after optional
ahsan-z-khan Feb 15, 2021
fc70d20
Merge branch 'master' into enable-step-caching
ahsan-z-khan Feb 15, 2021
c5d37a7
fix unit test
ahsan-z-khan Feb 15, 2021
d2305ca
Merge branch 'master' into enable-step-caching
ahsan-z-khan Feb 15, 2021
486362b
Merge branch 'master' into enable-step-caching
ahsan-z-khan Feb 16, 2021
0ac104d
Merge branch 'master' into enable-step-caching
ahsan-z-khan Feb 17, 2021
b7d24e3
Update docstring
ahsan-z-khan Feb 17, 2021
9b2db2a
Merge branch 'master' into enable-step-caching
metrizable Feb 17, 2021
d7bed9a
Merge branch 'master' into enable-step-caching
metrizable Feb 17, 2021
5288315
Merge branch 'master' into enable-step-caching
ahsan-z-khan Feb 18, 2021
ba2a795
Merge branch 'master' into enable-step-caching
ahsan-z-khan Feb 18, 2021
fdabaca
Merge branch 'master' into enable-step-caching
ahsan-z-khan Feb 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions src/sagemaker/workflow/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,35 @@ def ref(self) -> Dict[str, str]:
return {"Name": self.name}


@attr.s
class CacheConfig:
"""Configure steps to enable cache in pipeline workflow.

If caching is enabled, the pipeline attempts to find a previous execution of a step
that was called with the same arguments. Step caching only considers successful execution.
If a successful previous execution is found, the pipeline propagates the values
from previous execution rather than recomputing the step. When multiple successful executions
exist within the timeout period, it uses the result for the most recent successful execution.


Attributes:
enable_caching (bool): To enable step caching. Defaults to `False`.
expire_after (str): If step caching is enabled, a timeout also needs to defined.
It defines how old a previous execution can be to be considered for reuse.
Value should be an ISO 8601 duration string. Defaults to `None`.
"""

enable_caching: bool = attr.ib(default=False)
expire_after = attr.ib(
default=None, validator=attr.validators.optional(attr.validators.instance_of(str))
)

@property
def config(self):
"""Configures caching in pipeline steps."""
return {"CacheConfig": {"Enabled": self.enable_caching, "ExpireAfter": self.expire_after}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ExpireAfter is None, then don't include it as a key.



class TrainingStep(Step):
"""Training step for workflow."""

Expand All @@ -101,6 +130,7 @@ def __init__(
name: str,
estimator: EstimatorBase,
inputs: TrainingInput = None,
cache_config: CacheConfig = None,
):
"""Construct a TrainingStep, given an `EstimatorBase` instance.

Expand All @@ -111,14 +141,15 @@ def __init__(
name (str): The name of the training step.
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
"""
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING)
self.estimator = estimator
self.inputs = inputs

self._properties = Properties(
path=f"Steps.{name}", shape_name="DescribeTrainingJobResponse"
)
self.cache_config = cache_config

@property
def arguments(self) -> RequestType:
Expand All @@ -145,6 +176,14 @@ def properties(self):
"""A Properties object representing the DescribeTrainingJobResponse data model."""
return self._properties

def to_request(self) -> RequestType:
"""Updates the dictionary with cache configuration."""
request_dict = super().to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)

return request_dict


class CreateModelStep(Step):
"""CreateModel step for workflow."""
Expand Down Expand Up @@ -208,6 +247,7 @@ def __init__(
name: str,
transformer: Transformer,
inputs: TransformInput,
cache_config: CacheConfig = None,
):
"""Constructs a TransformStep, given an `Transformer` instance.

Expand All @@ -218,11 +258,12 @@ def __init__(
name (str): The name of the transform step.
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
"""
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM)
self.transformer = transformer
self.inputs = inputs

self.cache_config = cache_config
self._properties = Properties(
path=f"Steps.{name}", shape_name="DescribeTransformJobResponse"
)
Expand Down Expand Up @@ -258,6 +299,14 @@ def properties(self):
"""A Properties object representing the DescribeTransformJobResponse data model."""
return self._properties

def to_request(self) -> RequestType:
"""Updates the dictionary with cache configuration."""
request_dict = super().to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)

return request_dict


class ProcessingStep(Step):
"""Processing step for workflow."""
Expand All @@ -271,6 +320,7 @@ def __init__(
job_arguments: List[str] = None,
code: str = None,
property_files: List[PropertyFile] = None,
cache_config: CacheConfig = None,
):
"""Construct a ProcessingStep, given a `Processor` instance.

Expand All @@ -290,6 +340,7 @@ def __init__(
script to run. Defaults to `None`.
property_files (List[PropertyFile]): A list of property files that workflow looks
for and resolves from the configured processing output list.
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
"""
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING)
self.processor = processor
Expand All @@ -306,6 +357,7 @@ def __init__(
self._properties = Properties(
path=f"Steps.{name}", shape_name="DescribeProcessingJobResponse"
)
self.cache_config = cache_config

@property
def arguments(self) -> RequestType:
Expand Down Expand Up @@ -336,6 +388,8 @@ def properties(self):
def to_request(self) -> RequestType:
"""Get the request structure for workflow service calls."""
request_dict = super(ProcessingStep, self).to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)
if self.property_files:
request_dict["PropertyFiles"] = [
property_file.expr for property_file in self.property_files
Expand Down
9 changes: 9 additions & 0 deletions tests/integ/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
CreateModelStep,
ProcessingStep,
TrainingStep,
CacheConfig,
)
from sagemaker.workflow.step_collections import RegisterModel
from sagemaker.workflow.pipeline import Pipeline
Expand Down Expand Up @@ -293,6 +294,8 @@ def test_one_step_sklearn_processing_pipeline(
ProcessingInput(dataset_definition=athena_dataset_definition),
]

cache_config = CacheConfig(enable_caching=True, expire_after="T30m")

sklearn_processor = SKLearnProcessor(
framework_version=sklearn_latest_version,
role=role,
Expand All @@ -308,6 +311,7 @@ def test_one_step_sklearn_processing_pipeline(
processor=sklearn_processor,
inputs=inputs,
code=script_path,
cache_config=cache_config,
)
pipeline = Pipeline(
name=pipeline_name,
Expand Down Expand Up @@ -347,6 +351,11 @@ def test_one_step_sklearn_processing_pipeline(
response = execution.describe()
assert response["PipelineArn"] == create_arn

# Check CacheConfig
response = json.loads(pipeline.describe()["PipelineDefinition"])["Steps"][0]["CacheConfig"]
assert response["Enabled"] == cache_config.enable_caching
assert response["ExpireAfter"] == cache_config.expire_after

try:
execution.wait(delay=30, max_attempts=3)
except WaiterError:
Expand Down
16 changes: 10 additions & 6 deletions tests/unit/sagemaker/workflow/test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
TrainingStep,
TransformStep,
CreateModelStep,
CacheConfig,
)

REGION = "us-west-2"
Expand Down Expand Up @@ -114,10 +115,9 @@ def test_training_step(sagemaker_session):
sagemaker_session=sagemaker_session,
)
inputs = TrainingInput(f"s3://{BUCKET}/train_manifest")
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
step = TrainingStep(
name="MyTrainingStep",
estimator=estimator,
inputs=inputs,
name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config
)
assert step.to_request() == {
"Name": "MyTrainingStep",
Expand Down Expand Up @@ -145,6 +145,7 @@ def test_training_step(sagemaker_session):
"RoleArn": ROLE,
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
},
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
}
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}

Expand All @@ -163,11 +164,13 @@ def test_processing_step(sagemaker_session):
destination="processing_manifest",
)
]
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
step = ProcessingStep(
name="MyProcessingStep",
processor=processor,
inputs=inputs,
outputs=[],
cache_config=cache_config,
)
assert step.to_request() == {
"Name": "MyProcessingStep",
Expand Down Expand Up @@ -197,6 +200,7 @@ def test_processing_step(sagemaker_session):
},
"RoleArn": "DummyRole",
},
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
}
assert step.properties.ProcessingJobName.expr == {
"Get": "Steps.MyProcessingStep.ProcessingJobName"
Expand Down Expand Up @@ -238,10 +242,9 @@ def test_transform_step(sagemaker_session):
sagemaker_session=sagemaker_session,
)
inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest")
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
step = TransformStep(
name="MyTransformStep",
transformer=transformer,
inputs=inputs,
name="MyTransformStep", transformer=transformer, inputs=inputs, cache_config=cache_config
)
assert step.to_request() == {
"Name": "MyTransformStep",
Expand All @@ -262,6 +265,7 @@ def test_transform_step(sagemaker_session):
"InstanceType": "c4.4xlarge",
},
},
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
}
assert step.properties.TransformJobName.expr == {
"Get": "Steps.MyTransformStep.TransformJobName"
Expand Down