Skip to content

Commit 636f74b

Browse files
author
Payton Staub
committed
fix: Set ProcessingStep upload locations deterministically to avoid cache misses on pipeline upsert. Add a warning to cache-enabled TrainingSteps with profiling enabled
1 parent 2beb91e commit 636f74b

File tree

3 files changed

+248
-26
lines changed

3 files changed

+248
-26
lines changed

src/sagemaker/workflow/steps.py

+32
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
from __future__ import absolute_import
1515

1616
import abc
17+
import warnings
1718
from enum import Enum
1819
from typing import Dict, List, Union
20+
from urllib.parse import urlparse
1921

2022
import attr
2123

@@ -270,6 +272,16 @@ def __init__(
270272
)
271273
self.cache_config = cache_config
272274

275+
if self.cache_config is not None and not self.estimator.disable_profiler:
276+
msg = (
277+
"Profiling is enabled on the provided estimator. "
278+
"The default profiler rule includes a timestamp "
279+
"which will change each time the pipeline is "
280+
"upserted, causing cache misses. If profiling "
281+
"is not needed, set disable_profiler to True on the estimator."
282+
)
283+
warnings.warn(msg)
284+
273285
@property
274286
def arguments(self) -> RequestType:
275287
"""The arguments dict that is used to call `create_training_job`.
@@ -498,6 +510,7 @@ def __init__(
498510
self.job_arguments = job_arguments
499511
self.code = code
500512
self.property_files = property_files
513+
self.job_name = None
501514

502515
# Examine why run method in sagemaker.processing.Processor mutates the processor instance
503516
# by setting the instance's arguments attribute. Refactor Processor.run, if possible.
@@ -508,6 +521,17 @@ def __init__(
508521
)
509522
self.cache_config = cache_config
510523

524+
if code:
525+
code_url = urlparse(code)
526+
if code_url.scheme == "" or code_url.scheme == "file":
527+
# By default, Processor will upload the local code to an S3 path
528+
# containing a timestamp. This causes cache misses whenever a
529+
# pipeline is updated, even if the underlying script hasn't changed.
530+
# To avoid this, hash the contents of the script and include it
531+
# in the job_name passed to the Processor, which will be used
532+
# instead of the timestamped path.
533+
self.job_name = self._generate_code_upload_path()
534+
511535
@property
512536
def arguments(self) -> RequestType:
513537
"""The arguments dict that is used to call `create_processing_job`.
@@ -516,6 +540,7 @@ def arguments(self) -> RequestType:
516540
ProcessingJobName and ExperimentConfig cannot be included in the arguments.
517541
"""
518542
normalized_inputs, normalized_outputs = self.processor._normalize_args(
543+
job_name=self.job_name,
519544
arguments=self.job_arguments,
520545
inputs=self.inputs,
521546
outputs=self.outputs,
@@ -546,6 +571,13 @@ def to_request(self) -> RequestType:
546571
]
547572
return request_dict
548573

574+
def _generate_code_upload_path(self) -> str:
575+
"""Generate an upload path for local processing scripts based on its contents"""
576+
from sagemaker.workflow.utilities import hash_file
577+
578+
code_hash = hash_file(self.code)
579+
return f"{self.name}-{code_hash}"[:1024]
580+
549581

550582
class TuningStep(ConfigurableRetryStep):
551583
"""Tuning step for workflow."""

src/sagemaker/workflow/utilities.py

+21
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from typing import List, Sequence, Union
17+
import hashlib
1718

1819
from sagemaker.workflow.entities import (
1920
Entity,
@@ -37,3 +38,23 @@ def list_to_request(entities: Sequence[Union[Entity, StepCollection]]) -> List[R
3738
elif isinstance(entity, StepCollection):
3839
request_dicts.extend(entity.request_dicts())
3940
return request_dicts
41+
42+
43+
def hash_file(path: str) -> str:
44+
"""Get the MD5 hash of a file.
45+
46+
Args:
47+
path (str): The local path for the file.
48+
Returns:
49+
str: The MD5 hash of the file.
50+
"""
51+
BUF_SIZE = 65536 # read in 64KiB chunks
52+
md5 = hashlib.md5()
53+
with open(path, "rb") as f:
54+
while True:
55+
data = f.read(BUF_SIZE)
56+
if not data:
57+
break
58+
md5.update(data)
59+
60+
return md5.hexdigest()

tests/unit/sagemaker/workflow/test_steps.py

+195-26
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pytest
1717
import sagemaker
1818
import os
19+
import warnings
1920

2021
from mock import (
2122
Mock,
@@ -63,8 +64,7 @@
6364
)
6465
from tests.unit import DATA_DIR
6566

66-
SCRIPT_FILE = "dummy_script.py"
67-
SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_FILE)
67+
DUMMY_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
6868

6969
REGION = "us-west-2"
7070
BUCKET = "my-bucket"
@@ -129,6 +129,31 @@ def sagemaker_session(boto_session, client):
129129
)
130130

131131

132+
@pytest.fixture
133+
def script_processor(sagemaker_session):
134+
return ScriptProcessor(
135+
role=ROLE,
136+
image_uri="012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri",
137+
command=["python3"],
138+
instance_type="ml.m4.xlarge",
139+
instance_count=1,
140+
volume_size_in_gb=100,
141+
volume_kms_key="arn:aws:kms:us-west-2:012345678901:key/volume-kms-key",
142+
output_kms_key="arn:aws:kms:us-west-2:012345678901:key/output-kms-key",
143+
max_runtime_in_seconds=3600,
144+
base_job_name="my_sklearn_processor",
145+
env={"my_env_variable": "my_env_variable_value"},
146+
tags=[{"Key": "my-tag", "Value": "my-tag-value"}],
147+
network_config=NetworkConfig(
148+
subnets=["my_subnet_id"],
149+
security_group_ids=["my_security_group_id"],
150+
enable_network_isolation=True,
151+
encrypt_inter_container_traffic=True,
152+
),
153+
sagemaker_session=sagemaker_session,
154+
)
155+
156+
132157
def test_custom_step():
133158
step = CustomStep(
134159
name="MyStep", display_name="CustomStepDisplayName", description="CustomStepDescription"
@@ -326,7 +351,7 @@ def test_training_step_tensorflow(sagemaker_session):
326351
training_epochs_parameter = ParameterInteger(name="TrainingEpochs", default_value=5)
327352
training_batch_size_parameter = ParameterInteger(name="TrainingBatchSize", default_value=500)
328353
estimator = TensorFlow(
329-
entry_point=os.path.join(DATA_DIR, SCRIPT_FILE),
354+
entry_point=DUMMY_SCRIPT_PATH,
330355
role=ROLE,
331356
model_dir=False,
332357
image_uri=IMAGE_URI,
@@ -403,6 +428,101 @@ def test_training_step_tensorflow(sagemaker_session):
403428
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
404429

405430

431+
def test_training_step_profiler_warning(sagemaker_session):
432+
estimator = TensorFlow(
433+
entry_point=DUMMY_SCRIPT_PATH,
434+
role=ROLE,
435+
model_dir=False,
436+
image_uri=IMAGE_URI,
437+
source_dir="s3://mybucket/source",
438+
framework_version="2.4.1",
439+
py_version="py37",
440+
disable_profiler=False,
441+
instance_count=1,
442+
instance_type="ml.p3.16xlarge",
443+
sagemaker_session=sagemaker_session,
444+
hyperparameters={
445+
"batch-size": 500,
446+
"epochs": 5,
447+
},
448+
debugger_hook_config=False,
449+
distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
450+
)
451+
452+
inputs = TrainingInput(s3_data=f"s3://{BUCKET}/train_manifest")
453+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
454+
with warnings.catch_warnings(record=True) as w:
455+
TrainingStep(
456+
name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config
457+
)
458+
assert len(w) == 1
459+
assert issubclass(w[-1].category, UserWarning)
460+
assert "Profiling is enabled on the provided estimator" in str(w[-1].message)
461+
462+
463+
def test_training_step_no_profiler_warning(sagemaker_session):
464+
estimator = TensorFlow(
465+
entry_point=DUMMY_SCRIPT_PATH,
466+
role=ROLE,
467+
model_dir=False,
468+
image_uri=IMAGE_URI,
469+
source_dir="s3://mybucket/source",
470+
framework_version="2.4.1",
471+
py_version="py37",
472+
disable_profiler=True,
473+
instance_count=1,
474+
instance_type="ml.p3.16xlarge",
475+
sagemaker_session=sagemaker_session,
476+
hyperparameters={
477+
"batch-size": 500,
478+
"epochs": 5,
479+
},
480+
debugger_hook_config=False,
481+
distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
482+
)
483+
484+
inputs = TrainingInput(s3_data=f"s3://{BUCKET}/train_manifest")
485+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
486+
with warnings.catch_warnings(record=True) as w:
487+
# profiler disabled, cache config not None
488+
TrainingStep(
489+
name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config
490+
)
491+
assert len(w) == 0
492+
493+
with warnings.catch_warnings(record=True) as w:
494+
# profiler enabled, cache config is None
495+
estimator.disable_profiler = False
496+
TrainingStep(name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=None)
497+
assert len(w) == 0
498+
499+
500+
def test_training_step_profiler_not_explicitly_enabled(sagemaker_session):
501+
estimator = TensorFlow(
502+
entry_point=DUMMY_SCRIPT_PATH,
503+
role=ROLE,
504+
model_dir=False,
505+
image_uri=IMAGE_URI,
506+
source_dir="s3://mybucket/source",
507+
framework_version="2.4.1",
508+
py_version="py37",
509+
instance_count=1,
510+
instance_type="ml.p3.16xlarge",
511+
sagemaker_session=sagemaker_session,
512+
hyperparameters={
513+
"batch-size": 500,
514+
"epochs": 5,
515+
},
516+
debugger_hook_config=False,
517+
distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
518+
)
519+
520+
inputs = TrainingInput(s3_data=f"s3://{BUCKET}/train_manifest")
521+
step = TrainingStep(name="MyTrainingStep", estimator=estimator, inputs=inputs)
522+
step_request = step.to_request()
523+
assert step_request["Arguments"]["ProfilerRuleConfigurations"] is None
524+
525+
406526
def test_processing_step(sagemaker_session):
407527
processing_input_data_uri_parameter = ParameterString(
408528
name="ProcessingInputDataUri", default_value=f"s3://{BUCKET}/processing_manifest"
@@ -473,28 +593,42 @@ def test_processing_step(sagemaker_session):
473593

474594

475595
@patch("sagemaker.processing.ScriptProcessor._normalize_args")
476-
def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session):
477-
processor = ScriptProcessor(
478-
role=ROLE,
479-
image_uri="012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri",
480-
command=["python3"],
481-
instance_type="ml.m4.xlarge",
482-
instance_count=1,
483-
volume_size_in_gb=100,
484-
volume_kms_key="arn:aws:kms:us-west-2:012345678901:key/volume-kms-key",
485-
output_kms_key="arn:aws:kms:us-west-2:012345678901:key/output-kms-key",
486-
max_runtime_in_seconds=3600,
487-
base_job_name="my_sklearn_processor",
488-
env={"my_env_variable": "my_env_variable_value"},
489-
tags=[{"Key": "my-tag", "Value": "my-tag-value"}],
490-
network_config=NetworkConfig(
491-
subnets=["my_subnet_id"],
492-
security_group_ids=["my_security_group_id"],
493-
enable_network_isolation=True,
494-
encrypt_inter_container_traffic=True,
495-
),
496-
sagemaker_session=sagemaker_session,
596+
def test_processing_step_normalizes_args_with_local_code(mock_normalize_args, script_processor):
597+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
598+
inputs = [
599+
ProcessingInput(
600+
source=f"s3://{BUCKET}/processing_manifest",
601+
destination="processing_manifest",
602+
)
603+
]
604+
outputs = [
605+
ProcessingOutput(
606+
source=f"s3://{BUCKET}/processing_manifest",
607+
destination="processing_manifest",
608+
)
609+
]
610+
step = ProcessingStep(
611+
name="MyProcessingStep",
612+
processor=script_processor,
613+
code=DUMMY_SCRIPT_PATH,
614+
inputs=inputs,
615+
outputs=outputs,
616+
job_arguments=["arg1", "arg2"],
617+
cache_config=cache_config,
618+
)
619+
mock_normalize_args.return_value = [step.inputs, step.outputs]
620+
step.to_request()
621+
mock_normalize_args.assert_called_with(
622+
job_name="MyProcessingStep-3e89f0c7e101c356cbedf27d9d27e9db",
623+
arguments=step.job_arguments,
624+
inputs=step.inputs,
625+
outputs=step.outputs,
626+
code=step.code,
497627
)
628+
629+
630+
@patch("sagemaker.processing.ScriptProcessor._normalize_args")
631+
def test_processing_step_normalizes_args_with_s3_code(mock_normalize_args, script_processor):
498632
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
499633
inputs = [
500634
ProcessingInput(
@@ -510,8 +644,8 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session)
510644
]
511645
step = ProcessingStep(
512646
name="MyProcessingStep",
513-
processor=processor,
514-
code="foo.py",
647+
processor=script_processor,
648+
code="s3://foo",
515649
inputs=inputs,
516650
outputs=outputs,
517651
job_arguments=["arg1", "arg2"],
@@ -520,13 +654,48 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session)
520654
mock_normalize_args.return_value = [step.inputs, step.outputs]
521655
step.to_request()
522656
mock_normalize_args.assert_called_with(
657+
job_name=None,
523658
arguments=step.job_arguments,
524659
inputs=step.inputs,
525660
outputs=step.outputs,
526661
code=step.code,
527662
)
528663

529664

665+
@patch("sagemaker.processing.ScriptProcessor._normalize_args")
666+
def test_processing_step_normalizes_args_with_no_code(mock_normalize_args, script_processor):
667+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
668+
inputs = [
669+
ProcessingInput(
670+
source=f"s3://{BUCKET}/processing_manifest",
671+
destination="processing_manifest",
672+
)
673+
]
674+
outputs = [
675+
ProcessingOutput(
676+
source=f"s3://{BUCKET}/processing_manifest",
677+
destination="processing_manifest",
678+
)
679+
]
680+
step = ProcessingStep(
681+
name="MyProcessingStep",
682+
processor=script_processor,
683+
inputs=inputs,
684+
outputs=outputs,
685+
job_arguments=["arg1", "arg2"],
686+
cache_config=cache_config,
687+
)
688+
mock_normalize_args.return_value = [step.inputs, step.outputs]
689+
step.to_request()
690+
mock_normalize_args.assert_called_with(
691+
job_name=None,
692+
arguments=step.job_arguments,
693+
inputs=step.inputs,
694+
outputs=step.outputs,
695+
code=None,
696+
)
697+
698+
530699
def test_create_model_step(sagemaker_session):
531700
model = Model(
532701
image_uri=IMAGE_URI,

0 commit comments

Comments
 (0)