Skip to content

Commit bc5d530

Browse files
author
Keshav Chandak
committed
feat: Added register step in Jumpstart
1 parent d3f2841 commit bc5d530

File tree

6 files changed

+478
-77
lines changed

6 files changed

+478
-77
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,19 @@
3434
JUMPSTART_DEFAULT_REGION_NAME,
3535
JUMPSTART_LOGGER,
3636
)
37+
from sagemaker.model_metrics import ModelMetrics
38+
from sagemaker.metadata_properties import MetadataProperties
39+
from sagemaker.drift_check_baselines import DriftCheckBaselines
3740
from sagemaker.jumpstart.enums import JumpStartScriptScope
3841
from sagemaker.jumpstart.types import (
3942
JumpStartModelDeployKwargs,
4043
JumpStartModelInitKwargs,
44+
JumpStartModelRegisterKwargs,
4145
)
4246
from sagemaker.jumpstart.utils import (
4347
update_dict_if_key_not_present,
4448
resolve_model_sagemaker_config_field,
49+
verify_model_region_and_return_specs,
4550
)
4651

4752
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
@@ -507,6 +512,87 @@ def get_deploy_kwargs(
507512
return deploy_kwargs
508513

509514

515+
def get_register_kwargs(
516+
model_id: str,
517+
model_version: Optional[str] = None,
518+
region: Optional[str] = None,
519+
tolerate_deprecated_model: Optional[bool] = None,
520+
tolerate_vulnerable_model: Optional[bool] = None,
521+
sagemaker_session: Optional[Any] = None,
522+
supported_content_types: List[str] = None,
523+
response_types: List[str] = None,
524+
inference_instances: Optional[List[str]] = None,
525+
transform_instances: Optional[List[str]] = None,
526+
model_package_group_name: Optional[str] = None,
527+
image_uri: Optional[str] = None,
528+
model_metrics: Optional[ModelMetrics] = None,
529+
metadata_properties: Optional[MetadataProperties] = None,
530+
approval_status: Optional[str] = None,
531+
description: Optional[str] = None,
532+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
533+
customer_metadata_properties: Optional[Dict[str, str]] = None,
534+
validation_specification: Optional[str] = None,
535+
domain: Optional[str] = None,
536+
task: Optional[str] = None,
537+
sample_payload_url: Optional[str] = None,
538+
framework: Optional[str] = None,
539+
framework_version: Optional[str] = None,
540+
nearest_model_name: Optional[str] = None,
541+
data_input_configuration: Optional[str] = None,
542+
skip_model_validation: Optional[str] = None,
543+
) -> JumpStartModelRegisterKwargs:
544+
"""Returns kwargs required to call `register` on `sagemaker.estimator.Model` object."""
545+
546+
register_kwargs = JumpStartModelRegisterKwargs(
547+
model_id=model_id,
548+
model_version=model_version,
549+
region=region,
550+
tolerate_deprecated_model=tolerate_deprecated_model,
551+
tolerate_vulnerable_model=tolerate_vulnerable_model,
552+
sagemaker_session=sagemaker_session,
553+
content_types=supported_content_types,
554+
response_types=response_types,
555+
inference_instances=inference_instances,
556+
transform_instances=transform_instances,
557+
model_package_group_name=model_package_group_name,
558+
image_uri=image_uri,
559+
model_metrics=model_metrics,
560+
metadata_properties=metadata_properties,
561+
approval_status=approval_status,
562+
description=description,
563+
drift_check_baselines=drift_check_baselines,
564+
customer_metadata_properties=customer_metadata_properties,
565+
validation_specification=validation_specification,
566+
domain=domain,
567+
task=task,
568+
sample_payload_url=sample_payload_url,
569+
framework=framework,
570+
framework_version=framework_version,
571+
nearest_model_name=nearest_model_name,
572+
data_input_configuration=data_input_configuration,
573+
skip_model_validation=skip_model_validation,
574+
)
575+
576+
model_specs = verify_model_region_and_return_specs(
577+
model_id=model_id,
578+
version=model_version,
579+
region=region,
580+
scope=JumpStartScriptScope.INFERENCE,
581+
sagemaker_session=sagemaker_session,
582+
tolerate_deprecated_model=tolerate_deprecated_model,
583+
tolerate_vulnerable_model=tolerate_vulnerable_model,
584+
)
585+
586+
register_kwargs.content_types = (
587+
register_kwargs.content_types or model_specs.predictor_specs.supported_content_types
588+
)
589+
register_kwargs.response_types = (
590+
register_kwargs.response_types or model_specs.predictor_specs.supported_accept_types
591+
)
592+
593+
return register_kwargs
594+
595+
510596
def get_init_kwargs(
511597
model_id: str,
512598
model_from_estimator: bool = False,

src/sagemaker/jumpstart/model.py

Lines changed: 148 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
"""This module stores JumpStart implementation of Model class."""
1414

1515
from __future__ import absolute_import
16-
import re
1716

1817
from typing import Dict, List, Optional, Union
1918
from sagemaker import payloads
@@ -28,16 +27,23 @@
2827
get_default_predictor,
2928
get_deploy_kwargs,
3029
get_init_kwargs,
30+
get_register_kwargs,
3131
)
3232
from sagemaker.jumpstart.types import JumpStartSerializablePayload
3333
from sagemaker.jumpstart.utils import is_valid_model_id
3434
from sagemaker.utils import stringify_object
35-
from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model
35+
from sagemaker.model import (
36+
Model,
37+
ModelPackage,
38+
)
3639
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
3740
from sagemaker.predictor import PredictorBase
3841
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
3942
from sagemaker.session import Session
4043
from sagemaker.workflow.entities import PipelineVariable
44+
from sagemaker.model_metrics import ModelMetrics
45+
from sagemaker.metadata_properties import MetadataProperties
46+
from sagemaker.drift_check_baselines import DriftCheckBaselines
4147

4248

4349
class JumpStartModel(Model):
@@ -309,11 +315,12 @@ def _is_valid_model_id_hook():
309315
self.tolerate_vulnerable_model = model_init_kwargs.tolerate_vulnerable_model
310316
self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model
311317
self.region = model_init_kwargs.region
312-
self.model_package_arn = model_init_kwargs.model_package_arn
313318
self.sagemaker_session = model_init_kwargs.sagemaker_session
314319

315320
super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())
316321

