Skip to content

Commit ec70e0d

Browse files
committed
Change: Accept model_data as dictionary in the model deploy
1 parent 8462f1a commit ec70e0d

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/sagemaker/model.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -831,16 +831,10 @@ def _create_sagemaker_model(
831831
# _base_name, model_name are not needed under PipelineSession.
832832
# the model_data may be Pipeline variable
833833
# which may break the _base_name generation
834-
model_uri = None
835-
if isinstance(self.model_data, (str, PipelineVariable)):
836-
model_uri = self.model_data
837-
elif isinstance(self.model_data, dict):
838-
model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None)
839-
840834
self._ensure_base_name_if_needed(
841835
image_uri=container_def["Image"],
842836
script_uri=self.source_dir,
843-
model_uri=model_uri,
837+
model_uri=self._get_model_uri(),
844838
)
845839
self._set_model_name_if_needed()
846840

@@ -877,6 +871,14 @@ def _create_sagemaker_model(
877871
)
878872
self.sagemaker_session.create_model(**create_model_args)
879873

874+
def _get_model_uri(self):
875+
model_uri = None
876+
if isinstance(self.model_data, (str, PipelineVariable)):
877+
model_uri = self.model_data
878+
elif isinstance(self.model_data, dict):
879+
model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None)
880+
return model_uri
881+
880882
def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
881883
"""Create a base name from the image URI if there is no model name provided.
882884
@@ -1434,7 +1436,7 @@ def deploy(
14341436
self._ensure_base_name_if_needed(
14351437
image_uri=self.image_uri,
14361438
script_uri=self.source_dir,
1437-
model_uri=self.model_data,
1439+
model_uri=self._get_model_uri(),
14381440
)
14391441
if self._base_name is not None:
14401442
self._base_name = "-".join((self._base_name, compiled_model_suffix))

0 commit comments

Comments
 (0)