Skip to content

Commit 5d8ebfc

Browse files
staubhpPayton Staub
and
Payton Staub
authored
fix: Refactor repack_model script injection, fixes tar.gz error(#3039)
Co-authored-by: Payton Staub <[email protected]>
1 parent 8e0a37f commit 5d8ebfc

File tree

2 files changed

+76
-28
lines changed

2 files changed

+76
-28
lines changed

src/sagemaker/workflow/_utils.py

+21-28
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@
2020
from typing import List, Union
2121
from sagemaker import image_uris
2222
from sagemaker.inputs import TrainingInput
23-
from sagemaker.s3 import (
24-
S3Downloader,
25-
S3Uploader,
26-
)
2723
from sagemaker.estimator import EstimatorBase
2824
from sagemaker.sklearn.estimator import SKLearn
2925
from sagemaker.workflow.entities import RequestType
@@ -35,6 +31,7 @@
3531
Step,
3632
ConfigurableRetryStep,
3733
)
34+
from sagemaker.utils import _save_model, download_file_from_url
3835
from sagemaker.workflow.retry import RetryPolicy
3936

4037
FRAMEWORK_VERSION = "0.23-1"
@@ -203,40 +200,36 @@ def _establish_source_dir(self):
203200
self._entry_point = self._entry_point_basename
204201

205202
def _inject_repack_script(self):
206-
"""Injects the _repack_model.py script where it belongs.
203+
"""Injects the _repack_model.py script into S3 or local source directory.
207204
208205
If the source_dir is an S3 path:
209206
1) downloads the source_dir tar.gz
210-
2) copies the _repack_model.py script where it belongs
211-
3) uploads the mutated source_dir
207+
2) extracts it
208+
3) copies the _repack_model.py script into the extracted directory
209+
4) rezips the directory
210+
5) overwrites the S3 source_dir with the new tar.gz
212211
213212
If the source_dir is a local path:
214213
1) copies the _repack_model.py script into the source dir
215214
"""
216215
fname = os.path.join(os.path.dirname(__file__), REPACK_SCRIPT)
217216
if self._source_dir.lower().startswith("s3://"):
218217
with tempfile.TemporaryDirectory() as tmp:
219-
local_path = os.path.join(tmp, "local.tar.gz")
220-
221-
S3Downloader.download(
222-
s3_uri=self._source_dir,
223-
local_path=local_path,
224-
sagemaker_session=self.sagemaker_session,
225-
)
226-
227-
src_dir = os.path.join(tmp, "src")
228-
with tarfile.open(name=local_path, mode="r:gz") as tf:
229-
tf.extractall(path=src_dir)
230-
231-
shutil.copy2(fname, os.path.join(src_dir, REPACK_SCRIPT))
232-
with tarfile.open(name=local_path, mode="w:gz") as tf:
233-
tf.add(src_dir, arcname=".")
234-
235-
S3Uploader.upload(
236-
local_path=local_path,
237-
desired_s3_uri=self._source_dir,
238-
sagemaker_session=self.sagemaker_session,
239-
)
218+
targz_contents_dir = os.path.join(tmp, "extracted")
219+
220+
old_targz_path = os.path.join(tmp, "old.tar.gz")
221+
download_file_from_url(self._source_dir, old_targz_path, self.sagemaker_session)
222+
223+
with tarfile.open(name=old_targz_path, mode="r:gz") as t:
224+
t.extractall(path=targz_contents_dir)
225+
226+
shutil.copy2(fname, os.path.join(targz_contents_dir, REPACK_SCRIPT))
227+
228+
new_targz_path = os.path.join(tmp, "new.tar.gz")
229+
with tarfile.open(new_targz_path, mode="w:gz") as t:
230+
t.add(targz_contents_dir, arcname=os.path.sep)
231+
232+
_save_model(self._source_dir, new_targz_path, self.sagemaker_session, kms_key=None)
240233
else:
241234
shutil.copy2(fname, os.path.join(self._source_dir, REPACK_SCRIPT))
242235

tests/unit/sagemaker/workflow/test_utils.py

+55
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sagemaker.estimator import Estimator
2929
from sagemaker.workflow import Properties
3030
from sagemaker.workflow._utils import _RepackModelStep
31+
from tests.unit.test_utils import FakeS3, list_tar_files
3132
from tests.unit import DATA_DIR
3233

3334
REGION = "us-west-2"
@@ -210,3 +211,57 @@ def test_repack_model_step_with_source_dir(estimator, source_dir):
210211
assert step.properties.TrainingJobName.expr == {
211212
"Get": "Steps.MyRepackModelStep.TrainingJobName"
212213
}
214+
215+
216+
@pytest.fixture()
217+
def tmp(tmpdir):
218+
yield str(tmpdir)
219+
220+
221+
@pytest.fixture()
222+
def fake_s3(tmp):
223+
return FakeS3(tmp)
224+
225+
226+
def test_inject_repack_script_s3(estimator, tmp, fake_s3):
227+
228+
create_file_tree(
229+
tmp,
230+
[
231+
"model-dir/aa",
232+
"model-dir/foo/inference.py",
233+
],
234+
)
235+
236+
model_data = Properties(path="Steps.MyStep", shape_name="DescribeModelOutput")
237+
entry_point = "inference.py"
238+
source_dir_path = "s3://fake/location"
239+
step = _RepackModelStep(
240+
name="MyRepackModelStep",
241+
sagemaker_session=fake_s3.sagemaker_session,
242+
role=estimator.role,
243+
image_uri="foo",
244+
model_data=model_data,
245+
entry_point=entry_point,
246+
source_dir=source_dir_path,
247+
)
248+
249+
fake_s3.tar_and_upload("model-dir", "s3://fake/location")
250+
251+
step._inject_repack_script()
252+
253+
assert list_tar_files(fake_s3.fake_upload_path, tmp) == {
254+
"/aa",
255+
"/foo/inference.py",
256+
"/_repack_model.py",
257+
}
258+
259+
260+
def create_file_tree(root, tree):
261+
for file in tree:
262+
try:
263+
os.makedirs(os.path.join(root, os.path.dirname(file)))
264+
except: # noqa: E722 Using bare except because p2/3 incompatibility issues.
265+
pass
266+
with open(os.path.join(root, file), "a") as f:
267+
f.write(file)

0 commit comments

Comments
 (0)