From 90c9fbeb031ee5137cd50cf0b5572cb38f303175 Mon Sep 17 00:00:00 2001 From: Samrudhi Sharma Date: Fri, 17 May 2024 13:03:04 -0700 Subject: [PATCH 1/3] Fix: Add Image URI overrides for transformers models --- .../serve/builder/transformers_builder.py | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index ead9b7425f..614290b132 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -78,6 +78,25 @@ def _prepare_for_mode(self): """Abstract method""" def _create_transformers_model(self) -> Type[Model]: + """Initializes HF model with or without image_uri""" + if self.image_uri is None: + pysdk_model = self._get_hf_metadata_create_model() + else: + pysdk_model = HuggingFaceModel( + image_uri=self.image_uri, + vpc_config=self.vpc_config, + env=self.env_vars, + role=self.role_arn, + sagemaker_session=self.sagemaker_session, + ) + + logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri) + + self._original_deploy = pysdk_model.deploy + pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper + return pysdk_model + + def _get_hf_metadata_create_model(self) -> Type[Model]: """Initializes the model after fetching image 1. Get the metadata for deciding framework @@ -132,22 +151,21 @@ def _create_transformers_model(self) -> Type[Model]: vpc_config=self.vpc_config, ) - if not self.image_uri and self.mode == Mode.LOCAL_CONTAINER: + if self.mode == Mode.LOCAL_CONTAINER: self.image_uri = pysdk_model.serving_image_uri( self.sagemaker_session.boto_region_name, "local" ) - elif not self.image_uri: + else: self.image_uri = pysdk_model.serving_image_uri( self.sagemaker_session.boto_region_name, self.instance_type ) - logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri) + if pysdk_model is None or self.image_uri is None: + raise ValueError("PySDK model unable to be created, try overriding image_uri") if not pysdk_model.image_uri: pysdk_model.image_uri = self.image_uri - self._original_deploy = pysdk_model.deploy - pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper return pysdk_model @_capture_telemetry("transformers.deploy") From 8d8fb1f8e86c44f182c5fe76981bcceeb5cfbe8d Mon Sep 17 00:00:00 2001 From: Samrudhi Sharma Date: Fri, 17 May 2024 17:45:25 -0700 Subject: [PATCH 2/3] Increase coverage --- .../serve/test_serve_transformers.py | 2 +- .../builder/test_transformers_builder.py | 25 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/integ/sagemaker/serve/test_serve_transformers.py b/tests/integ/sagemaker/serve/test_serve_transformers.py index 64029f7290..33a1ae6708 100644 --- a/tests/integ/sagemaker/serve/test_serve_transformers.py +++ b/tests/integ/sagemaker/serve/test_serve_transformers.py @@ -127,4 +127,4 @@ def test_pytorch_transformers_sagemaker_endpoint( logger.exception(caught_ex) assert ( False - ), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test" + ), f"{caught_ex} thrown when running pytorch transformers sagemaker endpoint test" diff --git a/tests/unit/sagemaker/serve/builder/test_transformers_builder.py b/tests/unit/sagemaker/serve/builder/test_transformers_builder.py index d63eabf2a3..86f5836093 100644 --- a/tests/unit/sagemaker/serve/builder/test_transformers_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_transformers_builder.py @@ -144,3 +144,28 @@ def test_image_uri_override( with self.assertRaises(ValueError) as _: model.deploy(mode=Mode.IN_PROCESS) + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers") + @patch( + "sagemaker.serve.builder.transformers_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.huggingface.llm_utils.get_huggingface_model_metadata", + return_value=None, + ) + def test_failure_hf_md(self, mock_model_md, mock_get_nb_instance, mock_telemetry, + mock_build_for_transformers): + builder = ModelBuilder( + model=mock_model_id, + schema_builder=mock_schema_builder, + mode=Mode.LOCAL_CONTAINER, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + + builder.build() + + mock_build_for_transformers.assert_called_once() From 29cc2146c44bbc4a9f9362ce7f8f1f49f765878a Mon Sep 17 00:00:00 2001 From: Samrudhi Sharma Date: Sat, 18 May 2024 00:50:57 +0000 Subject: [PATCH 3/3] Fix formatting --- .../sagemaker/serve/builder/test_transformers_builder.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/sagemaker/serve/builder/test_transformers_builder.py b/tests/unit/sagemaker/serve/builder/test_transformers_builder.py index 86f5836093..9ea797adc2 100644 --- a/tests/unit/sagemaker/serve/builder/test_transformers_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_transformers_builder.py @@ -155,8 +155,9 @@ def test_image_uri_override( "sagemaker.huggingface.llm_utils.get_huggingface_model_metadata", return_value=None, ) - def test_failure_hf_md(self, mock_model_md, mock_get_nb_instance, mock_telemetry, - mock_build_for_transformers): + def test_failure_hf_md( + self, mock_model_md, mock_get_nb_instance, mock_telemetry, mock_build_for_transformers + ): builder = ModelBuilder( model=mock_model_id, schema_builder=mock_schema_builder,