@@ -831,16 +831,10 @@ def _create_sagemaker_model(
831
831
# _base_name, model_name are not needed under PipelineSession.
832
832
# the model_data may be Pipeline variable
833
833
# which may break the _base_name generation
834
- model_uri = None
835
- if isinstance (self .model_data , (str , PipelineVariable )):
836
- model_uri = self .model_data
837
- elif isinstance (self .model_data , dict ):
838
- model_uri = self .model_data .get ("S3DataSource" , {}).get ("S3Uri" , None )
839
-
840
834
self ._ensure_base_name_if_needed (
841
835
image_uri = container_def ["Image" ],
842
836
script_uri = self .source_dir ,
843
- model_uri = model_uri ,
837
+ model_uri = self . _get_model_uri () ,
844
838
)
845
839
self ._set_model_name_if_needed ()
846
840
@@ -877,6 +871,14 @@ def _create_sagemaker_model(
877
871
)
878
872
self .sagemaker_session .create_model (** create_model_args )
879
873
874
+ def _get_model_uri (self ):
875
+ model_uri = None
876
+ if isinstance (self .model_data , (str , PipelineVariable )):
877
+ model_uri = self .model_data
878
+ elif isinstance (self .model_data , dict ):
879
+ model_uri = self .model_data .get ("S3DataSource" , {}).get ("S3Uri" , None )
880
+ return model_uri
881
+
880
882
def _ensure_base_name_if_needed (self , image_uri , script_uri , model_uri ):
881
883
"""Create a base name from the image URI if there is no model name provided.
882
884
@@ -1434,7 +1436,7 @@ def deploy(
1434
1436
self ._ensure_base_name_if_needed (
1435
1437
image_uri = self .image_uri ,
1436
1438
script_uri = self .source_dir ,
1437
- model_uri = self .model_data ,
1439
+ model_uri = self ._get_model_uri () ,
1438
1440
)
1439
1441
if self ._base_name is not None :
1440
1442
self ._base_name = "-" .join ((self ._base_name , compiled_model_suffix ))
0 commit comments