322+
self.model_package_arn = model_init_kwargs.model_package_arn
323+
317324
def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]:
318325
"""Returns all example payloads associated with the model.
319326
@@ -390,30 +397,29 @@ def _create_sagemaker_model(
390397
# inference endpoint.
391398
if self.model_package_arn and not self._model_data_is_set:
392399
# When a ModelPackageArn is provided we just create the Model
393-
match = re.match(MODEL_PACKAGE_ARN_PATTERN, self.model_package_arn)
394-
if match:
395-
model_package_name = match.group(3)
396-
else:
397-
# model_package_arn can be just the name if your account owns the Model Package
398-
model_package_name = self.model_package_arn
399-
container_def = {"ModelPackageName": self.model_package_arn}
400-
401-
if self.env != {}:
402-
container_def["Environment"] = self.env
403-
404-
if self.name is None:
405-
self._base_name = model_package_name
406-
407-
self._set_model_name_if_needed()
408-
409-
self.sagemaker_session.create_model(
410-
self.name,
411-
self.role,
412-
container_def,
400+
model_package = ModelPackage(
401+
role=self.role,
402+
model_data=self.model_data,
403+
model_package_arn=self.model_package_arn,
404+
sagemaker_session=self.sagemaker_session,
405+
predictor_cls=self.predictor_cls,
413406
vpc_config=self.vpc_config,
414-
enable_network_isolation=self.enable_network_isolation(),
407+
)
408+
if self.name is not None:
409+
model_package.name = self.name
410+
if self.env is not None:
411+
model_package.env = self.env
412+
model_package._create_sagemaker_model(
413+
instance_type=instance_type,
414+
accelerator_type=accelerator_type,
415415
tags=tags,
416+
serverless_inference_config=serverless_inference_config,
417+
**kwargs,
416418
)
419+
if self._base_name is None and model_package._base_name is not None:
420+
self._base_name = model_package._base_name
421+
if self.name is None and model_package.name is not None:
422+
self.name = model_package.name
417423
else:
418424
super(JumpStartModel, self)._create_sagemaker_model(
419425
instance_type=instance_type,
@@ -565,6 +571,124 @@ def deploy(
565571
# If a predictor class was passed, do not mutate predictor
566572
return predictor
567573

574+
def register(
575+
self,
576+
content_types: List[Union[str, PipelineVariable]] = None,
577+
response_types: List[Union[str, PipelineVariable]] = None,
578+
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
579+
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
580+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
581+
image_uri: Optional[Union[str, PipelineVariable]] = None,
582+
model_metrics: Optional[ModelMetrics] = None,
583+
metadata_properties: Optional[MetadataProperties] = None,
584+
approval_status: Optional[Union[str, PipelineVariable]] = None,
585+
description: Optional[str] = None,
586+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
587+
customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
588+
validation_specification: Optional[Union[str, PipelineVariable]] = None,
589+
domain: Optional[Union[str, PipelineVariable]] = None,
590+
task: Optional[Union[str, PipelineVariable]] = None,
591+
sample_payload_url: Optional[Union[str, PipelineVariable]] = None,
592+
framework: Optional[Union[str, PipelineVariable]] = None,
593+
framework_version: Optional[Union[str, PipelineVariable]] = None,
594+
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
595+
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
596+
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
597+
):
598+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
599+
600+
Args:
601+
content_types (list[str] or list[PipelineVariable]): The supported MIME types
602+
for the input data.
603+
response_types (list[str] or list[PipelineVariable]): The supported MIME types
604+
for the output data.
605+
inference_instances (list[str] or list[PipelineVariable]): A list of the instance
606+
types that are used to generate inferences in real-time (default: None).
607+
transform_instances (list[str] or list[PipelineVariable]): A list of the instance types
608+
on which a transformation job can be run or on which an endpoint can be deployed
609+
(default: None).
610+
model_package_group_name (str or PipelineVariable): Model Package Group name,
611+
exclusive to `model_package_name`, using `model_package_group_name` makes the
612+
Model Package versioned. Defaults to ``None``.
613+
image_uri (str or PipelineVariable): Inference image URI for the container. Model class'
614+
self.image will be used if it is None. Defaults to ``None``.
615+
model_metrics (ModelMetrics): ModelMetrics object. Defaults to ``None``.
616+
metadata_properties (MetadataProperties): MetadataProperties object.
617+
Defaults to ``None``.
618+
approval_status (str or PipelineVariable): Model Approval Status, values can be
619+
"Approved", "Rejected", or "PendingManualApproval". Defaults to
620+
``PendingManualApproval``.
621+
description (str): Model Package description. Defaults to ``None``.
622+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
623+
customer_metadata_properties (dict[str, str] or dict[str, PipelineVariable]):
624+
A dictionary of key-value paired metadata properties (default: None).
625+
domain (str or PipelineVariable): Domain values can be "COMPUTER_VISION",
626+
"NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None).
627+
sample_payload_url (str or PipelineVariable): The S3 path where the sample payload
628+
is stored (default: None).
629+
task (str or PipelineVariable): Task values which are supported by Inference Recommender
630+
are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION",
631+
"IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
632+
framework (str or PipelineVariable): Machine learning framework of the model package
633+
container image (default: None).
634+
framework_version (str or PipelineVariable): Framework version of the Model Package
635+
Container Image (default: None).
636+
nearest_model_name (str or PipelineVariable): Name of a pre-trained machine learning
637+
benchmarked by Amazon SageMaker Inference Recommender (default: None).
638+
data_input_configuration (str or PipelineVariable): Input object for the model
639+
(default: None).
640+
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
641+
validation. Values can be "All" or "None" (default: None).
642+
643+
Returns:
644+
A `sagemaker.model.ModelPackage` instance.
645+
"""
646+
647+
register_kwargs = get_register_kwargs(
648+
model_id=self.model_id,
649+
model_version=self.model_version,
650+
region=self.region,
651+
tolerate_deprecated_model=self.tolerate_deprecated_model,
652+
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
653+
sagemaker_session=self.sagemaker_session,
654+
supported_content_types=content_types,
655+
response_types=response_types,
656+
inference_instances=inference_instances,
657+
transform_instances=transform_instances,
658+
model_package_group_name=model_package_group_name,
659+
image_uri=image_uri,
660+
model_metrics=model_metrics,
661+
metadata_properties=metadata_properties,
662+
approval_status=approval_status,
663+
description=description,
664+
drift_check_baselines=drift_check_baselines,
665+
customer_metadata_properties=customer_metadata_properties,
666+
validation_specification=validation_specification,
667+
domain=domain,
668+
task=task,
669+
sample_payload_url=sample_payload_url,
670+
framework=framework,
671+
framework_version=framework_version,
672+
nearest_model_name=nearest_model_name,
673+
data_input_configuration=data_input_configuration,
674+
skip_model_validation=skip_model_validation,
675+
)
676+
677+
model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())
678+
679+
def register_deploy_wrapper(*args, **kwargs):
680+
if self.model_package_arn is not None:
681+
return self.deploy(*args, **kwargs)
682+
else:
683+
self.model_package_arn = model_package.model_package_arn
684+
predictor = self.deploy(*args, **kwargs)
685+
self.model_package_arn = None
686+
return predictor
687+
688+
model_package.deploy = register_deploy_wrapper
689+
690+
return model_package
691+
568692
def __str__(self) -> str:
569693
"""Overriding str(*) method to make more human-readable."""
570694
return stringify_object(self)

0 commit comments

Comments
 (0)