Skip to content

Commit 8359e93

Browse files
committed
feature: Accept user-defined env variables for the entry-point
1 parent a9ac311 commit 8359e93

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

src/sagemaker/model.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -734,8 +734,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
734734

735735
def _script_mode_env_vars(self):
736736
"""Returns a mapping of environment variables for script mode execution"""
737-
script_name = None
738-
dir_name = None
737+
script_name = self.env.get(SCRIPT_PARAM_NAME.upper(), "")
738+
dir_name = self.env.get(DIR_PARAM_NAME.upper(), "")
739739
if self.uploaded_code:
740740
script_name = self.uploaded_code.script_name
741741
if self.repacked_model_data or self.enable_network_isolation():
@@ -751,8 +751,8 @@ def _script_mode_env_vars(self):
751751
else "file://" + self.source_dir
752752
)
753753
return {
754-
SCRIPT_PARAM_NAME.upper(): script_name or str(),
755-
DIR_PARAM_NAME.upper(): dir_name or str(),
754+
SCRIPT_PARAM_NAME.upper(): script_name,
755+
DIR_PARAM_NAME.upper(): dir_name,
756756
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level),
757757
SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name,
758758
}

tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py

+28
Original file line numberDiff line numberDiff line change
@@ -718,3 +718,31 @@ def test_register_hf_pytorch_model_auto_infer_framework(
718718
sagemaker_session.create_model_package_from_containers.assert_called_with(
719719
**expected_create_model_package_request
720720
)
721+
722+
723+
def test_accept_user_defined_environment_variables(
724+
sagemaker_session,
725+
huggingface_training_compiler_version,
726+
huggingface_training_compiler_pytorch_version,
727+
huggingface_training_compiler_pytorch_py_version,
728+
):
729+
program = "inference.py"
730+
directory = "/opt/ml/model/code"
731+
732+
hf_model = HuggingFaceModel(
733+
model_data="s3://some/data.tar.gz",
734+
role=ROLE,
735+
transformers_version=huggingface_training_compiler_version,
736+
pytorch_version=huggingface_training_compiler_pytorch_version,
737+
py_version=huggingface_training_compiler_pytorch_py_version,
738+
sagemaker_session=sagemaker_session,
739+
env={
740+
"SAGEMAKER_PROGRAM": program,
741+
"SAGEMAKER_SUBMIT_DIRECTORY": directory,
742+
},
743+
)
744+
745+
container_env = hf_model.prepare_container_def("ml.m4.xlarge")["Environment"]
746+
747+
assert container_env["SAGEMAKER_PROGRAM"] == program
748+
assert container_env["SAGEMAKER_SUBMIT_DIRECTORY"] == directory

0 commit comments

Comments
 (0)