diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 4f4ae6a3dd..d8ad818ce5 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -454,9 +454,14 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: if is_pipeline_variable(self.model_data): # model is not yet there, defer repacking to later during pipeline execution return - - bucket = self.bucket or self.sagemaker_session.default_bucket() - repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"]) + if local_code and self.model_data.startswith("file://"): + repacked_model_data = self.model_data + else: + bucket = self.bucket or self.sagemaker_session.default_bucket() + repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"]) + self.uploaded_code = fw_utils.UploadedCode( + s3_prefix=repacked_model_data, script_name=os.path.basename(self.entry_point) + ) utils.repack_model( inference_script=self.entry_point, @@ -469,9 +474,6 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: ) self.repacked_model_data = repacked_model_data - self.uploaded_code = fw_utils.UploadedCode( - s3_prefix=self.repacked_model_data, script_name=os.path.basename(self.entry_point) - ) def _script_mode_env_vars(self): """Returns a mapping of environment variables for script mode execution""" diff --git a/tests/integ/test_local_mode.py b/tests/integ/test_local_mode.py index af80be73f7..0c3bf140c3 100644 --- a/tests/integ/test_local_mode.py +++ b/tests/integ/test_local_mode.py @@ -24,6 +24,7 @@ import tests.integ.lock as lock from tests.integ import DATA_DIR +from mock import Mock, ANY from sagemaker import image_uris @@ -221,6 +222,13 @@ def test_mxnet_local_data_local_script( ): data_path = os.path.join(DATA_DIR, "mxnet_mnist") script_path = os.path.join(data_path, "mnist.py") + local_no_s3_session = LocalNoS3Session() + local_no_s3_session.boto_session.resource = Mock( + side_effect=local_no_s3_session.boto_session.resource + ) + local_no_s3_session.boto_session.client = Mock( + side_effect=local_no_s3_session.boto_session.client + ) mx = MXNet( entry_point=script_path, @@ -229,7 +237,7 @@ def test_mxnet_local_data_local_script( instance_type="local", framework_version=mxnet_training_latest_version, py_version=mxnet_training_latest_py_version, - sagemaker_session=LocalNoS3Session(), + sagemaker_session=local_no_s3_session, ) train_input = "file://" + os.path.join(data_path, "train") @@ -243,6 +251,11 @@ def test_mxnet_local_data_local_script( predictor = mx.deploy(1, "local", endpoint_name=endpoint_name) data = numpy.zeros(shape=(1, 1, 28, 28)) predictor.predict(data) + # check if no boto_session s3 calls were made + with pytest.raises(AssertionError): + local_no_s3_session.boto_session.resource.assert_called_with("s3", region_name=ANY) + with pytest.raises(AssertionError): + local_no_s3_session.boto_session.client.assert_called_with("s3", region_name=ANY) finally: predictor.delete_endpoint()