From 7a4e4ce41e5d5f18557d1b2c91e9be41d7a6aab3 Mon Sep 17 00:00:00 2001 From: Saumitra Vikram Date: Tue, 10 Jan 2023 16:23:19 +0530 Subject: [PATCH 1/2] feature: support specifying env-vars when creating model from model package --- src/sagemaker/model.py | 3 +++ .../sagemaker/model/test_model_package.py | 27 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index a8d6347d8c..f10cee8474 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1620,6 +1620,9 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar container_def = {"ModelPackageName": model_package_name} + if self.env != {}: + container_def["Environment"] = self.env + self._ensure_base_name_if_needed(model_package_name.split("/")[-1]) self._set_model_name_if_needed() diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index e65932ca99..32fd9fb413 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -52,6 +52,10 @@ "CertifyForMarketplace": False, } +ENV_KEY_1 = "env_key_1" +ENV_VALUE_1 = "env_key_1" +ENVIRONMENT = {ENV_KEY_1: ENV_VALUE_1} + @pytest.fixture def sagemaker_session(): @@ -115,6 +119,29 @@ def test_create_sagemaker_model_uses_model_name(name_from_base, sagemaker_sessio ) +def test_create_sagemaker_model_include_environment_variable(sagemaker_session): + model_name = "my-model" + model_package_name = "my-model-package" + + model_package = ModelPackage( + role="role", + name=model_name, + model_package_arn=model_package_name, + env=ENVIRONMENT, + sagemaker_session=sagemaker_session, + ) + + model_package._create_sagemaker_model() + + sagemaker_session.create_model.assert_called_with( + model_name, + "role", + {"ModelPackageName": model_package_name, "Environment": ENVIRONMENT}, + vpc_config=None, + enable_network_isolation=False, + ) + + @patch("sagemaker.utils.name_from_base") def test_create_sagemaker_model_generates_model_name(name_from_base, sagemaker_session): model_package_name = "my-model-package" From b04601d30b65415199f33fb765058d970dcce4f3 Mon Sep 17 00:00:00 2001 From: Saumitra Vikram Date: Tue, 10 Jan 2023 17:09:31 +0530 Subject: [PATCH 2/2] Updating unit test --- tests/unit/sagemaker/model/test_model_package.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index 32fd9fb413..161940874f 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -52,10 +52,6 @@ "CertifyForMarketplace": False, } -ENV_KEY_1 = "env_key_1" -ENV_VALUE_1 = "env_key_1" -ENVIRONMENT = {ENV_KEY_1: ENV_VALUE_1} - @pytest.fixture def sagemaker_session(): @@ -122,12 +118,15 @@ def test_create_sagemaker_model_uses_model_name(name_from_base, sagemaker_sessio def test_create_sagemaker_model_include_environment_variable(sagemaker_session): model_name = "my-model" model_package_name = "my-model-package" + env_key = "env_key" + env_value = "env_value" + environment = {env_key: env_value} model_package = ModelPackage( role="role", name=model_name, model_package_arn=model_package_name, - env=ENVIRONMENT, + env=environment, sagemaker_session=sagemaker_session, ) @@ -136,7 +135,7 @@ def test_create_sagemaker_model_include_environment_variable(sagemaker_session): sagemaker_session.create_model.assert_called_with( model_name, "role", - {"ModelPackageName": model_package_name, "Environment": ENVIRONMENT}, + {"ModelPackageName": model_package_name, "Environment": environment}, vpc_config=None, enable_network_isolation=False, )