Skip to content

Commit 8c52f1b

Browse files
authored
feature: Add ModelStep for SageMaker Model Building Pipeline (#3076)
1 parent 92b1d47 commit 8c52f1b

25 files changed

+2832
-402
lines changed

src/sagemaker/model.py

+78-12
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import re
2121
import copy
22+
from typing import List, Dict
2223

2324
import sagemaker
2425
from sagemaker import (
@@ -38,6 +39,7 @@
3839
from sagemaker.async_inference import AsyncInferenceConfig
3940
from sagemaker.predictor_async import AsyncPredictor
4041
from sagemaker.workflow import is_pipeline_variable
42+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
4143

4244
LOGGER = logging.getLogger("sagemaker")
4345

@@ -289,6 +291,7 @@ def __init__(
289291
self.uploaded_code = None
290292
self.repacked_model_data = None
291293

294+
@runnable_by_pipeline
292295
def register(
293296
self,
294297
content_types,
@@ -310,12 +313,12 @@ def register(
310313
"""Creates a model package for creating SageMaker models or listing on Marketplace.
311314
312315
Args:
313-
content_types (list): The supported MIME types for the input data (default: None).
314-
response_types (list): The supported MIME types for the output data (default: None).
316+
content_types (list): The supported MIME types for the input data.
317+
response_types (list): The supported MIME types for the output data.
315318
inference_instances (list): A list of the instance types that are used to
316-
generate inferences in real-time (default: None).
319+
generate inferences in real-time.
317320
transform_instances (list): A list of the instance types on which a transformation
318-
job can be run or on which an endpoint can be deployed (default: None).
321+
job can be run or on which an endpoint can be deployed.
319322
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
320323
using `model_package_name` makes the Model Package un-versioned (default: None).
321324
model_package_group_name (str): Model Package Group name, exclusive to
@@ -366,12 +369,50 @@ def register(
366369
model_package = self.sagemaker_session.create_model_package_from_containers(
367370
**model_pkg_args
368371
)
372+
if isinstance(self.sagemaker_session, PipelineSession):
373+
return None
369374
return ModelPackage(
370375
role=self.role,
371376
model_data=self.model_data,
372377
model_package_arn=model_package.get("ModelPackageArn"),
373378
)
374379

380+
@runnable_by_pipeline
381+
def create(
382+
self,
383+
instance_type: str = None,
384+
accelerator_type: str = None,
385+
serverless_inference_config: ServerlessInferenceConfig = None,
386+
tags: List[Dict[str, str]] = None,
387+
):
388+
"""Create a SageMaker Model Entity
389+
390+
Args:
391+
instance_type (str): The EC2 instance type that this Model will be
392+
used for, this is only used to determine if the image needs GPU
393+
support or not (default: None).
394+
accelerator_type (str): Type of Elastic Inference accelerator to
395+
attach to an endpoint for model loading and inference, for
396+
example, 'ml.eia1.medium'. If not specified, no Elastic
397+
Inference accelerator will be attached to the endpoint (default: None).
398+
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
399+
Specifies configuration related to serverless endpoint. Instance type is
400+
not provided in serverless inference. So this is used to find image URIs
401+
(default: None).
402+
tags (List[Dict[str, str]]): The list of tags to add to
403+
the model (default: None). Example: >>> tags = [{'Key': 'tagname', 'Value':
404+
'tagvalue'}] For more information about tags, see
405+
https://boto3.amazonaws.com/v1/documentation
406+
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
407+
"""
408+
# TODO: we should replace _create_sagemaker_model() with create()
409+
self._create_sagemaker_model(
410+
instance_type=instance_type,
411+
accelerator_type=accelerator_type,
412+
tags=tags,
413+
serverless_inference_config=serverless_inference_config,
414+
)
415+
375416
def _init_sagemaker_session_if_does_not_exist(self, instance_type=None):
376417
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
377418
@@ -455,6 +496,24 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
455496
if repack and self.model_data is not None and self.entry_point is not None:
456497
if is_pipeline_variable(self.model_data):
457498
# model is not yet there, defer repacking to later during pipeline execution
499+
if not isinstance(self.sagemaker_session, PipelineSession):
500+
# TODO: link the doc in the warning once ready
501+
logging.warning(
502+
"The model_data is a Pipeline variable of type %s, "
503+
"which should be used under `PipelineSession` and "
504+
"leverage `ModelStep` to create or register model. "
505+
"Otherwise some functionalities e.g. "
506+
"runtime repack may be missing",
507+
type(self.model_data),
508+
)
509+
return
510+
self.sagemaker_session.context.need_runtime_repack.add(id(self))
511+
# Add the uploaded_code and repacked_model_data to update the container env
512+
self.repacked_model_data = self.model_data
513+
self.uploaded_code = fw_utils.UploadedCode(
514+
s3_prefix=self.repacked_model_data,
515+
script_name=os.path.basename(self.entry_point),
516+
)
458517
return
459518
if local_code and self.model_data.startswith("file://"):
460519
repacked_model_data = self.model_data
@@ -538,22 +597,29 @@ def _create_sagemaker_model(
538597
serverless_inference_config=serverless_inference_config,
539598
)
540599

541-
self._ensure_base_name_if_needed(
542-
image_uri=container_def["Image"], script_uri=self.source_dir, model_uri=self.model_data
543-
)
544-
self._set_model_name_if_needed()
600+
if not isinstance(self.sagemaker_session, PipelineSession):
601+
# _base_name, model_name are not needed under PipelineSession.
602+
# the model_data may be Pipeline variable
603+
# which may break the _base_name generation
604+
self._ensure_base_name_if_needed(
605+
image_uri=container_def["Image"],
606+
script_uri=self.source_dir,
607+
model_uri=self.model_data,
608+
)
609+
self._set_model_name_if_needed()
545610

546611
enable_network_isolation = self.enable_network_isolation()
547612

548613
self._init_sagemaker_session_if_does_not_exist(instance_type)
549-
self.sagemaker_session.create_model(
550-
self.name,
551-
self.role,
552-
container_def,
614+
create_model_args = dict(
615+
name=self.name,
616+
role=self.role,
617+
container_defs=container_def,
553618
vpc_config=self.vpc_config,
554619
enable_network_isolation=enable_network_isolation,
555620
tags=tags,
556621
)
622+
self.sagemaker_session.create_model(**create_model_args)
557623

558624
def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
559625
"""Create a base name from the image URI if there is no model name provided.

src/sagemaker/pipeline.py

+100-4
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,16 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Dict
17+
1618
import sagemaker
19+
from sagemaker import ModelMetrics
20+
from sagemaker.drift_check_baselines import DriftCheckBaselines
21+
from sagemaker.metadata_properties import MetadataProperties
1722
from sagemaker.session import Session
1823
from sagemaker.utils import name_from_image
1924
from sagemaker.transformer import Transformer
25+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
2026

2127

2228
class PipelineModel(object):
@@ -221,6 +227,17 @@ def deploy(
221227
return predictor
222228
return None
223229

230+
@runnable_by_pipeline
231+
def create(self, instance_type: str):
232+
"""Create a SageMaker Model Entity
233+
234+
Args:
235+
instance_type (str): The EC2 instance type that this Model will be
236+
used for, this is only used to determine if the image needs GPU
237+
support or not.
238+
"""
239+
self._create_sagemaker_pipeline_model(instance_type)
240+
224241
def _create_sagemaker_pipeline_model(self, instance_type):
225242
"""Create a SageMaker Model Entity
226243
@@ -235,13 +252,92 @@ def _create_sagemaker_pipeline_model(self, instance_type):
235252
containers = self.pipeline_container_def(instance_type)
236253

237254
self.name = self.name or name_from_image(containers[0]["Image"])
238-
self.sagemaker_session.create_model(
239-
self.name,
240-
self.role,
241-
containers,
255+
create_model_args = dict(
256+
name=self.name,
257+
role=self.role,
258+
container_defs=containers,
242259
vpc_config=self.vpc_config,
243260
enable_network_isolation=self.enable_network_isolation,
244261
)
262+
self.sagemaker_session.create_model(**create_model_args)
263+
264+
@runnable_by_pipeline
265+
def register(
266+
self,
267+
content_types: list,
268+
response_types: list,
269+
inference_instances: list,
270+
transform_instances: list,
271+
model_package_name: Optional[str] = None,
272+
model_package_group_name: Optional[str] = None,
273+
image_uri: Optional[str] = None,
274+
model_metrics: Optional[ModelMetrics] = None,
275+
metadata_properties: Optional[MetadataProperties] = None,
276+
marketplace_cert: bool = False,
277+
approval_status: Optional[str] = None,
278+
description: Optional[str] = None,
279+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
280+
customer_metadata_properties: Optional[Dict[str, str]] = None,
281+
):
282+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
283+
284+
Args:
285+
content_types (list): The supported MIME types for the input data.
286+
response_types (list): The supported MIME types for the output data.
287+
inference_instances (list): A list of the instance types that are used to
288+
generate inferences in real-time.
289+
transform_instances (list): A list of the instance types on which a transformation
290+
job can be run or on which an endpoint can be deployed.
291+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
292+
using `model_package_name` makes the Model Package un-versioned (default: None).
293+
model_package_group_name (str): Model Package Group name, exclusive to
294+
`model_package_name`, using `model_package_group_name` makes the Model Package
295+
versioned (default: None).
296+
image_uri (str): Inference image uri for the container. Model class' self.image will
297+
be used if it is None (default: None).
298+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
299+
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
300+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
301+
for AWS Marketplace (default: False).
302+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
303+
or "PendingManualApproval" (default: "PendingManualApproval").
304+
description (str): Model Package description (default: None).
305+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
306+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
307+
metadata properties (default: None).
308+
309+
Returns:
310+
A `sagemaker.model.ModelPackage` instance.
311+
"""
312+
for model in self.models:
313+
if model.model_data is None:
314+
raise ValueError("SageMaker Model Package cannot be created without model data.")
315+
if model_package_group_name is not None:
316+
container_def = self.pipeline_container_def(inference_instances[0])
317+
else:
318+
container_def = [
319+
{"Image": image_uri or model.image_uri, "ModelDataUrl": model.model_data}
320+
for model in self.models
321+
]
322+
323+
model_pkg_args = sagemaker.get_model_package_args(
324+
content_types,
325+
response_types,
326+
inference_instances,
327+
transform_instances,
328+
model_package_name=model_package_name,
329+
model_package_group_name=model_package_group_name,
330+
model_metrics=model_metrics,
331+
metadata_properties=metadata_properties,
332+
marketplace_cert=marketplace_cert,
333+
approval_status=approval_status,
334+
description=description,
335+
container_def_list=container_def,
336+
drift_check_baselines=drift_check_baselines,
337+
customer_metadata_properties=customer_metadata_properties,
338+
)
339+
340+
self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)
245341

246342
def transformer(
247343
self,

src/sagemaker/session.py

+38-29
Original file line numberDiff line numberDiff line change
@@ -2679,23 +2679,24 @@ def create_model(
26792679
primary_container=primary_container,
26802680
tags=tags,
26812681
)
2682-
LOGGER.info("Creating model with name: %s", name)
2683-
LOGGER.debug("CreateModel request: %s", json.dumps(create_model_request, indent=4))
26842682

2685-
try:
2686-
self.sagemaker_client.create_model(**create_model_request)
2687-
except ClientError as e:
2688-
error_code = e.response["Error"]["Code"]
2689-
message = e.response["Error"]["Message"]
2690-
2691-
if (
2692-
error_code == "ValidationException"
2693-
and "Cannot create already existing model" in message
2694-
):
2695-
LOGGER.warning("Using already existing model: %s", name)
2696-
else:
2697-
raise
2683+
def submit(request):
2684+
LOGGER.info("Creating model with name: %s", name)
2685+
LOGGER.debug("CreateModel request: %s", json.dumps(request, indent=4))
2686+
try:
2687+
self.sagemaker_client.create_model(**request)
2688+
except ClientError as e:
2689+
error_code = e.response["Error"]["Code"]
2690+
message = e.response["Error"]["Message"]
2691+
if (
2692+
error_code == "ValidationException"
2693+
and "Cannot create already existing model" in message
2694+
):
2695+
LOGGER.warning("Using already existing model: %s", name)
2696+
else:
2697+
raise
26982698

2699+
self._intercept_create_request(create_model_request, submit, self.create_model.__name__)
26992700
return name
27002701

27012702
def create_model_from_job(
@@ -2829,10 +2830,9 @@ def create_model_package_from_containers(
28292830
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
28302831
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
28312832
metadata properties (default: None).
2832-
28332833
"""
28342834

2835-
request = get_create_model_package_request(
2835+
model_pkg_request = get_create_model_package_request(
28362836
model_package_name,
28372837
model_package_group_name,
28382838
containers,
@@ -2849,16 +2849,22 @@ def create_model_package_from_containers(
28492849
customer_metadata_properties=customer_metadata_properties,
28502850
validation_specification=validation_specification,
28512851
)
2852-
if model_package_group_name is not None:
2853-
try:
2854-
self.sagemaker_client.describe_model_package_group(
2855-
ModelPackageGroupName=request["ModelPackageGroupName"]
2856-
)
2857-
except ClientError:
2858-
self.sagemaker_client.create_model_package_group(
2859-
ModelPackageGroupName=request["ModelPackageGroupName"]
2860-
)
2861-
return self.sagemaker_client.create_model_package(**request)
2852+
2853+
def submit(request):
2854+
if model_package_group_name is not None:
2855+
try:
2856+
self.sagemaker_client.describe_model_package_group(
2857+
ModelPackageGroupName=request["ModelPackageGroupName"]
2858+
)
2859+
except ClientError:
2860+
self.sagemaker_client.create_model_package_group(
2861+
ModelPackageGroupName=request["ModelPackageGroupName"]
2862+
)
2863+
return self.sagemaker_client.create_model_package(**request)
2864+
2865+
return self._intercept_create_request(
2866+
model_pkg_request, submit, self.create_model_package_from_containers.__name__
2867+
)
28622868

28632869
def wait_for_model_package(self, model_package_name, poll=5):
28642870
"""Wait for an Amazon SageMaker endpoint deployment to complete.
@@ -4177,7 +4183,9 @@ def account_id(self) -> str:
41774183
)
41784184
return sts_client.get_caller_identity()["Account"]
41794185

4180-
def _intercept_create_request(self, request: typing.Dict, create):
4186+
def _intercept_create_request(
4187+
self, request: typing.Dict, create, func_name: str = None # pylint: disable=unused-argument
4188+
):
41814189
"""This function intercepts the create job request.
41824190
41834191
PipelineSession inherits this Session class and will override
@@ -4186,8 +4194,9 @@ def _intercept_create_request(self, request: typing.Dict, create):
41864194
Args:
41874195
request (dict): the create job request
41884196
create (functor): a functor calls the sagemaker client create method
4197+
func_name (str): the name of the function needed intercepting
41894198
"""
4190-
create(request)
4199+
return create(request)
41914200

41924201

41934202
def get_model_package_args(

0 commit comments

Comments
 (0)