@@ -866,16 +866,10 @@ def _create_sagemaker_model(
866
866
# _base_name, model_name are not needed under PipelineSession.
867
867
# the model_data may be Pipeline variable
868
868
# which may break the _base_name generation
869
- model_uri = None
870
- if isinstance (self .model_data , (str , PipelineVariable )):
871
- model_uri = self .model_data
872
- elif isinstance (self .model_data , dict ):
873
- model_uri = self .model_data .get ("S3DataSource" , {}).get ("S3Uri" , None )
874
-
875
869
self ._ensure_base_name_if_needed (
876
870
image_uri = container_def ["Image" ],
877
871
script_uri = self .source_dir ,
878
- model_uri = model_uri ,
872
+ model_uri = self . _get_model_uri () ,
879
873
)
880
874
self ._set_model_name_if_needed ()
881
875
@@ -912,6 +906,14 @@ def _create_sagemaker_model(
912
906
)
913
907
self .sagemaker_session .create_model (** create_model_args )
914
908
909
+ def _get_model_uri (self ):
910
+ model_uri = None
911
+ if isinstance (self .model_data , (str , PipelineVariable )):
912
+ model_uri = self .model_data
913
+ elif isinstance (self .model_data , dict ):
914
+ model_uri = self .model_data .get ("S3DataSource" , {}).get ("S3Uri" , None )
915
+ return model_uri
916
+
915
917
def _ensure_base_name_if_needed (self , image_uri , script_uri , model_uri ):
916
918
"""Create a base name from the image URI if there is no model name provided.
917
919
@@ -1496,7 +1498,7 @@ def deploy(
1496
1498
self ._ensure_base_name_if_needed (
1497
1499
image_uri = self .image_uri ,
1498
1500
script_uri = self .source_dir ,
1499
- model_uri = self .model_data ,
1501
+ model_uri = self ._get_model_uri () ,
1500
1502
)
1501
1503
if self ._base_name is not None :
1502
1504
self ._base_name = "-" .join ((self ._base_name , compiled_model_suffix ))
0 commit comments