Skip to content

Commit d53aad6

Browse files
ciNamrata Madan
ci
authored and
Namrata Madan
committed
fix: fix TrainingStep cache misses due to timestamp based job name
1 parent e5198e3 commit d53aad6

File tree

6 files changed

+179
-10
lines changed

6 files changed

+179
-10
lines changed

src/sagemaker/estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def __init__(
457457
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
458458
self.code_location = code_location
459459
self.entry_point = entry_point
460-
self.dependencies = dependencies
460+
self.dependencies = dependencies or []
461461
self.uploaded_code = None
462462
self.tags = add_jumpstart_tags(
463463
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir

src/sagemaker/workflow/steps.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,16 @@ def __init__(
287287
)
288288
warnings.warn(msg)
289289

290+
self.job_name = None
291+
if estimator.source_dir or estimator.entry_point:
292+
# By default, `Estimator` will upload the local code to an S3 path
293+
# containing a timestamp. This causes cache misses whenever a
294+
# pipeline is updated, even if the underlying script hasn't changed.
295+
# To avoid this, hash the contents of the training script and include it
296+
# in the `job_name` passed to the `Estimator`, which will be used
297+
# instead of the timestamped path.
298+
self.job_name = self._generate_code_upload_path()
299+
290300
@property
291301
def arguments(self) -> RequestType:
292302
"""The arguments dictionary that is used to call `create_training_job`.
@@ -295,7 +305,7 @@ def arguments(self) -> RequestType:
295305
The `TrainingJobName` and `ExperimentConfig` attributes cannot be included.
296306
"""
297307

298-
self.estimator._prepare_for_training()
308+
self.estimator._prepare_for_training(self.job_name)
299309
train_args = _TrainingJob._get_train_args(
300310
self.estimator, self.inputs, experiment_config=dict()
301311
)
@@ -319,6 +329,26 @@ def to_request(self) -> RequestType:
319329

320330
return request_dict
321331

332+
def _generate_code_upload_path(self) -> str or None:
333+
"""Generate an upload path for local training scripts based on their content."""
334+
from sagemaker.workflow.utilities import hash_files_or_dirs
335+
336+
if self.estimator.source_dir:
337+
source_dir_url = urlparse(self.estimator.source_dir)
338+
if source_dir_url.scheme == "" or source_dir_url.scheme == "file":
339+
code_hash = hash_files_or_dirs(
340+
[self.estimator.source_dir] + self.estimator.dependencies
341+
)
342+
return f"{self.name}-{code_hash}"[:1024]
343+
elif self.estimator.entry_point:
344+
entry_point_url = urlparse(self.estimator.entry_point)
345+
if entry_point_url.scheme == "" or entry_point_url.scheme == "file":
346+
code_hash = hash_files_or_dirs(
347+
[self.estimator.entry_point] + self.estimator.dependencies
348+
)
349+
return f"{self.name}-{code_hash}"[:1024]
350+
return None
351+
322352

323353
class CreateModelStep(ConfigurableRetryStep):
324354
"""`CreateModelStep` for SageMaker Pipelines Workflows."""

src/sagemaker/workflow/utilities.py

+71-5
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
"""Utilities to support workflow."""
1414
from __future__ import absolute_import
1515

16+
from pathlib import Path
1617
from typing import List, Sequence, Union
1718
import hashlib
19+
from _hashlib import HASH as Hash
1820
from urllib.parse import unquote, urlparse
1921

2022
from sagemaker.workflow.entities import (
@@ -23,6 +25,8 @@
2325
)
2426
from sagemaker.workflow.step_collections import StepCollection
2527

28+
BUF_SIZE = 65536 # 64KiB
29+
2630

2731
def list_to_request(entities: Sequence[Union[Entity, StepCollection]]) -> List[RequestType]:
2832
"""Get the request structure for list of entities.
@@ -49,15 +53,77 @@ def hash_file(path: str) -> str:
4953
Returns:
5054
str: The MD5 hash of the file.
5155
"""
52-
BUF_SIZE = 65536 # read in 64KiB chunks
56+
return _hash_file(path, hashlib.md5()).hexdigest()
57+
58+
59+
def hash_files_or_dirs(paths: List[str]) -> str:
60+
"""Get the MD5 hash of the contents of a list of files or directories.
61+
62+
Hash is changed if:
63+
* input list is changed
64+
* new nested directories/files are added to any directory in the input list
65+
* nested directory/file names are changed for any of the inputted directories
66+
* content of files is edited
67+
68+
Args:
69+
paths: List of file or directory paths
70+
Returns:
71+
str: The MD5 hash of the list of files or directories.
72+
"""
5373
md5 = hashlib.md5()
54-
if path.lower().startswith("file://"):
74+
for path in sorted(paths):
75+
md5 = _hash_file_or_dir(path, md5)
76+
return md5.hexdigest()
77+
78+
79+
def _hash_file_or_dir(path: str, md5: Hash) -> Hash:
80+
"""Updates the inputted Hash with the contents of the current path
81+
Args:
82+
path: path of file or directory
83+
Returns:
84+
str: The MD5 hash of the file or directory
85+
"""
86+
if isinstance(path, str) and path.lower().startswith("file://"):
5587
path = unquote(urlparse(path).path)
56-
with open(path, "rb") as f:
88+
md5.update(path.encode())
89+
if Path(path).is_dir():
90+
md5 = _hash_dir(path, md5)
91+
elif Path(path).is_file():
92+
md5 = _hash_file(path, md5)
93+
return md5
94+
95+
96+
def _hash_dir(directory: Union[str, Path], md5: Hash) -> Hash:
97+
"""Updates the inputted Hash with the contents of the current path
98+
Args:
99+
directory: path of the directory
100+
Returns:
101+
str: The MD5 hash of the directory
102+
"""
103+
assert Path(directory).is_dir()
104+
for path in sorted(Path(directory).iterdir()):
105+
md5.update(path.name.encode())
106+
if path.is_file():
107+
md5 = _hash_file(path, md5)
108+
elif path.is_dir():
109+
md5 = _hash_dir(path, md5)
110+
return md5
111+
112+
113+
def _hash_file(file: Union[str, Path], md5: Hash) -> Hash:
114+
"""Updates the inputted Hash with the contents of the current path
115+
Args:
116+
file: path of the file
117+
Returns:
118+
str: The MD5 hash of the file
119+
"""
120+
if isinstance(file, str) and file.lower().startswith("file://"):
121+
file = unquote(urlparse(file).path)
122+
assert Path(file).is_file()
123+
with open(file, "rb") as f:
57124
while True:
58125
data = f.read(BUF_SIZE)
59126
if not data:
60127
break
61128
md5.update(data)
62-
63-
return md5.hexdigest()
129+
return md5

tests/unit/sagemaker/workflow/test_utilities.py

+70-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from __future__ import absolute_import
1515

1616
import tempfile
17-
from sagemaker.workflow.utilities import hash_file
17+
from sagemaker.workflow.utilities import hash_file, hash_files_or_dirs
18+
from pathlib import Path
1819

1920

2021
def test_hash_file():
@@ -29,3 +30,71 @@ def test_hash_file_uri():
2930
tmp.write("hashme".encode())
3031
hash = hash_file(f"file:///{tmp.name}")
3132
assert hash == "d41d8cd98f00b204e9800998ecf8427e"
33+
34+
35+
def test_hash_files_or_dirs_with_file():
36+
with tempfile.NamedTemporaryFile() as tmp:
37+
tmp.write("hashme".encode())
38+
hash1 = hash_files_or_dirs([f"file:///{tmp.name}"])
39+
# compute hash again with no change to file
40+
hash2 = hash_files_or_dirs([f"file:///{tmp.name}"])
41+
assert hash1 == hash2
42+
43+
44+
def test_hash_files_or_dirs_with_directory():
45+
with tempfile.TemporaryDirectory() as tmpdirname:
46+
temp_dir = Path(tmpdirname)
47+
file_name = temp_dir / "test.txt"
48+
file_name.write_text("foo bar")
49+
hash1 = hash_files_or_dirs([tmpdirname])
50+
# compute hash again with no change to directory
51+
hash2 = hash_files_or_dirs([tmpdirname])
52+
assert hash1 == hash2
53+
54+
55+
def test_hash_files_or_dirs_change_file_content():
56+
with tempfile.TemporaryDirectory() as tmpdirname:
57+
temp_dir = Path(tmpdirname)
58+
file_name = temp_dir / "test.txt"
59+
file_name.write_text("foo bar")
60+
hash1 = hash_files_or_dirs([tmpdirname])
61+
# change file content
62+
file_name.write_text("new text")
63+
hash2 = hash_files_or_dirs([tmpdirname])
64+
assert hash1 != hash2
65+
66+
67+
def test_hash_files_or_dirs_rename_file():
68+
with tempfile.TemporaryDirectory() as tmpdirname:
69+
temp_dir = Path(tmpdirname)
70+
file_name = temp_dir / "test.txt"
71+
file_name.write_text("foo bar")
72+
hash1 = hash_files_or_dirs([tmpdirname])
73+
# rename file
74+
file_name.rename(temp_dir / "test1.txt")
75+
hash2 = hash_files_or_dirs([tmpdirname])
76+
assert hash1 != hash2
77+
# rename it back
78+
79+
80+
def test_hash_files_or_dirs_add_new_file():
81+
with tempfile.TemporaryDirectory() as tmpdirname:
82+
temp_dir = Path(tmpdirname)
83+
file_name = temp_dir / "test.txt"
84+
file_name.write_text("foo bar")
85+
hash1 = hash_files_or_dirs([tmpdirname])
86+
# add new file
87+
file_name2 = temp_dir / "test2.txt"
88+
file_name2.write_text("test test")
89+
hash2 = hash_files_or_dirs([tmpdirname])
90+
assert hash1 != hash2
91+
92+
93+
def test_hash_files_or_dirs_unsorted_input_list():
94+
with tempfile.NamedTemporaryFile() as tmp1:
95+
tmp1.write("hashme".encode())
96+
with tempfile.NamedTemporaryFile() as tmp2:
97+
tmp2.write("hashme".encode())
98+
hash1 = hash_files_or_dirs([tmp1.name, tmp2.name])
99+
hash2 = hash_files_or_dirs([tmp2.name, tmp1.name])
100+
assert hash1 == hash2

tests/unit/sagemaker/workflow/test_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ def test_repack_model_step(estimator):
120120
assert hyperparameters["inference_script"] == '"dummy_script.py"'
121121
assert hyperparameters["model_archive"] == '"s3://my-bucket/model.tar.gz"'
122122
assert hyperparameters["sagemaker_program"] == '"_repack_model.py"'
123+
assert (
124+
hyperparameters["sagemaker_submit_directory"]
125+
== '"s3://my-bucket/MyRepackModelStep-1be10316814854973ed1b445db3ef84e/source/sourcedir.tar.gz"'
126+
)
123127

124128
del request_dict["Arguments"]["HyperParameters"]
125129
del request_dict["Arguments"]["AlgorithmSpecification"]["TrainingImage"]

tests/unit/test_estimator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1598,7 +1598,7 @@ def test_git_support_with_branch_and_commit_succeed(git_clone_repo, sagemaker_se
15981598
git_clone_repo.side_effect = lambda gitconfig, entrypoint, source_dir=None, dependencies=None: {
15991599
"entry_point": "/tmp/repo_dir/entry_point",
16001600
"source_dir": None,
1601-
"dependencies": None,
1601+
"dependencies": [],
16021602
}
16031603
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
16041604
entry_point = "entry_point"
@@ -3448,7 +3448,7 @@ def test_git_support_with_branch_and_commit_succeed_estimator_class(
34483448
image_uri=IMAGE_URI,
34493449
)
34503450
fw.fit()
3451-
git_clone_repo.assert_called_once_with(git_config, entry_point, None, None)
3451+
git_clone_repo.assert_called_once_with(git_config, entry_point, None, [])
34523452

34533453

34543454
@patch("sagemaker.estimator.Estimator._stage_user_code_in_s3")

0 commit comments

Comments
 (0)