Skip to content

Commit 51c5344

Browse files
feature: Enable step caching (#2130)
* add cache config and unit test * modify integ test * modify unit test to enable cache to True * fix errors * fix: address requested changes * add optional validator to expire_after and change docstring * infra: lower test TPS for experiment analytics (#2145) * fix: make expire_after optional * fix unit test * Update docstring Co-authored-by: Eric Johnson <[email protected]> Co-authored-by: Eric Johnson <[email protected]>
1 parent e8ccb03 commit 51c5344

File tree

3 files changed

+78
-8
lines changed

3 files changed

+78
-8
lines changed

src/sagemaker/workflow/steps.py

+59-2
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,38 @@ def ref(self) -> Dict[str, str]:
9292
return {"Name": self.name}
9393

9494

95+
@attr.s
96+
class CacheConfig:
97+
"""Configuration class to enable caching in pipeline workflow.
98+
99+
If caching is enabled, the pipeline attempts to find a previous execution of a step
100+
that was called with the same arguments. Step caching only considers successful execution.
101+
If a successful previous execution is found, the pipeline propagates the values
102+
from previous execution rather than recomputing the step. When multiple successful executions
103+
exist within the timeout period, it uses the result for the most recent successful execution.
104+
105+
106+
Attributes:
107+
enable_caching (bool): To enable step caching. Defaults to `False`.
108+
expire_after (str): If step caching is enabled, a timeout also needs to defined.
109+
It defines how old a previous execution can be to be considered for reuse.
110+
Value should be an ISO 8601 duration string. Defaults to `None`.
111+
"""
112+
113+
enable_caching: bool = attr.ib(default=False)
114+
expire_after = attr.ib(
115+
default=None, validator=attr.validators.optional(attr.validators.instance_of(str))
116+
)
117+
118+
@property
119+
def config(self):
120+
"""Configures caching in pipeline steps."""
121+
config = {"Enabled": self.enable_caching}
122+
if self.expire_after is not None:
123+
config["ExpireAfter"] = self.expire_after
124+
return {"CacheConfig": config}
125+
126+
95127
class TrainingStep(Step):
96128
"""Training step for workflow."""
97129

@@ -100,6 +132,7 @@ def __init__(
100132
name: str,
101133
estimator: EstimatorBase,
102134
inputs: TrainingInput = None,
135+
cache_config: CacheConfig = None,
103136
):
104137
"""Construct a TrainingStep, given an `EstimatorBase` instance.
105138
@@ -110,14 +143,15 @@ def __init__(
110143
name (str): The name of the training step.
111144
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
112145
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
146+
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
113147
"""
114148
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING)
115149
self.estimator = estimator
116150
self.inputs = inputs
117-
118151
self._properties = Properties(
119152
path=f"Steps.{name}", shape_name="DescribeTrainingJobResponse"
120153
)
154+
self.cache_config = cache_config
121155

122156
@property
123157
def arguments(self) -> RequestType:
@@ -144,6 +178,14 @@ def properties(self):
144178
"""A Properties object representing the DescribeTrainingJobResponse data model."""
145179
return self._properties
146180

181+
def to_request(self) -> RequestType:
182+
"""Updates the dictionary with cache configuration."""
183+
request_dict = super().to_request()
184+
if self.cache_config:
185+
request_dict.update(self.cache_config.config)
186+
187+
return request_dict
188+
147189

148190
class CreateModelStep(Step):
149191
"""CreateModel step for workflow."""
@@ -207,6 +249,7 @@ def __init__(
207249
name: str,
208250
transformer: Transformer,
209251
inputs: TransformInput,
252+
cache_config: CacheConfig = None,
210253
):
211254
"""Constructs a TransformStep, given an `Transformer` instance.
212255
@@ -217,11 +260,12 @@ def __init__(
217260
name (str): The name of the transform step.
218261
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
219262
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
263+
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
220264
"""
221265
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM)
222266
self.transformer = transformer
223267
self.inputs = inputs
224-
268+
self.cache_config = cache_config
225269
self._properties = Properties(
226270
path=f"Steps.{name}", shape_name="DescribeTransformJobResponse"
227271
)
@@ -257,6 +301,14 @@ def properties(self):
257301
"""A Properties object representing the DescribeTransformJobResponse data model."""
258302
return self._properties
259303

304+
def to_request(self) -> RequestType:
305+
"""Updates the dictionary with cache configuration."""
306+
request_dict = super().to_request()
307+
if self.cache_config:
308+
request_dict.update(self.cache_config.config)
309+
310+
return request_dict
311+
260312

