diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 5d1e2a2b92..56a00478d3 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -478,7 +478,7 @@ def _script_mode_env_vars(self): dir_name = None if self.uploaded_code: script_name = self.uploaded_code.script_name - if self.enable_network_isolation(): + if self.repacked_model_data or self.enable_network_isolation(): dir_name = "/opt/ml/model/code" else: dir_name = self.uploaded_code.s3_prefix diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index b66fda908d..38aa894c10 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -665,3 +665,25 @@ def test_all_framework_models_add_jumpstart_base_name( sagemaker_session.create_model.reset_mock() sagemaker_session.endpoint_from_production_variants.reset_mock() + + +@patch("sagemaker.utils.repack_model") +def test_script_mode_model_uses_proper_sagemaker_submit_dir(repack_model, sagemaker_session): + + source_dir = "s3://blah/blah/blah" + t = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + source_dir=source_dir, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT) + + assert ( + sagemaker_session.create_model.call_args_list[0][0][2]["Environment"][ + "SAGEMAKER_SUBMIT_DIRECTORY" + ] + == "/opt/ml/model/code" + )