diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 7b741a1269..d9122cacf1 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -866,16 +866,10 @@ def _create_sagemaker_model( # _base_name, model_name are not needed under PipelineSession. # the model_data may be Pipeline variable # which may break the _base_name generation - model_uri = None - if isinstance(self.model_data, (str, PipelineVariable)): - model_uri = self.model_data - elif isinstance(self.model_data, dict): - model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None) - self._ensure_base_name_if_needed( image_uri=container_def["Image"], script_uri=self.source_dir, - model_uri=model_uri, + model_uri=self._get_model_uri(), ) self._set_model_name_if_needed() @@ -912,6 +906,14 @@ def _create_sagemaker_model( ) self.sagemaker_session.create_model(**create_model_args) + def _get_model_uri(self): + model_uri = None + if isinstance(self.model_data, (str, PipelineVariable)): + model_uri = self.model_data + elif isinstance(self.model_data, dict): + model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None) + return model_uri + def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri): """Create a base name from the image URI if there is no model name provided. @@ -1496,7 +1498,7 @@ def deploy( self._ensure_base_name_if_needed( image_uri=self.image_uri, script_uri=self.source_dir, - model_uri=self.model_data, + model_uri=self._get_model_uri(), ) if self._base_name is not None: self._base_name = "-".join((self._base_name, compiled_model_suffix)) diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 4d4248e0d6..bfd5af977d 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -1307,3 +1307,33 @@ def test_package_for_edge_with_sagemaker_config_injection(sagemaker_session): role=SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB["SageMaker"]["EdgePackagingJob"]["RoleArn"], tags=None, ) + + +def test_model_source( + sagemaker_session, +): + model = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + image_uri=IMAGE_URI, + model_data={ + "S3DataSource": { + "S3Uri": "s3://tmybuckaet", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + ) + + assert model._get_model_uri() == "s3://tmybuckaet" + + model_1 = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + image_uri=IMAGE_URI, + model_data="s3://tmybuckaet", + ) + + assert model_1._get_model_uri() == "s3://tmybuckaet"