Skip to content

Commit 69f9fa7

Browse files
martinRenouroot
authored and
root
committed
feature: Accept user-defined env variables for the entry-point (aws#4175)
1 parent 605a3f4 commit 69f9fa7

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

src/sagemaker/model.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -766,8 +766,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
766766

767767
def _script_mode_env_vars(self):
768768
"""Returns a mapping of environment variables for script mode execution"""
769-
script_name = None
770-
dir_name = None
769+
script_name = self.env.get(SCRIPT_PARAM_NAME.upper(), "")
770+
dir_name = self.env.get(DIR_PARAM_NAME.upper(), "")
771771
if self.uploaded_code:
772772
script_name = self.uploaded_code.script_name
773773
if self.repacked_model_data or self.enable_network_isolation():
@@ -783,8 +783,8 @@ def _script_mode_env_vars(self):
783783
else "file://" + self.source_dir
784784
)
785785
return {
786-
SCRIPT_PARAM_NAME.upper(): script_name or str(),
787-
DIR_PARAM_NAME.upper(): dir_name or str(),
786+
SCRIPT_PARAM_NAME.upper(): script_name,
787+
DIR_PARAM_NAME.upper(): dir_name,
788788
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level),
789789
SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name,
790790
}

tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py

+29
Original file line numberDiff line numberDiff line change
@@ -718,3 +718,32 @@ def test_register_hf_pytorch_model_auto_infer_framework(
718718
sagemaker_session.create_model_package_from_containers.assert_called_with(
719719
**expected_create_model_package_request
720720
)
721+
722+
723+
def test_accept_user_defined_environment_variables(
724+
sagemaker_session,
725+
huggingface_training_compiler_version,
726+
huggingface_training_compiler_pytorch_version,
727+
huggingface_training_compiler_pytorch_py_version,
728+
):
729+
program = "inference.py"
730+
directory = "/opt/ml/model/code"
731+
732+
hf_model = HuggingFaceModel(
733+
model_data="s3://some/data.tar.gz",
734+
role=ROLE,
735+
transformers_version=huggingface_training_compiler_version,
736+
pytorch_version=huggingface_training_compiler_pytorch_version,
737+
py_version=huggingface_training_compiler_pytorch_py_version,
738+
sagemaker_session=sagemaker_session,
739+
env={
740+
"SAGEMAKER_PROGRAM": program,
741+
"SAGEMAKER_SUBMIT_DIRECTORY": directory,
742+
},
743+
image_uri="fakeimage",
744+
)
745+
746+
container_env = hf_model.prepare_container_def("ml.m4.xlarge")["Environment"]
747+
748+
assert container_env["SAGEMAKER_PROGRAM"] == program
749+
assert container_env["SAGEMAKER_SUBMIT_DIRECTORY"] == directory

0 commit comments

Comments
 (0)