261313
class ProcessingStep(Step):
262314
"""Processing step for workflow."""
@@ -270,6 +322,7 @@ def __init__(
270322
job_arguments: List[str] = None,
271323
code: str = None,
272324
property_files: List[PropertyFile] = None,
325+
cache_config: CacheConfig = None,
273326
):
274327
"""Construct a ProcessingStep, given a `Processor` instance.
275328
@@ -289,6 +342,7 @@ def __init__(
289342
script to run. Defaults to `None`.
290343
property_files (List[PropertyFile]): A list of property files that workflow looks
291344
for and resolves from the configured processing output list.
345+
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
292346
"""
293347
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING)
294348
self.processor = processor
@@ -305,6 +359,7 @@ def __init__(
305359
self._properties = Properties(
306360
path=f"Steps.{name}", shape_name="DescribeProcessingJobResponse"
307361
)
362+
self.cache_config = cache_config
308363

309364
@property
310365
def arguments(self) -> RequestType:
@@ -335,6 +390,8 @@ def properties(self):
335390
def to_request(self) -> RequestType:
336391
"""Get the request structure for workflow service calls."""
337392
request_dict = super(ProcessingStep, self).to_request()
393+
if self.cache_config:
394+
request_dict.update(self.cache_config.config)
338395
if self.property_files:
339396
request_dict["PropertyFiles"] = [
340397
property_file.expr for property_file in self.property_files

tests/integ/test_workflow.py

+9
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
CreateModelStep,
4747
ProcessingStep,
4848
TrainingStep,
49+
CacheConfig,
4950
)
5051
from sagemaker.workflow.step_collections import RegisterModel
5152
from sagemaker.workflow.pipeline import Pipeline
@@ -274,6 +275,8 @@ def test_one_step_sklearn_processing_pipeline(
274275
ProcessingInput(dataset_definition=athena_dataset_definition),
275276
]
276277

278+
cache_config = CacheConfig(enable_caching=True, expire_after="T30m")
279+
277280
sklearn_processor = SKLearnProcessor(
278281
framework_version=sklearn_latest_version,
279282
role=role,
@@ -289,6 +292,7 @@ def test_one_step_sklearn_processing_pipeline(
289292
processor=sklearn_processor,
290293
inputs=inputs,
291294
code=script_path,
295+
cache_config=cache_config,
292296
)
293297
pipeline = Pipeline(
294298
name=pipeline_name,
@@ -328,6 +332,11 @@ def test_one_step_sklearn_processing_pipeline(
328332
response = execution.describe()
329333
assert response["PipelineArn"] == create_arn
330334

335+
# Check CacheConfig
336+
response = json.loads(pipeline.describe()["PipelineDefinition"])["Steps"][0]["CacheConfig"]
337+
assert response["Enabled"] == cache_config.enable_caching
338+
assert response["ExpireAfter"] == cache_config.expire_after
339+
331340
try:
332341
execution.wait(delay=30, max_attempts=3)
333342
except WaiterError:

tests/unit/sagemaker/workflow/test_steps.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
TrainingStep,
3838
TransformStep,
3939
CreateModelStep,
40+
CacheConfig,
4041
)
4142

4243
REGION = "us-west-2"
@@ -114,10 +115,9 @@ def test_training_step(sagemaker_session):
114115
sagemaker_session=sagemaker_session,
115116
)
116117
inputs = TrainingInput(f"s3://{BUCKET}/train_manifest")
118+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
117119
step = TrainingStep(
118-
name="MyTrainingStep",
119-
estimator=estimator,
120-
inputs=inputs,
120+
name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config
121121
)
122122
assert step.to_request() == {
123123
"Name": "MyTrainingStep",
@@ -145,6 +145,7 @@ def test_training_step(sagemaker_session):
145145
"RoleArn": ROLE,
146146
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
147147
},
148+
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
148149
}
149150
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
150151

@@ -163,11 +164,13 @@ def test_processing_step(sagemaker_session):
163164
destination="processing_manifest",
164165
)
165166
]
167+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
166168
step = ProcessingStep(
167169
name="MyProcessingStep",
168170
processor=processor,
169171
inputs=inputs,
170172
outputs=[],
173+
cache_config=cache_config,
171174
)
172175
assert step.to_request() == {
173176
"Name": "MyProcessingStep",
@@ -197,6 +200,7 @@ def test_processing_step(sagemaker_session):
197200
},
198201
"RoleArn": "DummyRole",
199202
},
203+
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
200204
}
201205
assert step.properties.ProcessingJobName.expr == {
202206
"Get": "Steps.MyProcessingStep.ProcessingJobName"
@@ -238,10 +242,9 @@ def test_transform_step(sagemaker_session):
238242
sagemaker_session=sagemaker_session,
239243
)
240244
inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest")
245+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
241246
step = TransformStep(
242-
name="MyTransformStep",
243-
transformer=transformer,
244-
inputs=inputs,
247+
name="MyTransformStep", transformer=transformer, inputs=inputs, cache_config=cache_config
245248
)
246249
assert step.to_request() == {
247250
"Name": "MyTransformStep",
@@ -262,6 +265,7 @@ def test_transform_step(sagemaker_session):
262265
"InstanceType": "c4.4xlarge",
263266
},
264267
},
268+
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
265269
}
266270
assert step.properties.TransformJobName.expr == {
267271
"Get": "Steps.MyTransformStep.TransformJobName"

0 commit comments

Comments
 (0)