Skip to content

Commit 8434b4f

Browse files
committed
feature: add serverless inference image_uri support
1 parent 9002d6f commit 8434b4f

File tree

19 files changed

+380
-58
lines changed

19 files changed

+380
-58
lines changed

src/sagemaker/chainer/model.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ def __init__(
143143

144144
self.model_server_workers = model_server_workers
145145

146-
def prepare_container_def(self, instance_type=None, accelerator_type=None):
146+
def prepare_container_def(
147+
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
148+
):
147149
"""Return a container definition with framework configuration set in model environment.
148150
149151
Args:
@@ -159,14 +161,17 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
159161
"""
160162
deploy_image = self.image_uri
161163
if not deploy_image:
162-
if instance_type is None:
164+
if instance_type is None and serverless_inference_config is None:
163165
raise ValueError(
164166
"Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
165167
)
166168

167169
region_name = self.sagemaker_session.boto_session.region_name
168170
deploy_image = self.serving_image_uri(
169-
region_name, instance_type, accelerator_type=accelerator_type
171+
region_name,
172+
instance_type,
173+
accelerator_type=accelerator_type,
174+
serverless_inference_config=serverless_inference_config,
170175
)
171176

172177
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
@@ -178,7 +183,9 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
178183
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
179184
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
180185

181-
def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
186+
def serving_image_uri(
187+
self, region_name, instance_type, accelerator_type=None, serverless_inference_config=None
188+
):
182189
"""Create a URI for the serving image.
183190
184191
Args:
@@ -198,4 +205,5 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
198205
instance_type=instance_type,
199206
accelerator_type=accelerator_type,
200207
image_scope="inference",
208+
serverless_inference_config=serverless_inference_config,
201209
)

src/sagemaker/huggingface/model.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def deploy(
272272
is not None. Otherwise, return None.
273273
"""
274274

275-
if not self.image_uri and instance_type.startswith("ml.inf"):
275+
if not self.image_uri and instance_type is not None and instance_type.startswith("ml.inf"):
276276
self.image_uri = self.serving_image_uri(
277277
region_name=self.sagemaker_session.boto_session.region_name,
278278
instance_type=instance_type,
@@ -365,7 +365,9 @@ def register(
365365
drift_check_baselines=drift_check_baselines,
366366
)
367367

368-
def prepare_container_def(self, instance_type=None, accelerator_type=None):
368+
def prepare_container_def(
369+
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
370+
):
369371
"""A container definition with framework configuration set in model environment variables.
370372
371373
Args:
@@ -381,14 +383,17 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
381383
"""
382384
deploy_image = self.image_uri
383385
if not deploy_image:
384-
if instance_type is None:
386+
if instance_type is None and serverless_inference_config is None:
385387
raise ValueError(
386388
"Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
387389
)
388390

389391
region_name = self.sagemaker_session.boto_session.region_name
390392
deploy_image = self.serving_image_uri(
391-
region_name, instance_type, accelerator_type=accelerator_type
393+
region_name,
394+
instance_type,
395+
accelerator_type=accelerator_type,
396+
serverless_inference_config=serverless_inference_config,
392397
)
393398

394399
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
@@ -402,7 +407,13 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
402407
deploy_image, self.repacked_model_data or self.model_data, deploy_env
403408
)
404409

405-
def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
410+
def serving_image_uri(
411+
self,
412+
region_name,
413+
instance_type=None,
414+
accelerator_type=None,
415+
serverless_inference_config=None,
416+
):
406417
"""Create a URI for the serving image.
407418
408419
Args:
@@ -432,4 +443,5 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
432443
accelerator_type=accelerator_type,
433444
image_scope="inference",
434445
base_framework_version=base_framework_version,
446+
serverless_inference_config=serverless_inference_config,
435447
)

src/sagemaker/image_uris.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def retrieve(
4848
tolerate_deprecated_model=False,
4949
sdk_version=None,
5050
inference_tool=None,
51+
serverless_inference_config=None,
5152
) -> str:
5253
"""Retrieves the ECR URI for the Docker image matching the given arguments.
5354
@@ -159,7 +160,9 @@ def retrieve(
159160
repo = version_config["repository"]
160161

161162
processor = _processor(
162-
instance_type, config.get("processors") or version_config.get("processors")
163+
instance_type,
164+
config.get("processors") or version_config.get("processors"),
165+
serverless_inference_config,
163166
)
164167

165168
# if container version is available in .json file, utilize that
@@ -202,7 +205,9 @@ def retrieve(
202205

203206
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
204207

205-
if _should_auto_select_container_version(instance_type, distribution):
208+
if instance_type is not None and _should_auto_select_container_version(
209+
instance_type, distribution
210+
):
206211
container_versions = {
207212
"tensorflow-2.3-gpu-py37": "cu110-ubuntu18.04-v3",
208213
"tensorflow-2.3.1-gpu-py37": "cu110-ubuntu18.04",
@@ -327,7 +332,7 @@ def _registry_from_region(region, registry_dict):
327332
return registry_dict[region]
328333

329334

330-
def _processor(instance_type, available_processors):
335+
def _processor(instance_type, available_processors, serverless_inference_config=None):
331336
"""Returns the processor type for the given instance type."""
332337
if not available_processors:
333338
logger.info("Ignoring unnecessary instance type: %s.", instance_type)
@@ -337,6 +342,10 @@ def _processor(instance_type, available_processors):
337342
logger.info("Defaulting to only supported image scope: %s.", available_processors[0])
338343
return available_processors[0]
339344

345+
if serverless_inference_config is not None:
346+
logger.info("Defaulting to CPU type when using serverless inference")
347+
return "cpu"
348+
340349
if not instance_type:
341350
raise ValueError(
342351
"Empty SageMaker instance type. For options, see: "

src/sagemaker/model.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,10 @@ def _init_sagemaker_session_if_does_not_exist(self, instance_type=None):
383383
self.sagemaker_session = session.Session()
384384

385385
def prepare_container_def(
386-
self, instance_type=None, accelerator_type=None
386+
self,
387+
instance_type=None,
388+
accelerator_type=None,
389+
serverless_inference_config=None,
387390
): # pylint: disable=unused-argument
388391
"""Return a dict created by ``sagemaker.container_def()``.
389392
@@ -498,7 +501,9 @@ def enable_network_isolation(self):
498501
"""
499502
return self._enable_network_isolation
500503

501-
def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tags=None):
504+
def _create_sagemaker_model(
505+
self, instance_type=None, accelerator_type=None, tags=None, serverless_inference_config=None
506+
):
502507
"""Create a SageMaker Model Entity
503508
504509
Args:
@@ -515,7 +520,11 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag
515520
https://boto3.amazonaws.com/v1/documentation
516521
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
517522
"""
518-
container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type)
523+
container_def = self.prepare_container_def(
524+
instance_type,
525+
accelerator_type=accelerator_type,
526+
serverless_inference_config=serverless_inference_config,
527+
)
519528

520529
self._ensure_base_name_if_needed(
521530
image_uri=container_def["Image"], script_uri=self.source_dir, model_uri=self.model_data
@@ -983,7 +992,9 @@ def deploy(
983992
if self._base_name is not None:
984993
self._base_name = "-".join((self._base_name, compiled_model_suffix))
985994

986-
self._create_sagemaker_model(instance_type, accelerator_type, tags)
995+
self._create_sagemaker_model(
996+
instance_type, accelerator_type, tags, serverless_inference_config
997+
)
987998

988999
serverless_inference_config_dict = (
9891000
serverless_inference_config._to_request_dict() if is_serverless else None

src/sagemaker/mxnet/model.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,9 @@ def register(
220220
customer_metadata_properties=customer_metadata_properties,
221221
)
222222

223-
def prepare_container_def(self, instance_type=None, accelerator_type=None):
223+
def prepare_container_def(
224+
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
225+
):
224226
"""Return a container definition with framework configuration.
225227
226228
Framework configuration is set in model environment variables.
@@ -238,14 +240,17 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
238240
"""
239241
deploy_image = self.image_uri
240242
if not deploy_image:
241-
if instance_type is None:
243+
if instance_type is None and serverless_inference_config is None:
242244
raise ValueError(
243245
"Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
244246
)
245247

246248
region_name = self.sagemaker_session.boto_session.region_name
247249
deploy_image = self.serving_image_uri(
248-
region_name, instance_type, accelerator_type=accelerator_type
250+
region_name,
251+
instance_type,
252+
accelerator_type=accelerator_type,
253+
serverless_inference_config=serverless_inference_config,
249254
)
250255

251256
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
@@ -259,7 +264,9 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
259264
deploy_image, self.repacked_model_data or self.model_data, deploy_env
260265
)
261266

262-
def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
267+
def serving_image_uri(
268+
self, region_name, instance_type, accelerator_type=None, serverless_inference_config=None
269+
):
263270
"""Create a URI for the serving image.
264271
265272
Args:
@@ -282,6 +289,7 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
282289
instance_type=instance_type,
283290
accelerator_type=accelerator_type,
284291
image_scope="inference",
292+
serverless_inference_config=serverless_inference_config,
285293
)
286294

287295
def _is_mms_version(self):

src/sagemaker/pytorch/model.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,9 @@ def register(
220220
customer_metadata_properties=customer_metadata_properties,
221221
)
222222

223-
def prepare_container_def(self, instance_type=None, accelerator_type=None):
223+
def prepare_container_def(
224+
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
225+
):
224226
"""A container definition with framework configuration set in model environment variables.
225227
226228
Args:
@@ -236,14 +238,17 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
236238
"""
237239
deploy_image = self.image_uri
238240
if not deploy_image:
239-
if instance_type is None:
241+
if instance_type is None and serverless_inference_config is None:
240242
raise ValueError(
241243
"Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
242244
)
243245

244246
region_name = self.sagemaker_session.boto_session.region_name
245247
deploy_image = self.serving_image_uri(
246-
region_name, instance_type, accelerator_type=accelerator_type
248+
region_name,
249+
instance_type,
250+
accelerator_type=accelerator_type,
251+
serverless_inference_config=serverless_inference_config,
247252
)
248253

249254
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
@@ -257,7 +262,9 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
257262
deploy_image, self.repacked_model_data or self.model_data, deploy_env
258263
)
259264

260-
def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
265+
def serving_image_uri(
266+
self, region_name, instance_type, accelerator_type=None, serverless_inference_config=None
267+
):
261268
"""Create a URI for the serving image.
262269
263270
Args:
@@ -280,6 +287,7 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
280287
instance_type=instance_type,
281288
accelerator_type=accelerator_type,
282289
image_scope="inference",
290+
serverless_inference_config=serverless_inference_config,
283291
)
284292

285293
def _is_mms_version(self):

src/sagemaker/sklearn/model.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ def register(
208208
description,
209209
)
210210

211-
def prepare_container_def(self, instance_type=None, accelerator_type=None):
211+
def prepare_container_def(
212+
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
213+
):
212214
"""Container definition with framework configuration set in model environment variables.
213215
214216
Args:
@@ -244,7 +246,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
244246
)
245247
return sagemaker.container_def(deploy_image, model_data_uri, deploy_env)
246248

247-
def serving_image_uri(self, region_name, instance_type):
249+
def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None):
248250
"""Create a URI for the serving image.
249251
250252
Args:
@@ -261,4 +263,5 @@ def serving_image_uri(self, region_name, instance_type):
261263
version=self.framework_version,
262264
py_version=self.py_version,
263265
instance_type=instance_type,
266+
serverless_inference_config=serverless_inference_config,
264267
)

0 commit comments

Comments
 (0)