Skip to content

Commit 15691f8

Browse files
martinRenousagemaker-bot
authored andcommitted
Change: Accept model_data as dictionary in the model deploy (aws#4276)
* Change: Accept model_data as dictionary in the model deploy * Add unit test
1 parent 0f616f4 commit 15691f8

File tree

2 files changed

+40
-8
lines changed

2 files changed

+40
-8
lines changed

src/sagemaker/model.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -866,16 +866,10 @@ def _create_sagemaker_model(
866866
# _base_name, model_name are not needed under PipelineSession.
867867
# the model_data may be Pipeline variable
868868
# which may break the _base_name generation
869-
model_uri = None
870-
if isinstance(self.model_data, (str, PipelineVariable)):
871-
model_uri = self.model_data
872-
elif isinstance(self.model_data, dict):
873-
model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None)
874-
875869
self._ensure_base_name_if_needed(
876870
image_uri=container_def["Image"],
877871
script_uri=self.source_dir,
878-
model_uri=model_uri,
872+
model_uri=self._get_model_uri(),
879873
)
880874
self._set_model_name_if_needed()
881875

@@ -912,6 +906,14 @@ def _create_sagemaker_model(
912906
)
913907
self.sagemaker_session.create_model(**create_model_args)
914908

909+
def _get_model_uri(self):
910+
model_uri = None
911+
if isinstance(self.model_data, (str, PipelineVariable)):
912+
model_uri = self.model_data
913+
elif isinstance(self.model_data, dict):
914+
model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None)
915+
return model_uri
916+
915917
def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
916918
"""Create a base name from the image URI if there is no model name provided.
917919
@@ -1496,7 +1498,7 @@ def deploy(
14961498
self._ensure_base_name_if_needed(
14971499
image_uri=self.image_uri,
14981500
script_uri=self.source_dir,
1499-
model_uri=self.model_data,
1501+
model_uri=self._get_model_uri(),
15001502
)
15011503
if self._base_name is not None:
15021504
self._base_name = "-".join((self._base_name, compiled_model_suffix))

tests/unit/sagemaker/model/test_model.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,3 +1307,33 @@ def test_package_for_edge_with_sagemaker_config_injection(sagemaker_session):
13071307
role=SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB["SageMaker"]["EdgePackagingJob"]["RoleArn"],
13081308
tags=None,
13091309
)
1310+
1311+
1312+
def test_model_source(
1313+
sagemaker_session,
1314+
):
1315+
model = Model(
1316+
entry_point=ENTRY_POINT_INFERENCE,
1317+
role=ROLE,
1318+
sagemaker_session=sagemaker_session,
1319+
image_uri=IMAGE_URI,
1320+
model_data={
1321+
"S3DataSource": {
1322+
"S3Uri": "s3://tmybuckaet",
1323+
"S3DataType": "S3Prefix",
1324+
"CompressionType": "None",
1325+
}
1326+
},
1327+
)
1328+
1329+
assert model._get_model_uri() == "s3://tmybuckaet"
1330+
1331+
model_1 = Model(
1332+
entry_point=ENTRY_POINT_INFERENCE,
1333+
role=ROLE,
1334+
sagemaker_session=sagemaker_session,
1335+
image_uri=IMAGE_URI,
1336+
model_data="s3://tmybuckaet",
1337+
)
1338+
1339+
assert model_1._get_model_uri() == "s3://tmybuckaet"

0 commit comments

Comments
 (0)