Skip to content

Commit f6d9538

Browse files
committed
Change: Accept model_data as dictionary in the model deploy
1 parent 3ed0011 commit f6d9538

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/sagemaker/model.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -866,16 +866,10 @@ def _create_sagemaker_model(
866866
# _base_name, model_name are not needed under PipelineSession.
867867
# the model_data may be Pipeline variable
868868
# which may break the _base_name generation
869-
model_uri = None
870-
if isinstance(self.model_data, (str, PipelineVariable)):
871-
model_uri = self.model_data
872-
elif isinstance(self.model_data, dict):
873-
model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None)
874-
875869
self._ensure_base_name_if_needed(
876870
image_uri=container_def["Image"],
877871
script_uri=self.source_dir,
878-
model_uri=model_uri,
872+
model_uri=self._get_model_uri(),
879873
)
880874
self._set_model_name_if_needed()
881875

@@ -912,6 +906,14 @@ def _create_sagemaker_model(
912906
)
913907
self.sagemaker_session.create_model(**create_model_args)
914908

909+
def _get_model_uri(self):
910+
model_uri = None
911+
if isinstance(self.model_data, (str, PipelineVariable)):
912+
model_uri = self.model_data
913+
elif isinstance(self.model_data, dict):
914+
model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None)
915+
return model_uri
916+
915917
def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
916918
"""Create a base name from the image URI if there is no model name provided.
917919
@@ -1496,7 +1498,7 @@ def deploy(
14961498
self._ensure_base_name_if_needed(
14971499
image_uri=self.image_uri,
14981500
script_uri=self.source_dir,
1499-
model_uri=self.model_data,
1501+
model_uri=self._get_model_uri(),
15001502
)
15011503
if self._base_name is not None:
15021504
self._base_name = "-".join((self._base_name, compiled_model_suffix))

0 commit comments

Comments
 (0)