diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index a8d6347d8c..f10cee8474 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1620,6 +1620,9 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar container_def = {"ModelPackageName": model_package_name} + if self.env != {}: + container_def["Environment"] = self.env + self._ensure_base_name_if_needed(model_package_name.split("/")[-1]) self._set_model_name_if_needed() diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index e65932ca99..161940874f 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -115,6 +115,32 @@ def test_create_sagemaker_model_uses_model_name(name_from_base, sagemaker_sessio ) +def test_create_sagemaker_model_include_environment_variable(sagemaker_session): + model_name = "my-model" + model_package_name = "my-model-package" + env_key = "env_key" + env_value = "env_value" + environment = {env_key: env_value} + + model_package = ModelPackage( + role="role", + name=model_name, + model_package_arn=model_package_name, + env=environment, + sagemaker_session=sagemaker_session, + ) + + model_package._create_sagemaker_model() + + sagemaker_session.create_model.assert_called_with( + model_name, + "role", + {"ModelPackageName": model_package_name, "Environment": environment}, + vpc_config=None, + enable_network_isolation=False, + ) + + @patch("sagemaker.utils.name_from_base") def test_create_sagemaker_model_generates_model_name(name_from_base, sagemaker_session): model_package_name = "my-model-package"