From b9df7e69b8f4d42930838750992122ecfc4a4e2b Mon Sep 17 00:00:00 2001 From: martinRenou Date: Tue, 10 Oct 2023 11:50:36 +0200 Subject: [PATCH] feature: Accept user-defined env variables for the entry-point --- src/sagemaker/model.py | 8 ++--- .../test_huggingface_pytorch_compiler.py | 29 +++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index ff340b58e9..5a2b27c54d 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -766,8 +766,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: def _script_mode_env_vars(self): """Returns a mapping of environment variables for script mode execution""" - script_name = None - dir_name = None + script_name = self.env.get(SCRIPT_PARAM_NAME.upper(), "") + dir_name = self.env.get(DIR_PARAM_NAME.upper(), "") if self.uploaded_code: script_name = self.uploaded_code.script_name if self.repacked_model_data or self.enable_network_isolation(): @@ -783,8 +783,8 @@ def _script_mode_env_vars(self): else "file://" + self.source_dir ) return { - SCRIPT_PARAM_NAME.upper(): script_name or str(), - DIR_PARAM_NAME.upper(): dir_name or str(), + SCRIPT_PARAM_NAME.upper(): script_name, + DIR_PARAM_NAME.upper(): dir_name, CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level), SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name, } diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index 2b59113354..12162f799f 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -718,3 +718,32 @@ def test_register_hf_pytorch_model_auto_infer_framework( sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request ) + + +def test_accept_user_defined_environment_variables( + sagemaker_session, + huggingface_training_compiler_version, + huggingface_training_compiler_pytorch_version, + huggingface_training_compiler_pytorch_py_version, +): + program = "inference.py" + directory = "/opt/ml/model/code" + + hf_model = HuggingFaceModel( + model_data="s3://some/data.tar.gz", + role=ROLE, + transformers_version=huggingface_training_compiler_version, + pytorch_version=huggingface_training_compiler_pytorch_version, + py_version=huggingface_training_compiler_pytorch_py_version, + sagemaker_session=sagemaker_session, + env={ + "SAGEMAKER_PROGRAM": program, + "SAGEMAKER_SUBMIT_DIRECTORY": directory, + }, + image_uri="fakeimage", + ) + + container_env = hf_model.prepare_container_def("ml.m4.xlarge")["Environment"] + + assert container_env["SAGEMAKER_PROGRAM"] == program + assert container_env["SAGEMAKER_SUBMIT_DIRECTORY"] == directory