Skip to content

Commit 48ca435

Browse files
committed
Fixes container env generation for S3 URI, update test
1 parent b9f90dc commit 48ca435

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

src/sagemaker/model.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
466466
)
467467

468468
def _script_mode_env_vars(self):
469-
"""Placeholder docstring"""
469+
"""Returns a mapping of environment variables for script mode execution"""
470470
script_name = None
471471
dir_name = None
472472
if self.uploaded_code:
@@ -476,9 +476,12 @@ def _script_mode_env_vars(self):
476476
else:
477477
dir_name = self.uploaded_code.s3_prefix
478478
elif self.entry_point is not None:
479-
script_name = self.entry_point
480479
if self.source_dir is not None:
481-
dir_name = "file://" + self.source_dir
480+
dir_name = (
481+
self.source_dir
482+
if self.source_dir.startswith("s3://")
483+
else "file://" + self.source_dir
484+
)
482485

483486
return {
484487
SCRIPT_PARAM_NAME.upper(): script_name or str(),

tests/unit/sagemaker/model/test_framework_model.py

+18
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MODEL_IMAGE = "mi"
2626
ENTRY_POINT = "blah.py"
2727
ROLE = "some-role"
28+
S3_SOURCE_DIR = "s3://somebucket/sourcedir.tar.gz"
2829

2930
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
3031
SCRIPT_NAME = "dummy_script.py"
@@ -111,6 +112,23 @@ def test_prepare_container_def(time, sagemaker_session):
111112
"ModelDataUrl": MODEL_DATA,
112113
}
113114

115+
@patch("shutil.rmtree", MagicMock())
116+
@patch("tarfile.open", MagicMock())
117+
@patch("os.listdir", MagicMock(return_value=["blah.py"]))
118+
@patch("time.strftime", return_value=TIMESTAMP)
119+
def test_prepare_container_def_s3_src(time, sagemaker_session):
120+
model = DummyFrameworkModel(sagemaker_session, source_dir=S3_SOURCE_DIR)
121+
assert model.prepare_container_def(INSTANCE_TYPE) == {
122+
"Environment": {
123+
"SAGEMAKER_PROGRAM": ENTRY_POINT,
124+
"SAGEMAKER_SUBMIT_DIRECTORY": "s3://somebucket/sourcedir.tar.gz",
125+
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
126+
"SAGEMAKER_REGION": REGION,
127+
},
128+
"Image": MODEL_IMAGE,
129+
"ModelDataUrl": MODEL_DATA,
130+
}
131+
S3_SOURCE_DIR
114132

115133
@patch("shutil.rmtree", MagicMock())
116134
@patch("tarfile.open", MagicMock())

0 commit comments

Comments
 (0)