Skip to content

Commit 29a7dbc

Browse files
mufaddal-rohawalajerrypeng7773
authored andcommitted
fix: repack model locally when local_code local mode (aws#3094)
1 parent b1a69b7 commit 29a7dbc

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

src/sagemaker/model.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -456,9 +456,14 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
456456
if is_pipeline_variable(self.model_data):
457457
# model is not yet there, defer repacking to later during pipeline execution
458458
return
459-
460-
bucket = self.bucket or self.sagemaker_session.default_bucket()
461-
repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"])
459+
if local_code and self.model_data.startswith("file://"):
460+
repacked_model_data = self.model_data
461+
else:
462+
bucket = self.bucket or self.sagemaker_session.default_bucket()
463+
repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"])
464+
self.uploaded_code = fw_utils.UploadedCode(
465+
s3_prefix=repacked_model_data, script_name=os.path.basename(self.entry_point)
466+
)
462467

463468
utils.repack_model(
464469
inference_script=self.entry_point,
@@ -471,9 +476,6 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
471476
)
472477

473478
self.repacked_model_data = repacked_model_data
474-
self.uploaded_code = fw_utils.UploadedCode(
475-
s3_prefix=self.repacked_model_data, script_name=os.path.basename(self.entry_point)
476-
)
477479

478480
def _script_mode_env_vars(self):
479481
"""Returns a mapping of environment variables for script mode execution"""

tests/integ/test_local_mode.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import tests.integ.lock as lock
2626
from tests.integ import DATA_DIR
27+
from mock import Mock, ANY
2728

2829
from sagemaker import image_uris
2930

@@ -221,6 +222,13 @@ def test_mxnet_local_data_local_script(
221222
):
222223
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
223224
script_path = os.path.join(data_path, "mnist.py")
225+
local_no_s3_session = LocalNoS3Session()
226+
local_no_s3_session.boto_session.resource = Mock(
227+
side_effect=local_no_s3_session.boto_session.resource
228+
)
229+
local_no_s3_session.boto_session.client = Mock(
230+
side_effect=local_no_s3_session.boto_session.client
231+
)
224232

225233
mx = MXNet(
226234
entry_point=script_path,
@@ -229,7 +237,7 @@ def test_mxnet_local_data_local_script(
229237
instance_type="local",
230238
framework_version=mxnet_training_latest_version,
231239
py_version=mxnet_training_latest_py_version,
232-
sagemaker_session=LocalNoS3Session(),
240+
sagemaker_session=local_no_s3_session,
233241
)
234242

235243
train_input = "file://" + os.path.join(data_path, "train")
@@ -243,6 +251,11 @@ def test_mxnet_local_data_local_script(
243251
predictor = mx.deploy(1, "local", endpoint_name=endpoint_name)
244252
data = numpy.zeros(shape=(1, 1, 28, 28))
245253
predictor.predict(data)
254+
# check if no boto_session s3 calls were made
255+
with pytest.raises(AssertionError):
256+
local_no_s3_session.boto_session.resource.assert_called_with("s3", region_name=ANY)
257+
with pytest.raises(AssertionError):
258+
local_no_s3_session.boto_session.client.assert_called_with("s3", region_name=ANY)
246259
finally:
247260
predictor.delete_endpoint()
248261

0 commit comments

Comments
 (0)